Skip to content

Commit 7dfe0d7

Browse files
authored
Merge pull request #37 from CosmoStat/needlet
add needlet filters
2 parents 3c0fb97 + 8ccf857 commit 7dfe0d7

1 file changed

Lines changed: 116 additions & 69 deletions

File tree

pycs/sparsity/mrs/mrs_starlet.py

Lines changed: 116 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,61 @@
1717
Related Geometric Multiscale Analysis,
1818
Cambridge University Press, Cambridge (GB), 2016.
1919
20-
Example how to use the Class:
20+
Example how to use the Class with 5 scales i.e. 4 wavelet scales + coarse resolution):
2121
CW = MRS_starlet() # Create the class
22+
CW.init_starlet(Nside, nscale=5)
2223
CW.transform(Image) # Starlet transform of a 2D np array
2324
CW.stat() # print statistics of all scales
2425
r = CW.recons() # reconstruct an image from its coefficients
26+
CW.plot_filter() # plot the filters in harmonic space which are used in the wavelet decomposition
2527
more examples are given at the end of this file.
28+
29+
Class variables are:
30+
nx = 0 # number of pixel of the healpix map
31+
ns = 0 # number of scales
32+
coef = 0.0 # Starlet coefficients
33+
TabNorm = 0.0 # Coefficient normalixation table
34+
SigmaNoise = 1.0 # noise standard deviation
35+
TabNsigma = 0 # detection level per scale
36+
verb = False
37+
nside = 0 # nside of the input map
38+
lmax = 0 # lmax used in spherical harmonic decomposition
39+
ALM_iter = 0 # numnber of iteration for the inverse spherical harmonic decomposition
40+
TabNameCode = ["Full python", "c++ Binding", "c++ binary"]
41+
TypeCode = 0 # 0 for 'Full python', '1' for 'c++ Binding' and 2 for'c++ binary'
42+
Tablmax =0 # lmax for each scale
43+
TabSigma =0 # Standard deviation of the Gaussian which fits the scaling function at every scale
44+
TabPhi = 0 # Scaling function for each scale
45+
TabPsi = 0 # Wavelet function for each scale
46+
Tabh = 0 # h filter for each scale
47+
Tabg = 0 # g filter for each scale
48+
TabResol = 0 # Resolution of each wavelet scale in arc minute
49+
PixelResol = 0 # pixel sizr in arc minute
50+
l2norm = False # if True, normlaize the coefficients (l2 normalization) such that the noise standart deviation remains constant through the scales.
51+
NeedletFilter = False # If True, use needlet filters instead of spline filters
52+
53+
Class functions are:
54+
def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None, Needlet=None):
55+
def info(self): # Print information relative to the intialisation.
56+
def stat(self): # Print Min, Max, Mean and standard deviation of all scales.
57+
def plot_filter(self, wavelet=True, scaling=False, hfilter=False, gfilter=False): # plot the filters which are used.
58+
def transform(self, data, WTname=None, opt=None): # Compute the wavelet transform
59+
WaveletScale = def get_scale(self, j): # returns the wavelet coefficients at a given scale : return self.coef[j,:]
60+
def put_scale(self, ScaleCoef, j): # insert a scale in the wavelet transform: self.coef[j,:] = ScaleCoef
61+
Rec = recons(self): reconstruct an image from its wavelet coefficients
62+
DenoiseMap = denoising(self, data, SigmaNoise=0, Nsigma=3, ThresCoarse=False, hard=True): # perform a denoising in wavelet space
63+
64+
def tvs(self,j,min=None,max=None,title=None,sigma=None,lut=None,filename=None, dpi=100) : plot the scale j
65+
def tv(self, log=False, unit="", title="", minimum=None, maximum=None, cbar=True): plot all wavelet scales
66+
def dump(self) : print all variable values of the class
67+
SigmaNoise = get_noise(self) : estimate the noise standard deviation from the first wavelet scale
68+
TabNsigma = get_tabsigma(self, Nsigma=3) # Cretate the table for the detection level. Nsigma can be either a number of an array.
69+
def threshold(self,SigmaNoise=0,Nsigma=3,ThresCoarse=False,hard=True,FirstDetectScale=0,KillCoarse=False,Verbose=False): # Threshold the wavelet coefficient
70+
CopiedClass = copy(self, name="wt"): # Return a copy of the class
71+
def eval_computation_time(self): # Compare the computation time of different implementations of the wavelet transform
2672
"""
2773

28-
import importlib.util
29-
30-
spec = importlib.util.find_spec("pymrs")
31-
if spec is None:
32-
# print("pymrs is available at:", spec.origin)
33-
MRS_CXX = False
34-
else:
35-
pymrs = importlib.util.module_from_spec(spec)
36-
spec.loader.exec_module(pymrs)
37-
MRS_CXX = True
38-
39-
40-
import numpy as np
41-
import random
42-
import os, sys
43-
from scipy import ndimage
44-
import healpy as hp
45-
from astropy.io import fits
46-
import matplotlib.pyplot as plt
47-
from astropy.io import fits
48-
from importlib import reload
49-
from pycs.misc.cosmostat_init import *
50-
from pycs.misc.stats import *
51-
from pycs.misc.mr_prog import *
52-
from pycs.sparsity.mrs.mrs_tools import *
53-
import getpass
54-
import time
55-
from scipy.optimize import curve_fit
74+
5675

5776

5877

@@ -67,13 +86,14 @@ def test_mrs_class(Init=None):
6786
else:
6887
Nside = 1024
6988
d = np.random.normal(size=(Nside**2 * 12))
70-
Ns = 5
89+
Ns = 7
7190
ALM_iter = 0
91+
Needlet=False
92+
lmax=2048
7293
C = CMRStarlet()
73-
C.init_starlet(Nside, nscale=Ns, ALM_iter=ALM_iter)
94+
C.init_starlet(Nside, nscale=Ns, ALM_iter=ALM_iter, lmax=lmax, Needlet=Needlet)
7495
print("PYTHON Code computation time:")
7596
start = time.time()
76-
C.TypeCode=1
7797
C.transform(d)
7898
end = time.time()
7999
print(f"Execution time python: {end - start:.4f} seconds")
@@ -197,7 +217,7 @@ class CMRStarlet:
197217
SigmaNoise = 1.0 # noise standard deviation
198218
TabNsigma = 0 # detection level per scale
199219
verb = False
200-
nside = 0
220+
nside = 0
201221
lmax = 0
202222
ALM_iter = 0
203223
TabNameCode = ["Full python", "c++ Binding", "c++ binary"]
@@ -211,6 +231,7 @@ class CMRStarlet:
211231
TabResol = 0 # Resolution of each wavelet scale in arc minute
212232
PixelResol = 0 # pixel sizr in arc minute
213233
l2norm = False
234+
NeedletFilter = False # If True, use need filters instead of spline filters
214235

215236
# __init__ is the constructor
216237
def __init__(self, name="wt", verb=False):
@@ -230,7 +251,7 @@ def __init__(self, name="wt", verb=False):
230251
self.name = name # self.name is an object variable
231252
self.verb = verb
232253

233-
def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None):
254+
def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None, Needlet=None):
234255
"""
235256
Initialize the scale for a given image size and a number of scales.
236257
Parameters
@@ -245,6 +266,10 @@ def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None):
245266
-------
246267
None.
247268
"""
269+
if Needlet:
270+
self.NeedletFilter=True
271+
self.TypeCode = 0
272+
248273
self.nside = np.int64(nside)
249274
self.nx = 12 * self.nside * self.nside
250275
if lmax != 0:
@@ -265,6 +290,8 @@ def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None):
265290
else:
266291
if nscale == 0:
267292
nscale = np.int64(np.log(self.nside) // 1) + 1
293+
self.TabSigma = np.zeros(nscale)
294+
268295
self.ns = np.int64(nscale)
269296

270297
if ALM_iter != 0:
@@ -274,12 +301,27 @@ def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None):
274301
CMRS = pymrs.MRS()
275302
CMRS.alloc(nside, self.ns, self.lmax, self.ALM_iter, self.verb)
276303

277-
if TabResolSigma is None:
278-
self.Tablmax, self.TabSigma, self.TabPhi, self.TabPsi, self.Tabh, self.Tabg = get_default_filters(self.nside, self.ns)
279-
# print("Default TabSigma = ", self.TabSigma)
304+
if Needlet is None:
305+
if TabResolSigma is None:
306+
self.Tablmax, self.TabSigma, self.TabPhi, self.TabPsi, self.Tabh, self.Tabg = get_default_filters(self.nside, self.ns)
307+
# print("Default TabSigma = ", self.TabSigma)
308+
else:
309+
self.TabPhi, self.TabPsi, self.Tabh, self.Tabg = get_sigmafilters(self.TabSigma, self.lmax, Phi0Spline=False)
280310
else:
281-
self.TabPhi, self.TabPsi, self.Tabh, self.Tabg = get_sigmafilters(self.TabSigma, self.lmax, Phi0Spline=False)
282-
311+
filters = mrs_needlet_filters(self.lmax, NbrScale=self.ns)
312+
self.Tabh = filters["TabFilterH"]
313+
self.Tabg = filters["TabFilterG"]
314+
self.TabPhi = filters["TabPhi"]
315+
self.TabPsi = filters["TabPsi"]
316+
self.Tablmax = np.zeros((self.ns))
317+
lm = self.lmax
318+
for j in range(0,self.ns):
319+
self.Tablmax[j] = lm
320+
lm = lm / 2
321+
self.TabSigma = splinelmax2sigma(self.Tablmax)
322+
323+
# print("ns = ", self.ns, self.TabSigma.shape )
324+
283325
self.TabResol = np.zeros(self.ns)
284326
self.TabResol[0] = self.TabSigma[0]
285327
for j in range(1,self.ns-1):
@@ -295,11 +337,12 @@ def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None):
295337
DeltaPhi0 = np.sqrt(1. - Phi0Lmax**2)
296338
# the alm m around lmax are often not accurate, and 1 if a conservative value, the
297339
# the correct value is most likely between 0.9 and 1.
298-
self.TabNorm[0] = np.sqrt( DeltaPhi0**2 + (sigma_filter(self.TabPsi[:,0], self.nside, lmax=self.lmax, PixelWindow=PixelWindow))**2)
299-
for j in range(1,self.ns):
340+
for j in range(0,self.ns):
300341
self.TabNorm[j] = sigma_filter(self.TabPsi[:,j], self.nside, lmax=self.lmax, PixelWindow=PixelWindow)
301-
302-
def info(self): # sound is a method (a method is a function of an object)
342+
if Needlet is None:
343+
self.TabNorm[0] = np.sqrt( DeltaPhi0**2 + (sigma_filter(self.TabPsi[:,0], self.nside, lmax=self.lmax, PixelWindow=PixelWindow))**2)
344+
345+
def info(self):
303346
"""
304347
Print information relative to the intialisation.
305348
"""
@@ -377,7 +420,10 @@ def transform(self, data, WTname=None, opt=None):
377420
self.coef = mrs_uwttrans(im, self.ns, self.lmax, opt=opt, verbose=self.verb, path="./", cxx=True)
378421
else:
379422
# print(self.ns, self.TabPhi.shape)
380-
self.coef = wt_phi_filter_trans(im, self.TabPhi)
423+
if self.NeedletFilter is False:
424+
self.coef = wt_phi_filter_trans(im, self.TabPhi)
425+
else:
426+
self.coef = mrs_needlet_transform(im, self.TabPsi)
381427
# self.coef = mrs_uwttrans(im,self.ns,self.lmax,opt=None,verbose=self.verb,path="./",cxx=False)
382428

383429
if self.l2norm is True:
@@ -400,7 +446,12 @@ def recons(self):
400446
if self.l2norm is True:
401447
for j in range(self.ns):
402448
self.coef[j, :] = self.coef[j, :] * self.TabNorm[j]
403-
return np.sum(self.coef, axis=0)
449+
450+
if self.NeedletFilter is False:
451+
rec = np.sum(self.coef, axis=0)
452+
else:
453+
rec = mrs_needlet_recons(self.coef, self.TabPsi)
454+
return rec
404455

405456
def denoising(self, data, SigmaNoise=0, Nsigma=3, ThresCoarse=False, hard=True):
406457
"""
@@ -435,6 +486,21 @@ def denoising(self, data, SigmaNoise=0, Nsigma=3, ThresCoarse=False, hard=True):
435486
)
436487
return self.recons()
437488

