-
-
Notifications
You must be signed in to change notification settings - Fork 499
Allow user to set precision
in CWT, increase default to 12
#570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
precision
, increase default to 12precision
in CWT, increase default to 12
Higher `precision` doesn't help with lower scales
Thanks, I am fine with exposing this parameter. I will take a look at the info you found soon and review it a little closer. |
@grlee77 Some interesting findings here (see bottom); to my surprise, I'm curious on rationale behind integrated wavelet, which isn't documented on the site you found; perhaps it can be improved further. |
… implement the precision option
hi @grlee77 is that PR still planned ? |
I've meant to write about this a while back, and I think I still intend to, but in short, I strongly discourage pywt's CWT, some info here. |
*write to the team |
Wow, that's a detailed analysis @OverLordGoldDragon, thanks for sharing. Does this PR help improve the behavior demonstrated in that StackExchange post? |
Thanks @OverLordGoldDragon for the work. May I ask also ask about your ssqueeze package. what would it cost to merge with pywavelet? I believe this address most of the drawback in your stack post. |
@rgommers Yes, but significant problems remain, and are unsolvable with parameter changes. Sorry for being blunt, but I rather state it sooner than wait another year+ with many people still using PyWavelets' CWT is severely flawed, and should be reimplemented entirely. In the meantime, a warning should be thrown against its use. |
Here's the script used in the SE post, in messy but runnable form. script# -*- coding: utf-8 -*-
"""
Demonstrate the utility of `wavespin.toolkit.validate_filterbank()` while
explaining wavelet filterbank concepts.
Code reproduces
"How to validate a wavelet filterbank (CWT)?", John Muradeli,
https://dsp.stackexchange.com/a/86069/50076
Code follows the order of the post, except for the scipy and PyWavelets
examples, moved to bottom.
"""
import numpy as np
# from wavespin import Scattering1D, TimeFrequencyScattering1D
# from wavespin.scattering1d.filter_bank import morlet_1d
# from wavespin.toolkit import validate_filterbank
# from wavespin.visuals import (plot, imshow, plotscat, filterbank_jtfs_1d,
# filterbank_scattering, filterbank_heatmap)
from ssqueezepy import cwt, ssq_cwt, TestSignals
from ssqueezepy.utils import cwt_scalebounds, make_scales
from ssqueezepy.visuals import plot, imshow, plotscat
from scipy import signal
from scipy.fft import fft, ifft, ifftshift
# import warnings
# warnings.filterwarnings("error")
#%% helper method ############################################################
# def validate(ts, **kw):
# psi1_f = [p[0] for p in ts.psi1_f]
# phi_f = ts.phi_f[0]
# kw['verbose'] = kw.get('verbose', 1)
# _ = validate_filterbank(psi1_f, phi_f, for_real_inputs=True,
# unimodal=True, **kw)
def to_numpy(ts):
for i, pf in enumerate(ts.psi1_f):
ts.psi1_f[i] = {k: (v.numpy() if hasattr(v, 'numpy') else v)
for k, v in pf.items()}
ts.phi_f = {k: (v.numpy() if hasattr(v, 'numpy') else v)
for k, v in ts.phi_f.items()}
#%% common configs ###########################################################
N = 2048
Q = (8, 8)
log2_N = int(np.log2(N))
ckw = dict(shape=N, Q=Q, max_order=1)#, frontend='numpy')
pkw = dict(lp_sum=1, lp_phi=1, plot_kw={'w': .9}, second_order=0)
hkw = dict(w=.8, h=.9)
#%% Baseline: Generic filterbank #############################################
# ts = Scattering1D(**ckw, max_pad_factor=1, analytic=False, normalize='l1',
# J=log2_N - 2)
# to_numpy(ts)
# filterbank_scattering(ts, **pkw)
# validate(ts)
# #%% Actually generic #########################################################
# # ("Generic" as in matching a naive CWT implementation)
# ts = Scattering1D(**ckw, max_pad_factor=1, analytic=False, normalize='l1',
# J=log2_N)
# to_numpy(ts)
# filterbank_scattering(ts, **pkw)
# filterbank_scattering(ts, **pkw, zoom=8)
# validate(ts)
# #%%
# filterbank_heatmap(ts, **hkw)
#%% Sufficiently padded #######################################################
# ts = Scattering1D(**ckw, max_pad_factor=None, analytic=False, normalize='l1',
# J=log2_N)
# to_numpy(ts)
# filterbank_scattering(ts, **pkw)
# filterbank_scattering(ts, **pkw, zoom=8)
# filterbank_heatmap(ts, **hkw)
#%% Energy norm, simple `sqrt(2)`
# ts = Scattering1D(**ckw, max_pad_factor=None, analytic=False, normalize='l1',
# J=log2_N)
# to_numpy(ts)
# for p in ts.psi1_f:
# p[0] *= np.sqrt(2)
# filterbank_scattering(ts, **pkw)
# filterbank_scattering(ts, **pkw, zoom=8)
# validate(ts)
#%% L2 norm ##################################################################
# ts = Scattering1D(**ckw, max_pad_factor=None, analytic=False, normalize='l2',
# J=log2_N)
# to_numpy(ts)
# filterbank_scattering(ts, **pkw)
# filterbank_scattering(ts, **pkw, zoom=8)
#%% Incomplete tiling ########################################################
# ts = Scattering1D(**ckw, max_pad_factor=None, analytic=False, normalize='l1',
# J=log2_N, r_psi=.0001)
# to_numpy(ts)
# filterbank_scattering(ts, **pkw)
# validate(ts)
# there isn't user-facing code for the under-tiling example, since the filterbank
# is designed to avoid this; it's achieved via
# `num_intermediate = Q` -> `num_intermediate = Q - 2` in
# `wavespin.scattering1d.filter_bank.py`.
# similar undertiling can be achieved with `Q - 1` but different parameters,
# kept things simple for this example and used an extreme `r_psi`.
#%% Incorrect frequency-bandwidth tiling #####################################
# f_min = N//100
# f_max = N//2 - 10
# f_all = np.round(np.logspace(np.log10(f_min), np.log10(f_max), 80)
# ).astype(int)
# pf_all = np.array([morlet_1d(N, xi=f/N, sigma=20/N) for f in f_all[:]])
# # zero-mean
# lp_sum = np.sum(np.abs(pf_all)**2, axis=0)
# _pkw = dict(w=.7, h=.9, show=1)
# plot(pf_all.real.T, color='tab:blue',
# title="Wavelet filterbank, exp-spaced freqs, const. bandwidth", **_pkw)
# plot(lp_sum, title="Littlewood-Paley sum", **_pkw, ylims=(0, None))
# _ = validate_filterbank(pf_all, for_real_inputs=True, unimodal=True, verbose=True)
#%% High redundancy ##########################################################
# ckw['Q'] = (16, 2)
# ts = Scattering1D(**ckw, max_pad_factor=None, analytic=0, normalize='l1-energy',
# J=log2_N, r_psi=.99)
# to_numpy(ts)
# filterbank_scattering(ts, **pkw)
# validate(ts)
#%% Proper filterbank
# ckw['Q'] = 8
# ts = Scattering1D(**ckw, max_pad_factor=None, analytic=False,
# normalize='l1-energy', J=log2_N)
# to_numpy(ts)
# filterbank_scattering(ts, **pkw)
# filterbank_scattering(ts, **pkw, zoom=8)
#%% Tight frame attempt
# ckw['Q'] = 256
# ts = Scattering1D(**ckw, max_pad_factor=None, analytic=1,
# normalize='l1-energy', J=log2_N, r_psi=.98)
# to_numpy(ts)
# filterbank_scattering(ts, **pkw)
# filterbank_scattering(ts, **pkw, zoom=10)
#%% Temporal peak
# pf = morlet_1d(N, xi=1.5/N, sigma=1.5/N)
# pt = ifftshift(ifft(pf))
# _ckw = dict(w=.7, h=.9)
# plot(pf.real, show=1, **_ckw, title="Low frequency Morlet, freq domain")
# plot(pt, complex=2, **_ckw, show=1, title="Time domain")
# _ = validate_filterbank([pf])
#%% Decay ####################################################################
# pf = morlet_1d(N, xi=200/N, sigma=10/N)
# pf += np.roll(pf, 100) / 10
# _ckw = dict(w=.7, h=.9)
# plot(pf.real, show=1, **_ckw, title="Non-permanently decayed wavelet")
# _ = validate_filterbank([pf])
#%% non-smooth decay case ####
# pf = morlet_1d(N, xi=200/N, sigma=10/N)
# slc = pf.copy()
# slc[:210] = 0
# pf[210:] = 0
# pf += np.roll(slc, 50)
# pf[210:260] = pf[209]
# pt = ifftshift(ifft(pf))
# _ckw = dict(w=.7, h=.9, show=1)
# plot(pf, **_ckw, title="Non-smoothly decayed wavelet; freq domain")
# idxs = np.arange(N//2-250, N//2+250+1)
# plot(idxs, pt[idxs], complex=2, **_ckw, title="Time domain (zoomed)")
# _ = validate_filterbank([pf])
#%% Aliasing #################################################################
# _N = 1024
# pf0 = morlet_1d(_N, xi=450/_N, sigma=40/_N)
# pf1 = fft(ifft(pf0)[::2]).real # imag stays zero
# pf2 = fft(ifft(pf0)[::4]).real
# _ckw = dict(w=.7, h=.9)
# plot(pf0.real, show=1, **_ckw, title="Morlet")
# plot(pf1.real, show=1, **_ckw, title="Morlet subsampled by 2")
# # need at least 6 filters for alias detector to work properly
# psi_fs = [pf0, pf1, pf2, fft(ifft(pf2)[::2]),
# fft(ifft(pf2)[::4]), fft(ifft(pf2)[::8])]
# _ = validate_filterbank(psi_fs)
#%% Analyticity ##############################################################
# ckw['Q'] = 16
# ts0 = Scattering1D(**ckw, max_pad_factor=None, analytic=False,
# normalize='l1', J=log2_N-7, r_psi=.99, T=1, smart_paths=0)
# ts1 = Scattering1D(**ckw, max_pad_factor=None, analytic=True,
# normalize='l1', J=log2_N-7, r_psi=.99, T=1, smart_paths=0)
# t = np.linspace(0, 1, N, 1)
# x = np.cos(2*np.pi * (N//2 - 16) * t)
# plot(x)
#%% CWT pure sine
# out0 = ts0(x)[1:]
# out1 = ts1(x)[1:]
# _ckw = dict(w=.7, h=.9, abs=1)
# imshow(out0, **_ckw, title="|cwt(x, strict_analytic=False)|")
# imshow(out1, **_ckw, title="|cwt(x, strict_analytic=True)|")
#%% SSQ_CWT, hchirp
# automated scale generation won't always reach largest possible scale,
# we exaggerate `N`
# _N = 4096
# min_scale, max_scale = cwt_scalebounds('gmw', N=3*_N, preset='maximal',
# use_padded_N=1)
# scales = make_scales(_N, min_scale, max_scale, nv=32, scaletype='log',
# wavelet='gmw')
# x = TestSignals().hchirp(N=_N, fmin=.25)[0]
# # the two wavelets are configured to yield similar time-frequency resolution
# Tx0, *_ = ssq_cwt(x, scales=scales, wavelet=('morlet', {'mu': 2.5}))
# Tx1, *_ = ssq_cwt(x, scales=scales, wavelet=('gmw', {'gamma': 1, 'beta': 1}))
# _ckw = dict(abs=1, w=.7, h=.9, yticks=0)
# imshow(Tx0, **_ckw, title="|SSQ_CWT(hyperbolic_chirp, 'morlet')|")
# imshow(Tx1, **_ckw, title="|SSQ_CWT(hyperbolic_chirp, 'gmw')|")
#%% time resolution & tail images
# See https://github.com/jonathanlilly/jLab/issues/13
#%% non-halved Nyquist ####
# pf = morlet_1d(N, xi=.48, sigma=40/N)
# pf[len(pf)//2+1:] = 0
# plot(pf)
# _ = validate_filterbank([pf, pf])
#%% Analytic & anti-analytic example #########################################
# jtfs = TimeFrequencyScattering1D(shape=N, Q=8, J=log2_N, J_fr=4, F=2**4,
# max_pad_factor_fr=None, max_pad_factor=None,
# frontend='numpy')
# _ = filterbank_jtfs_1d(jtfs, zoom=-1, lp_sum=1, plot_kw={'w': .9, 'h': .85})
#%% Non-zero phase, non-zero mean ############################################
# Show scipy's filterbank
# _N = 4096
# wavelet = signal.morlet2
# t = np.linspace(0, 1, _N, 1)
# data = np.hstack([np.cos(2*np.pi * 4 * t),
# np.cos(2*np.pi * len(t)//2 * t)])
# data = np.cos(2*np.pi * len(t)//3 * t)
# dtype = np.complex128
# # be fair to `widths` as scipy provides no bounds check and it's easy to
# # distort, but also not too fair as to stay far from Nyquist and DC,
# # as the implem should account for these cases
# widths = np.logspace(np.log10(1.5), np.log10(1600), 100)
# # replicate its convolution then assert equality
# output = np.empty((len(widths), len(data)), dtype=dtype)
# wavs = []
# for ind, width in enumerate(widths):
# Nw = np.min([10 * width, len(data)])
# wd = np.conj(wavelet(Nw, s=width)[::-1])
# output[ind] = signal.convolve(data, wd, mode='same')
# wavs.append(wd)
# # assert equality
# output_scipy = signal.cwt(data, wavelet, widths)
# assert np.allclose(output, output_scipy)
# # print report
# _ = validate_filterbank(wavs, for_real_inputs=True, unimodal=True,
# is_time_domain=True, criterion_amplitude=1e-3,
# verbose=1)
#%% show freq-domain filterbank ###
# recreate logic of `validate_filterbank`, which pads consistently with
# `np.convolve(, mode='same')` (same as `signal.convolve`)
# see further below for full validation ("Scipy: fully implement"...)
# wavs = [p.squeeze() for p in wavs]
# # fetch max length
# max_len = max(len(p) for p in wavs)
# # pad to next power of 2
# max_len_wav = int(2**(1 + np.round(np.log2(max_len))))
# # take to freq or pad to max length
# _wavs_f = [] # store processed filters
# for p in wavs:
# if len(p) != max_len_wav:
# orig_len = len(p)
# p = np.pad(p, [0, max_len_wav - orig_len])
# center_idx = int(np.ceil(orig_len / 2))
# p = np.roll(p, -(center_idx - 1))
# p = fft(p)
# else:
# center_idx = int(np.ceil(len(p) / 2))
# p = np.roll(p, -(center_idx - 1))
# p = fft(p)
# _wavs_f.append(p)
# wavs_f = np.array(_wavs_f)
# plot(wavs_f.T, complex=1, show=1,
# title="scipy.morlet2 filterbank, real & imag parts")
#%% show example wavelets near DC & Nyquist ####
# _kw = dict(complex=1, show=1, w=.6, h=.8)
# plot(wavs_f[1], **_kw, title="Example wavelet: nonzero phase")
# plotscat(wavs_f[-3][:20], **_kw, title="Example wavelet: nonzero mean")
#%% zoom away from bound effs, and account for ssqueezepy's conservative
# high freq converage (which it can do safely, unlike scipy)
# _ckw = dict(w=.8, h=.9)
# tidxs = np.arange(40, 240)
# imshow(output[:, tidxs], xticks=tidxs,
# title="scipy.cwt of high-freq pure sine, real part, zoomed", **_ckw)
# output_s = cwt(data, padtype='zero')[0]
# fidxs = np.arange(14, 100)
# imshow(output_s[fidxs][:, tidxs], title="ssqueezepy.cwt, real part, zoomed",
# **_ckw, yticks=fidxs, xticks=tidxs)
#%% show individual rows, overlapped
# _ckw = dict(color='tab:blue', show=1, yticks=0, w=.7, h=.9)
# idx = np.argmax(np.sum(np.abs(output.real)**2, axis=-1))
# plot(output[idx-7:idx+7].real.T[20:40], **_ckw,
# title="scipy.cwt, real part, 14 high-energy rows, zoomed")
# idx = np.argmax(np.sum(np.abs(output_s.real)**2, axis=-1))
# plot(output_s[idx-7:idx+7].real.T[20:40],
# title="ssqueezepy.cwt, real part, 14 high-energy rows, zoomed", **_ckw)
#%% zero-mean case ####
# _N = 4096
# t = np.linspace(0, 1, _N, 1)
# x = np.cos(2*np.pi * 96 * t)
# v0 = np.cos(2*np.pi * .5 * t) / 1.3
# v1 = np.ones(len(t))
# x0 = x + v0
# x1 = x + v1
# # implement padding, be extra safe
# x0p = np.pad(x0, 2*_N, mode='reflect')
# x1p = np.pad(x1, 2*_N, mode='reflect')
# morl_fn = lambda N, *a, **k: signal.morlet2(min(N, _N), w=5, *a, **k)
# out0 = signal.cwt(x0p, morl_fn, widths)
# out1 = signal.cwt(x1p, morl_fn, widths)
# mx = max(np.abs(out0).max(), np.abs(out1).max())
# _ckw = dict(abs=1, w=.7, h=.9, norm=(0, mx), yticks=0)
# # show unpadded
# imshow(out0[:, 2*_N:-2*_N], **_ckw, title="|scipy.cwt(x0)|")
# imshow(out1[:, 2*_N:-2*_N], **_ckw, title="|scipy.cwt(x1)|")
#%% show the signals
# _ckw = dict(w=.6, h=.8, show=1)
# plot(x0, **_ckw, title="x0")
# plot(x1, **_ckw, title="x1")
#%% Scipy: fully implement output-matching demonstration #####################
# Show that the frequency domain values shown above match scipy's unpadded
# procedure that's done via `convolve`
# x = np.random.randn(len(data))
# output_fftconv = np.empty((len(widths), len(x)), dtype=dtype)
# pad_right = (max_len_wav - len(x)) // 2
# pad_left = max_len_wav - len(x) - pad_right
# xp = np.pad(x, [pad_left, pad_right])
# xpf = fft(xp)
# for ind, wav_f in enumerate(wavs_f):
# o = ifft(wav_f * xpf)
# output_fftconv[ind] = o[pad_left:-pad_right]
# # assert equality
# output_scipy = signal.cwt(x, wavelet, widths)
# assert np.allclose(output_fftconv, output_scipy)
# note we don't (and can't) account for unpadding, so the match isn't exact,
# but it's an excellent approximation (and likely the best we can do in the
# general case due to aliasing)
#%% Bonus: PyWavelets ########################################################
# replicate pywt internals to demonstrate flawed sampling
import pywt
from pywt._extensions._pywt import DiscreteContinuousWavelet
from pywt._functions import integrate_wavelet
wavelet_name = 'cmor2-2'
wavelet = DiscreteContinuousWavelet(wavelet_name)
scales = np.logspace(np.log10(4.1), np.log10(1000), 40)
int_psi, x_ = integrate_wavelet(wavelet, precision=10)
int_psi = np.conj(int_psi)
psis = []
psis_nol1 = []
max_len_wav = None
for scale in scales[::-1]:
step = x_[1] - x_[0]
j = np.arange(scale * (x_[-1] - x_[0]) + 1) / (scale * step)
j = j.astype(int)
if j[-1] >= int_psi.size:
j = np.extract(j < int_psi.size, j)
p = int_psi[j][::-1]
p = -np.diff(p)
if max_len_wav is None:
# set based on largest scale
max_len_wav = int(2**(1 + np.round(np.log2(len(p)))))
# pad to common length, center about n=0 based on length, take to freq, append
# repeat scipy's logic per pywt's default `method='conv'`
if len(p) < max_len_wav:
orig_len = len(p)
p = np.pad(p, [0, max_len_wav - orig_len])
center_idx = int(np.ceil(orig_len / 2))
p = np.roll(p, -(center_idx - 1))
else:
center_idx = int(np.ceil(len(p) / 2))
p = np.roll(p, -(center_idx - 1))
pf = fft(p)
psis_nol1.append(pf.copy())
pf /= np.abs(pf).max()
psis.append(pf)
psis = np.array(psis)
plot(psis.T, abs=1, color='tab:blue', show=1, w=.8,
title="'cmor2-2' filterbank; L1-normed; abs")
plot(psis.T.real, color='tab:blue', show=0, w=.8)
plot(psis.T.imag, color='tab:orange', show=1, w=.8,
title="'cmor2-2' filterbank; L1-normed; real & imag")
plot(psis[-1].real, color='tab:blue', w=.8)
plot(psis[-1].imag, color='tab:orange', show=1, w=.8,
title="Highest freq filter of same filterbank; real & imag")
plot(psis[0].real, color='tab:blue', w=.8)
plot(psis[0].imag, color='tab:orange', show=1, w=.8,
title="Lowest freq filter of same filterbank; real & imag")
#%% PyWavelets: fully implement output-matching demonstration ################
N = 8192
x = np.random.randn(N)
output_fftconv_pywt = np.zeros((len(scales), N), dtype='complex128')
pad_right = (max_len_wav - N) // 2
pad_left = max_len_wav - N - pad_right
xp = np.pad(x, [pad_left, pad_right])
xpf = fft(xp)
for i, (scale, pf) in enumerate(zip(scales, psis_nol1[::-1])):
c = ifft(pf * xpf)[pad_left:-pad_right] * np.sqrt(scale)
output_fftconv_pywt[i] = c
# assert equality
output_pywt = pywt.cwt(x, scales, wavelet_name)[0]
assert np.allclose(output_fftconv_pywt, output_pywt) |
Addresses this Issue. Detailed explanation here; a summary:
precision
distorts CWT, and heavily at high scales.64
are increasingly distorted.precision=10
is low, andprecision=12
does not add much computation cost, while considerably remedying 1-3.At the least, the user should get to decide
precision
instead of carving it in stone. And a doc/comment should be made about these caveats.Improvement comparison
Code