Skip to content

Commit 48824bc

Browse files
modified the l1-norm on the sphere function to use the new mrs_uwttrans code + fixed bug in mre_uwttrans
1 parent 4395587 commit 48824bc

2 files changed

Lines changed: 70 additions & 33 deletions

File tree

pycs/astro/wl/hos_peaks_l1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pycs.sparsity.sparse2d.dct_inpainting import dct_inpainting
2121
from pycs.misc.im_isospec import *
2222
from pycs.astro.wl.mass_mapping import *
23-
from pycs.sparsity.mrs.mrs_tools import mrs_uwttrans
23+
from pycs.sparsity.mrs.mrs_starlet import mrs_uwttrans
2424

2525

2626
def get_wt_noiselevel(W, NoiseSigmaMap, Mask=None):
@@ -617,7 +617,8 @@ def get_norm_wtl1_sphere(
617617
nbins = 40
618618

619619
# Perform undecimated wavelet transform on the spherical map
620-
WT = mrs_uwttrans(Map, verbose=False, path=path)
620+
# WT = mrs_uwttrans(Map, verbose=False, path=path, cxx=True)
621+
WT = mrs_uwttrans(Map, nscale=nscales, verbose=False, path=path, cxx=False)
621622

622623
l1norm_coll = []
623624
bins_coll = []

pycs/sparsity/mrs/mrs_starlet.py

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pycs.misc.mr_prog import *
1313
from pycs.sparsity.mrs.mrs_tools import *
1414

15+
1516
def mrs_starlet(map, nscale=None, lmax=None):
1617
nside = gnside(map)
1718
if nscale is None:
@@ -20,16 +21,26 @@ def mrs_starlet(map, nscale=None, lmax=None):
2021
Ns = nscale
2122

2223
npix = map.shape[0]
23-
w = wt_trans(map, lmax=lmax,nscales=Ns-1)
24+
w = wt_trans(map, lmax=lmax, nscales=Ns - 1)
2425
trans = w.T
2526
return trans
2627

28+
2729
def mrs_istarlet(trans):
2830
r = np.sum(trans, axis=0)
2931
return r
3032

3133

32-
def mrs_uwttrans(map, nscale=None, lmax=None, opt=None, verbose=False, path="./", progpath=None, cxx=False):
34+
def mrs_uwttrans(
35+
map,
36+
nscale=None,
37+
lmax=None,
38+
opt=None,
39+
verbose=False,
40+
path="./",
41+
progpath=None,
42+
cxx=False,
43+
):
3344
nside = gnside(map)
3445
if nscale is None:
3546
Ns = np.log2(nside) - 2
@@ -58,11 +69,10 @@ def mrs_uwttrans(map, nscale=None, lmax=None, opt=None, verbose=False, path="./"
5869
)
5970
else:
6071
npix = map.shape[0]
61-
w = wt_trans(map, lmax=lmax,nscales=Ns-1)
62-
p = np.zeros(Ns, npix)
72+
w = wt_trans(map, lmax=lmax, nscales=Ns - 1)
73+
p = np.zeros((Ns, npix))
6374
for j in range(Ns):
64-
print(j+1)
65-
p[j,:] = w[:,j]
75+
p[j, :] = w[:, j]
6676

6777
return p
6878

@@ -89,9 +99,9 @@ def mrs_uwtrecons(Tmap, lmax=None, opt=None, verbose=False, path="./", progpath=
8999
return p
90100

91101

92-
93102
# Wavelet filtering
94103

104+
95105
def spline2(size, l, lc):
96106
"""
97107
Compute a non-negative decreasing spline, with value 1 at index 0.
@@ -111,9 +121,20 @@ def spline2(size, l, lc):
111121
(size,) float array, spline
112122
"""
113123

114-
res = np.arange(0, size+1)
115-
res = 2*l*res/(lc*size)
116-
res = (3/2) * 1/12 * (abs(res-2)**3 - 4*abs(res-1)**3 + 6*abs(res)**3 - 4*abs(res+1)**3 + abs(res+2)**3)
124+
res = np.arange(0, size + 1)
125+
res = 2 * l * res / (lc * size)
126+
res = (
127+
(3 / 2)
128+
* 1
129+
/ 12
130+
* (
131+
abs(res - 2) ** 3
132+
- 4 * abs(res - 1) ** 3
133+
+ 6 * abs(res) ** 3
134+
- 4 * abs(res + 1) ** 3
135+
+ abs(res + 2) ** 3
136+
)
137+
)
117138
return res
118139

119140

@@ -134,10 +155,10 @@ def compute_h(size, lc):
134155
(size,) float array, filter
135156
"""
136157

137-
tab1 = spline2(size, 2*lc, 1)
158+
tab1 = spline2(size, 2 * lc, 1)
138159
tab2 = spline2(size, lc, 1)
139-
h = tab1/(tab2+1e-6)
140-
h[np.int64(size/(2*lc)):size] = 0
160+
h = tab1 / (tab2 + 1e-6)
161+
h[np.int64(size / (2 * lc)) : size] = 0
141162
return h
142163

143164

@@ -158,10 +179,10 @@ def compute_g(size, lc):
158179
(size,) float array, filter
159180
"""
160181

161-
tab1 = spline2(size, 2*lc, 1)
182+
tab1 = spline2(size, 2 * lc, 1)
162183
tab2 = spline2(size, lc, 1)
163-
g = (tab2-tab1)/(tab2+1e-6)
164-
g[np.int64(size/(2*lc)):size] = 1
184+
g = (tab2 - tab1) / (tab2 + 1e-6)
185+
g[np.int64(size / (2 * lc)) : size] = 1
165186
return g
166187

167188

@@ -181,9 +202,11 @@ def get_wt_filters(lmax, nscales):
181202
(lmax+1,nscales+1) float array, filters
182203
"""
183204

184-
wt_filters = np.ones((lmax+1, nscales+1))
185-
wt_filters[:, 1:] = np.array([compute_h(lmax, 2**scale) for scale in range(nscales)]).T
186-
wt_filters[:, :nscales] -= wt_filters[:, 1:(nscales+1)]
205+
wt_filters = np.ones((lmax + 1, nscales + 1))
206+
wt_filters[:, 1:] = np.array(
207+
[compute_h(lmax, 2**scale) for scale in range(nscales)]
208+
).T
209+
wt_filters[:, :nscales] -= wt_filters[:, 1 : (nscales + 1)]
187210
return wt_filters
188211

189212

@@ -247,10 +270,10 @@ def wt_trans(inputs, nscales=3, lmax=None, alm_in=False, nside=None, alm_out=Fal
247270
l_scale = alms.copy()
248271
if dim_inputs == 1:
249272
npix = np.size(alms)
250-
wts = np.zeros((npix, nscales + 1), dtype='complex')
273+
wts = np.zeros((npix, nscales + 1), dtype="complex")
251274
else:
252275
npix = np.shape(alms)[1]
253-
wts = np.zeros((np.shape(maps)[0], npix, nscales + 1), dtype='complex')
276+
wts = np.zeros((np.shape(maps)[0], npix, nscales + 1), dtype="complex")
254277

255278
scale = 1
256279
for j in range(nscales):
@@ -295,7 +318,8 @@ def wt_rec(wts):
295318

296319
# Plots
297320

298-
def mrs_tv(maps, log=False, unit='', title='', minimum=None, maximum=None, cbar=True):
321+
322+
def mrs_tv(maps, log=False, unit="", title="", minimum=None, maximum=None, cbar=True):
299323
"""Plot one or more Healpix maps in Mollweide projection.
300324
301325
Parameters
@@ -325,19 +349,31 @@ def mrs_tv(maps, log=False, unit='', title='', minimum=None, maximum=None, cbar=
325349

326350
if minimum is None:
327351
minimum = np.min(maps)
328-
352+
329353
if maximum is None:
330354
maximum = np.max(maps)
331-
355+
332356
if not log:
333-
def f(x): return x
357+
358+
def f(x):
359+
return x
360+
334361
else:
335-
def f(x): return np.log10(x - minimum + 1)
362+
363+
def f(x):
364+
return np.log10(x - minimum + 1)
365+
336366
for i in range(np.shape(maps)[0]):
337367
if title:
338-
tit = title + ": Scale " + str(i+1)
368+
tit = title + ": Scale " + str(i + 1)
339369
else:
340-
tit = "Scale " + str(i+1)
341-
hp.mollview(f(maps[i, :]), fig=None, unit=unit, title=tit, min=f(minimum), max=f(maximum), cbar=cbar)
342-
343-
370+
tit = "Scale " + str(i + 1)
371+
hp.mollview(
372+
f(maps[i, :]),
373+
fig=None,
374+
unit=unit,
375+
title=tit,
376+
min=f(minimum),
377+
max=f(maximum),
378+
cbar=cbar,
379+
)

0 commit comments

Comments
 (0)