489+
def get_scale(self, j):
490+
"""
491+
Return the scale j in self.coef
492+
Parameters
493+
----------
494+
j : int
495+
Scale number. It must be in [0:self.ns].
496+
Returns
497+
-------
498+
None.
499+
500+
"""
501+
return self.coef[j, :]
502+
503+
438504
def put_scale(self, ScaleCoef, j):
439505
"""
440506
Replace the scale j in self.coef by the 2D array ScaleCoef.
@@ -451,17 +517,8 @@ def put_scale(self, ScaleCoef, j):
451517
"""
452518
self.coef[j, :] = ScaleCoef
453519

454-
def tvs(
455-
self,
456-
j,
457-
min=None,
458-
max=None,
459-
title=None,
460-
sigma=None,
461-
lut=None,
462-
filename=None,
463-
dpi=100,
464-
):
520+
521+
def tvs(self,j,min=None,max=None,title=None,sigma=None,lut=None,filename=None, dpi=100):
465522
"""
466523
Display the scale j
467524
Parameters
@@ -561,16 +618,7 @@ def get_tabsigma(self, Nsigma=3):
561618
TabNsigma = Nsigma[:nscale]
562619
return TabNsigma
563620

564-
def threshold(
565-
self,
566-
SigmaNoise=0,
567-
Nsigma=3,
568-
ThresCoarse=False,
569-
hard=True,
570-
FirstDetectScale=0,
571-
KillCoarse=False,
572-
Verbose=False,
573-
):
621+
def threshold(self,SigmaNoise=0,Nsigma=3,ThresCoarse=False,hard=True,FirstDetectScale=0,KillCoarse=False,Verbose=False):
574622
"""
575623
Apply a hard or a soft thresholding on the coefficients self.coef
576624
Parameters
@@ -658,6 +706,7 @@ def copy(self, name="wt"):
658706
x = self
659707
x.name = name
660708
x.coef = np.zeros((x.ns, x.nx))
709+
x.coef[:,:] = self.coef[:,:]
661710
x.TabNorm = np.copy(self.TabNorm)
662711
return x
663712

@@ -928,7 +977,7 @@ def plotsig(T, x=None, title="Spherical wavelet Filters", xlabel="X", ylabel="T[
928977

929978
plt.figure(figsize=(10, 6))
930979

931-
plt.plot(x, T / T.max(), label=f"{legend_prefix}")
980+
plt.plot(x, T , label=f"{legend_prefix}")
932981

933982
plt.title(title)
934983
plt.xlabel(xlabel)
@@ -1120,8 +1169,6 @@ def test_wt_hfilter_trans():
11201169
w1 = wt_trans(d, nscales=4)
11211170
return wts
11221171

1123-
1124-
11251172
def get_sigma_from_spline(lmax, hfilter=False):
11261173
# Range of l values (spherical harmonics degrees)
11271174
l_vals = np.arange(0, lmax+1)

0 commit comments

Comments
 (0)