Source code for wdmwavelet

""" Wilson-Daubechies-Mayer wavelet transform.
"""
import numpy as np
from scipy.special import betainc
from math import ceil

# Constants.
PI = 3.141592653589793238462643383279502884
SQRT2 = 1.414213562373095048801688724209698079

def _rolloff_idx(flat_width, trans_width):
    """Computer start and end index of Meyer roll-off edge.
    
    Parameters
    ----------
    flat_width : float
        Width of half flat interval.
    trans_width : float
        Width if transition interval.

    Returns
    -------
    (sidx, eidx) : tuple
        Start and end indices of transition interval.
    """
    sidx = ceil(flat_width)
    eidx = ceil(flat_width + trans_width)
    return sidx, eidx


def _gen_rolloff(flat_width, trans_width, steepness):
    """Compute Meyer roll-off edge.
    """
    sidx, eidx = _rolloff_idx(flat_width, trans_width)
    idxs = np.arange(sidx, eidx, dtype=int)

    edge = np.cos(PI / 2 *
            betainc(steepness,
                    steepness,
                    (idxs - flat_width) / trans_width
            )
        )

    return (sidx, eidx), edge

[docs] def rwdm(x, nt, nf, trans_width, steepness, axis=-1): #TODO: delete nt or nf """Compute WDM transform for real input. Parameters ---------- x : array_like Input array. nt : int N_t. nf : int N_f. trans_width : float Width of transition interval. steepness : float Steepness of transition interval. axis : int Axis to be transformed. Returns ------- w : np.ndarray WDM coefficients, stored in the last two axes. """ if nf & 1: raise ValueError("nf must be even.") if trans_width < 1.: trans_width = 1. flat_width = (nt - trans_width) / 2 x = np.asarray(x) if axis != -1: x = np.moveaxis(x, axis, -1) (trans_sidx, trans_eidx), edge = _gen_rolloff(flat_width, trans_width, steepness) # Transform. trans_sidx_r = trans_sidx + nt trans_eidx_r = trans_eidx + nt trans_sidx_l = -trans_eidx + nt + 1 trans_eidx_l = -trans_sidx + nt + 1 xf = np.fft.rfft(x, norm="ortho") ans_shape = xf.shape[:-1] + (nf//2, 2*nt) ans = np.empty(ans_shape, dtype=float) ## m == 0 xf_ = xf[..., nt::-1].conj() xf_[..., trans_sidx_l:trans_eidx_l] *= edge[::-1] xf_[..., :trans_sidx_l] = 0. ans0 = np.fft.irfft(xf_, norm="ortho") ans[..., 0, ::2] = ans0[..., ::2] * SQRT2 ## m in [1, nf//2) for m in range(1, nf//2): k_sidx = (m - 1) * nt k_eidx = k_sidx + 2 * nt xf_ = xf[..., k_sidx:k_eidx].copy() xf_[..., trans_sidx_r:trans_eidx_r] *= edge xf_[..., trans_sidx_l:trans_eidx_l] *= edge[::-1] # need edge_rev cache? xf_[..., trans_eidx_r:] = 0. xf_[..., :trans_sidx_l] = 0. ans_m = np.fft.ifft(xf_, norm="ortho") neo = 1 ^ (m & 1) ans_m[..., neo::2].real = ans_m[..., neo::2].imag ans_m = 2 * ans_m.real if neo == 0: ans_m[..., ::2] *= -1 ans[..., m, :] = ans_m ## m == nf//2 xf_ = xf[..., -(nt+1):].copy() xf_[..., trans_sidx_l:trans_eidx_l] *= edge[::-1] xf_[..., :trans_sidx_l] = 0. ans1 = np.fft.irfft(xf_, norm="ortho") eo = (nf//2) & 1 ans[..., 0, 1::2] = ans1[..., eo::2] * SQRT2 return ans #TODO: restore moved axis
[docs] def irwdm(w, trans_width, steepness): #TODO: tran_width check, axis. """Compute the inverse of real WDM transform ``rwdm``. Parameters ---------- w : array_like Input WDM coefficients, stored in the last two axis. trans_width : float Width of transition interval. steepness : float Steepness of transition interval. axis : int Returns ------- np.ndarray Inverse of w. """ w = np.asarray(w) w_shape = w.shape nt = (w_shape[-1] + 1) // 2 nf = w_shape[-2] * 2 flat_width = (nt - trans_width) / 2 n = nf * nt xf_shape = w_shape[:-2] + (n//2+1,) xf = np.zeros(xf_shape, dtype=complex) (trans_sidx, trans_eidx), edge = _gen_rolloff(flat_width, trans_width, steepness) trans_sidx_r = trans_sidx + nt trans_eidx_r = trans_eidx + nt trans_sidx_l = -trans_eidx + nt + 1 trans_eidx_l = -trans_sidx + nt + 1 # m == 0 w_ = np.fft.fft(w[..., 0, ::2], norm="ortho") xf[..., :trans_sidx] = w_[..., :trans_sidx] xf[..., trans_sidx:trans_eidx] += w_[..., trans_sidx:trans_eidx] * edge # m in [1, nf//2) for m in range(1, nf//2): k_shift = (m - 1) * nt k_sidx_r = trans_sidx_r + k_shift k_eidx_r = trans_eidx_r + k_shift k_sidx_l = trans_sidx_l + k_shift k_eidx_l = trans_eidx_l + k_shift neo = 1 ^ (m & 1) w_ = np.array(w[..., m, :], dtype=complex) w_[..., neo::2] *= 1j if neo == 0: w_[..., ::2] *= -1 w_ = np.fft.fft(w_, norm="ortho") xf[..., k_sidx_l:k_eidx_l] += w_[..., trans_sidx_l:trans_eidx_l] * edge[::-1] xf[..., k_eidx_l:k_sidx_r] = w_[..., trans_eidx_l:trans_sidx_r] xf[..., k_sidx_r:k_eidx_r] += w_[..., trans_sidx_r:trans_eidx_r] * edge # m = nf//2 m = nf // 2 k_shift = (m - 1) * nt k_sidx_l = trans_sidx_l + k_shift k_eidx_l = trans_eidx_l + k_shift neo = 1 ^ (m & 1) w_ = np.fft.fft(w[..., 0, 1::2], norm="ortho") * np.exp(-1j * PI * np.arange(nt) / nt) xf[..., k_sidx_l:k_eidx_l] += w_[..., trans_sidx_l:trans_eidx_l] * edge[::-1] xf[..., k_eidx_l:-1] = w_[..., trans_eidx_l:] xf[..., -1] = -w_[..., 0] ans = np.fft.irfft(xf, norm="ortho") # ans = xf return ans
# Debug # nt = 6 # nf = 6 # n = nt * nf # x = np.arange(2*n).reshape(2,-1) # trans_width = 2 # w = rwdm(x, nt, nf, trans_width, 2) # xf_ = np.fft.rfft(x, norm="ortho") # xf = irwdm(w, trans_width, 2) # pass