1919from pycs .misc .cosmostat_init import *
2020from pycs .misc .cosmostat_init import writefits
2121
22-
2322##
2423# Function that calls mr_gmca to perform blind source separation on the
2524# input data.
@@ -75,7 +74,9 @@ def mr_gmca(data, opt=None, path="./", remove_files=True, verbose=False, FileOut
7574
7675 result = readfits (file_out_source )
7776 est_mixmat = readfits (file_out_mat )
77+ est_mixmat = est_mixmat .T
7878 est_invmixmat = readfits (file_out_invmat )
79+ est_invmixmat = est_invmixmat .T
7980
8081 # Return the mr_transform results (and the output file names).
8182 if remove_files :
@@ -87,3 +88,92 @@ def mr_gmca(data, opt=None, path="./", remove_files=True, verbose=False, FileOut
8788 return result , est_mixmat , est_invmixmat
8889 else :
8990 return result , est_mixmat , est_invmixmat
91+
92+
93+ # Main
94+ if __name__ == "__main__" :
95+ # to run the test, need to install scikit-image and import the two following pakages
96+ from skimage import data , color
97+ from skimage .transform import resize
98+ import matplotlib .pyplot as plt
99+ from pycs .sparsity .sparse2d .bss_eval import *
100+
101+ sources = load_source_images ()
102+ # mixed_images, A = mix_sources_images(sources)
103+ # np.random.seed(0)
104+ mixed_images , A = mix_sources_images_noise (sources , noise_level = 0.1 )
105+ info (sources )
106+
107+ print ("Mixing Matrix (A):\n " , A )
108+ print ("Sources shape:" , sources .shape ) # (2, H, W)
109+ print ("Mixed images shape:" , mixed_images .shape ) # (3, H, W)
110+
111+ optF = '-S2 -K3 -t14 -n5' # with bi-orthogonal WT abd final denoising at 3sigma
112+ optStarlet = '-S2 -K3 -t2 -n5' # with starlet and final denoising at 3sigma
113+ optCurvelet = '-E1 -S2 -K3 -t28 -n5 ' # with curvelet abd final denoising at 3sigma
114+
115+ corrPerm = False
116+ verbose = False
117+ SRec , Emat , Eimat = mr_gmca (mixed_images , opt = optF , remove_files = False , verbose = verbose )
118+ SRec = reorder_and_fix_sign (sources , SRec )
119+ # print(" ==> Bi-Orth Wavelet Source err = ", compute_sdr(sources, SRec))
120+ # print("A_true shape:", A.shape)
121+ # print("A_est shape:", Emat.shape)
122+ # error = amari_error(A, Emat)
123+ # print(" ==> Mixing matrix err = ", amari_error(A, Emat))
124+ # print("Mixing Matrix (A):\n", Emat)
125+ CA , NMSE = evaluate (A , sources , Emat , SRec , corrPerm = corrPerm )
126+ print (' ==> Bi-Orth Wavelet Source err: CA = %.4f | NMSE= %.4f' % (CA , NMSE ))
127+
128+ # Visualization
129+ fig , axs = plt .subplots (1 , 5 , figsize = (15 , 5 ))
130+ axs [0 ].imshow (sources [0 ], cmap = 'gray' )
131+ axs [0 ].set_title ("Source Image 1" )
132+ axs [1 ].imshow (sources [1 ], cmap = 'gray' )
133+ axs [1 ].set_title ("Source Image 2" )
134+ for i in range (3 ):
135+ axs [i + 2 ].imshow (mixed_images [i ], cmap = 'gray' )
136+ axs [i + 2 ].set_title (f"Mixed Image { i + 1 } " )
137+ fig , axs = plt .subplots (1 , 2 , figsize = (15 , 5 ))
138+ axs [0 ].imshow (SRec [0 ], cmap = 'gray' )
139+ axs [0 ].set_title ("7/9 WT GMCA Image 1" )
140+ axs [1 ].imshow (SRec [1 ], cmap = 'gray' )
141+ axs [1 ].set_title ("7/9 WT GMCA Image 2" )
142+ for ax in axs :
143+ ax .axis ('off' )
144+ plt .tight_layout ()
145+ plt .show ()
146+
147+ # ----- STARLET ------
148+ SRec , Emat , Eimat = mr_gmca (mixed_images , opt = optStarlet , remove_files = False , verbose = verbose )
149+ SRec = reorder_and_fix_sign (sources , SRec )
150+ CA , NMSE = evaluate (A , sources , Emat , SRec , corrPerm = corrPerm )
151+ print ('==> Starlet Source err: CA = %.4f | NMSE = %.4f' % (CA , NMSE ))
152+
153+ # Visualization
154+ fig , axs = plt .subplots (1 , 2 , figsize = (15 , 5 ))
155+ axs [0 ].imshow (SRec [0 ], cmap = 'gray' )
156+ axs [0 ].set_title ("Starlet GMCA Image 1" )
157+ axs [1 ].imshow (SRec [1 ], cmap = 'gray' )
158+ axs [1 ].set_title ("Starlet GMCA Image 2" )
159+ for ax in axs :
160+ ax .axis ('off' )
161+ plt .tight_layout ()
162+ plt .show ()
163+
164+ # ----- CURVELET ------
165+ SRec , Emat , Eimat = mr_gmca (mixed_images , opt = optCurvelet , remove_files = False , verbose = verbose )
166+ SRec = reorder_and_fix_sign (sources , SRec )
167+ CA , NMSE = evaluate (A , sources , Emat , SRec , corrPerm = corrPerm )
168+ print ('==> Curvelet Source err: CA = %.4f | NMSE = %.4f' % (CA , NMSE ))
169+
170+ # Visualization
171+ fig , axs = plt .subplots (1 , 2 , figsize = (15 , 5 ))
172+ axs [0 ].imshow (SRec [0 ], cmap = 'gray' )
173+ axs [0 ].set_title ("Curvelet GMCA Image 1" )
174+ axs [1 ].imshow (SRec [1 ], cmap = 'gray' )
175+ axs [1 ].set_title ("Curvelet GMCA Image 2" )
176+ for ax in axs :
177+ ax .axis ('off' )
178+ plt .tight_layout ()
179+ plt .show ()
0 commit comments