Skip to content

Commit c86baa1

Browse files
committed
gmca test
1 parent de55355 commit c86baa1

1 file changed

Lines changed: 91 additions & 1 deletion

File tree

pycs/sparsity/sparse2d/mr_gmca.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pycs.misc.cosmostat_init import *
2020
from pycs.misc.cosmostat_init import writefits
2121

22-
2322
##
2423
# Function that calls mr_gmca to perform blind source separation on the
2524
# input data.
@@ -75,7 +74,9 @@ def mr_gmca(data, opt=None, path="./", remove_files=True, verbose=False, FileOut
7574

7675
result = readfits(file_out_source)
7776
est_mixmat = readfits(file_out_mat)
77+
est_mixmat = est_mixmat.T
7878
est_invmixmat = readfits(file_out_invmat)
79+
est_invmixmat = est_invmixmat.T
7980

8081
# Return the mr_transform results (and the output file names).
8182
if remove_files:
@@ -87,3 +88,92 @@ def mr_gmca(data, opt=None, path="./", remove_files=True, verbose=False, FileOut
8788
return result, est_mixmat, est_invmixmat
8889
else:
8990
return result, est_mixmat, est_invmixmat
91+
92+
93+
# Main
94+
if __name__ == "__main__":
95+
# to run the test, need to install scikit-image and import the two following pakages
96+
from skimage import data, color
97+
from skimage.transform import resize
98+
import matplotlib.pyplot as plt
99+
from pycs.sparsity.sparse2d.bss_eval import *
100+
101+
sources = load_source_images()
102+
# mixed_images, A = mix_sources_images(sources)
103+
# np.random.seed(0)
104+
mixed_images, A = mix_sources_images_noise(sources, noise_level=0.1)
105+
info(sources)
106+
107+
print("Mixing Matrix (A):\n", A)
108+
print("Sources shape:", sources.shape) # (2, H, W)
109+
print("Mixed images shape:", mixed_images.shape) # (3, H, W)
110+
111+
optF = '-S2 -K3 -t14 -n5' # with bi-orthogonal WT abd final denoising at 3sigma
112+
optStarlet = '-S2 -K3 -t2 -n5' # with starlet and final denoising at 3sigma
113+
optCurvelet = '-E1 -S2 -K3 -t28 -n5 ' # with curvelet abd final denoising at 3sigma
114+
115+
corrPerm=False
116+
verbose=False
117+
SRec, Emat, Eimat = mr_gmca(mixed_images, opt=optF, remove_files=False, verbose=verbose)
118+
SRec = reorder_and_fix_sign(sources, SRec)
119+
# print(" ==> Bi-Orth Wavelet Source err = ", compute_sdr(sources, SRec))
120+
# print("A_true shape:", A.shape)
121+
# print("A_est shape:", Emat.shape)
122+
# error = amari_error(A, Emat)
123+
# print(" ==> Mixing matrix err = ", amari_error(A, Emat))
124+
# print("Mixing Matrix (A):\n", Emat)
125+
CA, NMSE = evaluate(A, sources, Emat, SRec, corrPerm=corrPerm)
126+
print(' ==> Bi-Orth Wavelet Source err: CA = %.4f | NMSE= %.4f' % (CA, NMSE))
127+
128+
# Visualization
129+
fig, axs = plt.subplots(1, 5, figsize=(15, 5))
130+
axs[0].imshow(sources[0], cmap='gray')
131+
axs[0].set_title("Source Image 1")
132+
axs[1].imshow(sources[1], cmap='gray')
133+
axs[1].set_title("Source Image 2")
134+
for i in range(3):
135+
axs[i+2].imshow(mixed_images[i], cmap='gray')
136+
axs[i+2].set_title(f"Mixed Image {i+1}")
137+
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
138+
axs[0].imshow(SRec[0], cmap='gray')
139+
axs[0].set_title("7/9 WT GMCA Image 1")
140+
axs[1].imshow(SRec[1], cmap='gray')
141+
axs[1].set_title("7/9 WT GMCA Image 2")
142+
for ax in axs:
143+
ax.axis('off')
144+
plt.tight_layout()
145+
plt.show()
146+
147+
# ----- STARLET ------
148+
SRec, Emat, Eimat = mr_gmca(mixed_images, opt=optStarlet, remove_files=False, verbose=verbose)
149+
SRec = reorder_and_fix_sign(sources, SRec)
150+
CA, NMSE = evaluate(A, sources, Emat, SRec, corrPerm=corrPerm)
151+
print('==> Starlet Source err: CA = %.4f | NMSE = %.4f' % (CA, NMSE))
152+
153+
# Visualization
154+
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
155+
axs[0].imshow(SRec[0], cmap='gray')
156+
axs[0].set_title("Starlet GMCA Image 1")
157+
axs[1].imshow(SRec[1], cmap='gray')
158+
axs[1].set_title("Starlet GMCA Image 2")
159+
for ax in axs:
160+
ax.axis('off')
161+
plt.tight_layout()
162+
plt.show()
163+
164+
# ----- CURVELET ------
165+
SRec, Emat, Eimat = mr_gmca(mixed_images, opt=optCurvelet, remove_files=False, verbose=verbose)
166+
SRec = reorder_and_fix_sign(sources, SRec)
167+
CA, NMSE = evaluate(A, sources, Emat, SRec, corrPerm=corrPerm)
168+
print('==> Curvelet Source err: CA = %.4f | NMSE = %.4f' % (CA, NMSE))
169+
170+
# Visualization
171+
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
172+
axs[0].imshow(SRec[0], cmap='gray')
173+
axs[0].set_title("Curvelet GMCA Image 1")
174+
axs[1].imshow(SRec[1], cmap='gray')
175+
axs[1].set_title("Curvelet GMCA Image 2")
176+
for ax in axs:
177+
ax.axis('off')
178+
plt.tight_layout()
179+
plt.show()

0 commit comments

Comments
 (0)