Skip to content

Commit 808df92

Browse files
authored
Merge pull request #33 from CosmoStat/spericalwavelet
Spericalwavelet
2 parents 9fc6ad0 + 60cf2b2 commit 808df92

3 files changed

Lines changed: 708 additions & 39 deletions

File tree

pycs/sparsity/mrs/mrs_starlet.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
import numpy as np
2+
import random
3+
4+
import os, sys
5+
from scipy import ndimage
6+
import healpy as hp
7+
from astropy.io import fits
8+
import matplotlib.pyplot as plt
9+
from astropy.io import fits
10+
from importlib import reload
11+
from pycs.misc.cosmostat_init import *
12+
from pycs.misc.mr_prog import *
13+
from pycs.sparsity.mrs.mrs_tools import *
14+
15+
def mrs_starlet(map, nscale=None, lmax=None):
16+
nside = gnside(map)
17+
if nscale is None:
18+
Ns = np.int64(np.log2(nside) - 2)
19+
else:
20+
Ns = nscale
21+
22+
npix = map.shape[0]
23+
w = wt_trans(map, lmax=lmax,nscales=Ns-1)
24+
trans = w.T
25+
return trans
26+
27+
def mrs_istarlet(trans):
28+
r = np.sum(trans, axis=0)
29+
return r
30+
31+
32+
def mrs_uwttrans(map, nscale=None, lmax=None, opt=None, verbose=False, path="./", progpath=None, cxx=False):
33+
nside = gnside(map)
34+
if nscale is None:
35+
Ns = np.log2(nside) - 2
36+
else:
37+
Ns = nscale
38+
39+
if cxx:
40+
optParam = " "
41+
if opt is not None:
42+
optParam = " " + opt
43+
if lmax is not None:
44+
optParam = " -l " + str(lmax) + optParam
45+
if nscale is not None:
46+
optParam = " -n " + str(nscale) + optParam
47+
if progpath is None:
48+
prog = "mrs_uwttrans"
49+
else:
50+
prog = progpath + "mrs_uwttrans"
51+
p = mrs_prog(
52+
map,
53+
prog=prog,
54+
verbose=verbose,
55+
opt=optParam,
56+
OutputFormatisHealpix=False,
57+
path=path,
58+
)
59+
else:
60+
npix = map.shape[0]
61+
w = wt_trans(map, lmax=lmax,nscales=Ns-1)
62+
p = np.zeros(Ns, npix)
63+
for j in range(Ns):
64+
print(j+1)
65+
p[j,:] = w[:,j]
66+
67+
return p
68+
69+
70+
def mrs_uwtrecons(Tmap, lmax=None, opt=None, verbose=False, path="./", progpath=None):
71+
optParam = " "
72+
if opt is not None:
73+
optParam = " " + opt
74+
if lmax is not None:
75+
optParam = " -l " + str(lmax) + optParam
76+
if progpath is None:
77+
prog = "mrs_uwttrans"
78+
else:
79+
prog = progpath + "mrs_uwttrans -r "
80+
p = mrs_prog(
81+
Tmap,
82+
prog=prog,
83+
verbose=verbose,
84+
opt=optParam,
85+
InputFormatisHealpix=False,
86+
OutputFormatisHealpix=True,
87+
path=path,
88+
)
89+
return p
90+
91+
92+
93+
# Wavelet filtering
94+
95+
def spline2(size, l, lc):
96+
"""
97+
Compute a non-negative decreasing spline, with value 1 at index 0.
98+
99+
Parameters
100+
----------
101+
size: int
102+
size of the spline
103+
l: float
104+
spline parameter
105+
lc: float
106+
spline parameter
107+
108+
Returns
109+
-------
110+
np.ndarray
111+
(size,) float array, spline
112+
"""
113+
114+
res = np.arange(0, size+1)
115+
res = 2*l*res/(lc*size)
116+
res = (3/2) * 1/12 * (abs(res-2)**3 - 4*abs(res-1)**3 + 6*abs(res)**3 - 4*abs(res+1)**3 + abs(res+2)**3)
117+
return res
118+
119+
120+
def compute_h(size, lc):
121+
"""
122+
Compute a low-pass filter.
123+
124+
Parameters
125+
----------
126+
size: int
127+
size of the filter
128+
lc: float
129+
cutoff parameter
130+
131+
Returns
132+
-------
133+
np.ndarray
134+
(size,) float array, filter
135+
"""
136+
137+
tab1 = spline2(size, 2*lc, 1)
138+
tab2 = spline2(size, lc, 1)
139+
h = tab1/(tab2+1e-6)
140+
h[np.int64(size/(2*lc)):size] = 0
141+
return h
142+
143+
144+
def compute_g(size, lc):
145+
"""
146+
Compute a high-pass filter.
147+
148+
Parameters
149+
----------
150+
size: int
151+
size of the filter
152+
lc: float
153+
cutoff parameter
154+
155+
Returns
156+
-------
157+
np.ndarray
158+
(size,) float array, filter
159+
"""
160+
161+
tab1 = spline2(size, 2*lc, 1)
162+
tab2 = spline2(size, lc, 1)
163+
g = (tab2-tab1)/(tab2+1e-6)
164+
g[np.int64(size/(2*lc)):size] = 1
165+
return g
166+
167+
168+
def get_wt_filters(lmax, nscales):
169+
"""Compute wavelet filters.
170+
171+
Parameters
172+
----------
173+
lmax: int
174+
maximum l
175+
nscales: int
176+
number of wavelet detail scales
177+
178+
Returns
179+
-------
180+
np.ndarray
181+
(lmax+1,nscales+1) float array, filters
182+
"""
183+
184+
wt_filters = np.ones((lmax+1, nscales+1))
185+
wt_filters[:, 1:] = np.array([compute_h(lmax, 2**scale) for scale in range(nscales)]).T
186+
wt_filters[:, :nscales] -= wt_filters[:, 1:(nscales+1)]
187+
return wt_filters
188+
189+
190+
def wt_trans(inputs, nscales=3, lmax=None, alm_in=False, nside=None, alm_out=False):
191+
"""Wavelet transform an array.
192+
193+
Parameters
194+
----------
195+
inputs: np.ndarray
196+
(p,) or (n,p) float array, map or stack of n maps / if alm_in, (t,) or (n,t) complex array, alm or stack
197+
of n alms
198+
nscales: int
199+
number of wavelet detail scales
200+
lmax: int
201+
maximum l (default: 3*nside / if alm_in, deduced from inputs)
202+
alm_in: bool
203+
inputs is alm
204+
nside: int
205+
nside of the output Healpix maps (default: deduced from maps)
206+
alm_out: bool
207+
output is alm
208+
209+
Returns
210+
-------
211+
np.ndarray
212+
(p,nscales+1) or (n,p,scales+1) float array, wavelet transform of the input array or stack of the wavelet
213+
transforms of the n input arrays / if alm_out, (t,nscales+1) or (n,t,scales+1) complex array, alm of the
214+
wavelet transform of the input array or stack of the alms of the wavelet transforms of the n input arrays
215+
"""
216+
dim_inputs = len(np.shape(inputs))
217+
maps = None # to remove warnings
218+
219+
if alm_in:
220+
alms = inputs
221+
if nside is None and not alm_out:
222+
raise ValueError("nside is missing")
223+
if not alm_out:
224+
maps = alm2map(alms, nside)
225+
if lmax is None:
226+
lmax = hp.Alm.getlmax(np.shape(alms)[-1])
227+
228+
else:
229+
maps = inputs
230+
if dim_inputs == 1:
231+
nside = hp.get_nside(maps)
232+
else:
233+
nside = hp.get_nside(maps[0, :])
234+
if lmax is None:
235+
lmax = 3 * nside
236+
alms = map2alm(maps, lmax=lmax)
237+
238+
if not alm_out:
239+
l_scale = maps.copy()
240+
if dim_inputs == 1:
241+
npix = len(maps)
242+
wts = np.zeros((npix, nscales + 1))
243+
else:
244+
npix = np.shape(maps)[1]
245+
wts = np.zeros((np.shape(maps)[0], npix, nscales + 1))
246+
else:
247+
l_scale = alms.copy()
248+
if dim_inputs == 1:
249+
npix = np.size(alms)
250+
wts = np.zeros((npix, nscales + 1), dtype='complex')
251+
else:
252+
npix = np.shape(alms)[1]
253+
wts = np.zeros((np.shape(maps)[0], npix, nscales + 1), dtype='complex')
254+
255+
scale = 1
256+
for j in range(nscales):
257+
h = compute_h(lmax, scale)
258+
if not alm_out:
259+
m = alm2map(alm_product(alms, h), nside)
260+
else:
261+
m = alm_product(alms, h)
262+
h_scale = l_scale - m
263+
l_scale = m
264+
if dim_inputs == 1:
265+
wts[:, j] = h_scale
266+
else:
267+
wts[:, :, j] = h_scale
268+
scale *= 2
269+
270+
if dim_inputs == 1:
271+
wts[:, nscales] = l_scale
272+
else:
273+
wts[:, :, nscales] = l_scale
274+
275+
return wts
276+
277+
278+
def wt_rec(wts):
279+
"""Reconstruct a wavelet decomposition.
280+
281+
Parameters
282+
----------
283+
wts: np.ndarray
284+
(p,nscales+1) or (n,p,scales+1) float array, wavelet transform of a map or stack of the wavelet transforms of n
285+
maps
286+
287+
Returns
288+
-------
289+
np.ndarray
290+
(p,) or (n,p,) float array, reconstructed map or stack of n reconstructed maps
291+
"""
292+
293+
return np.sum(wts, axis=-1)
294+
295+
296+
# Plots
297+
298+
def mrs_tv(maps, log=False, unit='', title='', minimum=None, maximum=None, cbar=True):
299+
"""Plot one or more Healpix maps in Mollweide projection.
300+
301+
Parameters
302+
----------
303+
maps: np.ndarray
304+
(p,) or (n,p) float array, map or stack of n maps
305+
log: bool
306+
logarithmic scale
307+
unit: str
308+
unit of the data
309+
title: str
310+
title of the plots
311+
minimum: float
312+
minimum range value (default: min(maps, maps2))
313+
maximum: float
314+
maximum range value (default: max(maps, maps2))
315+
cbar: bool
316+
show color bar
317+
318+
Returns
319+
-------
320+
None
321+
"""
322+
323+
if len(np.shape(maps)) == 1:
324+
maps = np.expand_dims(maps, axis=0)
325+
326+
if minimum is None:
327+
minimum = np.min(maps)
328+
329+
if maximum is None:
330+
maximum = np.max(maps)
331+
332+
if not log:
333+
def f(x): return x
334+
else:
335+
def f(x): return np.log10(x - minimum + 1)
336+
for i in range(np.shape(maps)[0]):
337+
if title:
338+
tit = title + ": Scale " + str(i+1)
339+
else:
340+
tit = "Scale " + str(i+1)
341+
hp.mollview(f(maps[i, :]), fig=None, unit=unit, title=tit, min=f(minimum), max=f(maximum), cbar=cbar)
342+
343+

0 commit comments

Comments
 (0)