1212from pycs .misc .mr_prog import *
1313from pycs .sparsity .mrs .mrs_tools import *
1414
15+
1516def mrs_starlet (map , nscale = None , lmax = None ):
1617 nside = gnside (map )
1718 if nscale is None :
@@ -20,16 +21,26 @@ def mrs_starlet(map, nscale=None, lmax=None):
2021 Ns = nscale
2122
2223 npix = map .shape [0 ]
23- w = wt_trans (map , lmax = lmax ,nscales = Ns - 1 )
24+ w = wt_trans (map , lmax = lmax , nscales = Ns - 1 )
2425 trans = w .T
2526 return trans
2627
28+
2729def mrs_istarlet (trans ):
2830 r = np .sum (trans , axis = 0 )
2931 return r
3032
3133
32- def mrs_uwttrans (map , nscale = None , lmax = None , opt = None , verbose = False , path = "./" , progpath = None , cxx = False ):
34+ def mrs_uwttrans (
35+ map ,
36+ nscale = None ,
37+ lmax = None ,
38+ opt = None ,
39+ verbose = False ,
40+ path = "./" ,
41+ progpath = None ,
42+ cxx = False ,
43+ ):
3344 nside = gnside (map )
3445 if nscale is None :
3546 Ns = np .log2 (nside ) - 2
@@ -58,11 +69,10 @@ def mrs_uwttrans(map, nscale=None, lmax=None, opt=None, verbose=False, path="./"
5869 )
5970 else :
6071 npix = map .shape [0 ]
61- w = wt_trans (map , lmax = lmax ,nscales = Ns - 1 )
62- p = np .zeros (Ns , npix )
72+ w = wt_trans (map , lmax = lmax , nscales = Ns - 1 )
73+ p = np .zeros (( Ns , npix ) )
6374 for j in range (Ns ):
64- print (j + 1 )
65- p [j ,:] = w [:,j ]
75+ p [j , :] = w [:, j ]
6676
6777 return p
6878
@@ -89,9 +99,9 @@ def mrs_uwtrecons(Tmap, lmax=None, opt=None, verbose=False, path="./", progpath=
8999 return p
90100
91101
92-
93102# Wavelet filtering
94103
104+
95105def spline2 (size , l , lc ):
96106 """
97107 Compute a non-negative decreasing spline, with value 1 at index 0.
@@ -111,9 +121,20 @@ def spline2(size, l, lc):
111121 (size,) float array, spline
112122 """
113123
114- res = np .arange (0 , size + 1 )
115- res = 2 * l * res / (lc * size )
116- res = (3 / 2 ) * 1 / 12 * (abs (res - 2 )** 3 - 4 * abs (res - 1 )** 3 + 6 * abs (res )** 3 - 4 * abs (res + 1 )** 3 + abs (res + 2 )** 3 )
124+ res = np .arange (0 , size + 1 )
125+ res = 2 * l * res / (lc * size )
126+ res = (
127+ (3 / 2 )
128+ * 1
129+ / 12
130+ * (
131+ abs (res - 2 ) ** 3
132+ - 4 * abs (res - 1 ) ** 3
133+ + 6 * abs (res ) ** 3
134+ - 4 * abs (res + 1 ) ** 3
135+ + abs (res + 2 ) ** 3
136+ )
137+ )
117138 return res
118139
119140
@@ -134,10 +155,10 @@ def compute_h(size, lc):
134155 (size,) float array, filter
135156 """
136157
137- tab1 = spline2 (size , 2 * lc , 1 )
158+ tab1 = spline2 (size , 2 * lc , 1 )
138159 tab2 = spline2 (size , lc , 1 )
139- h = tab1 / (tab2 + 1e-6 )
140- h [np .int64 (size / ( 2 * lc )): size ] = 0
160+ h = tab1 / (tab2 + 1e-6 )
161+ h [np .int64 (size / ( 2 * lc )) : size ] = 0
141162 return h
142163
143164
@@ -158,10 +179,10 @@ def compute_g(size, lc):
158179 (size,) float array, filter
159180 """
160181
161- tab1 = spline2 (size , 2 * lc , 1 )
182+ tab1 = spline2 (size , 2 * lc , 1 )
162183 tab2 = spline2 (size , lc , 1 )
163- g = (tab2 - tab1 )/ (tab2 + 1e-6 )
164- g [np .int64 (size / ( 2 * lc )): size ] = 1
184+ g = (tab2 - tab1 ) / (tab2 + 1e-6 )
185+ g [np .int64 (size / ( 2 * lc )) : size ] = 1
165186 return g
166187
167188
@@ -181,9 +202,11 @@ def get_wt_filters(lmax, nscales):
181202 (lmax+1,nscales+1) float array, filters
182203 """
183204
184- wt_filters = np .ones ((lmax + 1 , nscales + 1 ))
185- wt_filters [:, 1 :] = np .array ([compute_h (lmax , 2 ** scale ) for scale in range (nscales )]).T
186- wt_filters [:, :nscales ] -= wt_filters [:, 1 :(nscales + 1 )]
205+ wt_filters = np .ones ((lmax + 1 , nscales + 1 ))
206+ wt_filters [:, 1 :] = np .array (
207+ [compute_h (lmax , 2 ** scale ) for scale in range (nscales )]
208+ ).T
209+ wt_filters [:, :nscales ] -= wt_filters [:, 1 : (nscales + 1 )]
187210 return wt_filters
188211
189212
@@ -247,10 +270,10 @@ def wt_trans(inputs, nscales=3, lmax=None, alm_in=False, nside=None, alm_out=Fal
247270 l_scale = alms .copy ()
248271 if dim_inputs == 1 :
249272 npix = np .size (alms )
250- wts = np .zeros ((npix , nscales + 1 ), dtype = ' complex' )
273+ wts = np .zeros ((npix , nscales + 1 ), dtype = " complex" )
251274 else :
252275 npix = np .shape (alms )[1 ]
253- wts = np .zeros ((np .shape (maps )[0 ], npix , nscales + 1 ), dtype = ' complex' )
276+ wts = np .zeros ((np .shape (maps )[0 ], npix , nscales + 1 ), dtype = " complex" )
254277
255278 scale = 1
256279 for j in range (nscales ):
@@ -295,7 +318,8 @@ def wt_rec(wts):
295318
296319# Plots
297320
298- def mrs_tv (maps , log = False , unit = '' , title = '' , minimum = None , maximum = None , cbar = True ):
321+
322+ def mrs_tv (maps , log = False , unit = "" , title = "" , minimum = None , maximum = None , cbar = True ):
299323 """Plot one or more Healpix maps in Mollweide projection.
300324
301325 Parameters
@@ -325,19 +349,31 @@ def mrs_tv(maps, log=False, unit='', title='', minimum=None, maximum=None, cbar=
325349
326350 if minimum is None :
327351 minimum = np .min (maps )
328-
352+
329353 if maximum is None :
330354 maximum = np .max (maps )
331-
355+
332356 if not log :
333- def f (x ): return x
357+
358+ def f (x ):
359+ return x
360+
334361 else :
335- def f (x ): return np .log10 (x - minimum + 1 )
362+
363+ def f (x ):
364+ return np .log10 (x - minimum + 1 )
365+
336366 for i in range (np .shape (maps )[0 ]):
337367 if title :
338- tit = title + ": Scale " + str (i + 1 )
368+ tit = title + ": Scale " + str (i + 1 )
339369 else :
340- tit = "Scale " + str (i + 1 )
341- hp .mollview (f (maps [i , :]), fig = None , unit = unit , title = tit , min = f (minimum ), max = f (maximum ), cbar = cbar )
342-
343-
370+ tit = "Scale " + str (i + 1 )
371+ hp .mollview (
372+ f (maps [i , :]),
373+ fig = None ,
374+ unit = unit ,
375+ title = tit ,
376+ min = f (minimum ),
377+ max = f (maximum ),
378+ cbar = cbar ,
379+ )
0 commit comments