Skip to content

Commit 7c4b515

Browse files
authored
Merge pull request #142 from CosmoStat/case_study_psf_decontamination
WaveDiff: New PSF Model Features, Refactoring, Rotation Obscuration Mask, and PEP8 Compliance This commit merges the 'case_study_psf_decontamination' feature/refactor branch into `develop`, introducing significant new features, architectural changes, bug fixes, and CI/tooling improvements. Highlights: New Features --------------- - Added physical Zernike prior layer to PSF model (Closes #123) - Added obscuration mask rotation support to PSF models - Included module for simulating spatially-varying PSFs with validation tests (Closes #116, #133) - Integrated centroid estimation correction (Closes #113) - Corrected CCD z-axis misalignment handling - Added phase retrieval module with random seed support (Closes #91) - Added new data generation script to support structured model validation Bug Fixes ----------- - Fixed random seed bug in data-driven PSF reinitialization (Closes #132) Refactors and Internal Architecture ------------------------------------- - Replaced `TrainingDataHandler` and `TestDataHandler` with unified `DataHandler` - Refactored `train_utils.py` and `train.py` using the Strategy Pattern (Closes #89, #152) - Introduced helper functions and modular callbacks - Added custom loss/metric functions for masked datasets (e.g. for centroid and sample weights) - Renamed functions, classes, and variables to comply with PEP8 (Closes #139) - Removed deprecated code and unused imports Testing and Validation ------------------------- - Added `test_training/` with unit tests for `train.py` and `train_utils.py` CI / Tooling Updates ------------------------ - Replaced `black` with `ruff` for code formatting (Closes #119, #120, #145) - Unpinned `pytest` (Closes #125) - Increased TensorFlow version to >=2.11.0 (Closes #90) ⚠️ Notes -------- - This commit is large and includes many interdependent changes. Follow-up work includes: - Further refactoring of the data generation script - Cleanup of temporary or duplicate logic introduced during the transition - Finalisation of naming schemes and docstrings across modules - Preparation of a release-ready, slimmed-down develop branch Co-authored-by many developers
2 parents d18eb73 + b9f394d commit 7c4b515

120 files changed

Lines changed: 14570 additions & 3089 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/ci.yml

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,7 @@ jobs:
2828

2929
- name: Install dependencies
3030
run: python -m pip install ".[test]"
31-
31+
3232
- name: Test with pytest
3333
run: python -m pytest
3434

35-
# Add Black formatter
36-
- name: Install Black formatter
37-
run: python -m pip install black
38-
39-
- name: Check code formatting with Black
40-
run: black . --check --diff
41-
42-

config/data_config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Training and test data sets for training and/or metrics evaluation
22
data:
33
training:
4-
# Specify directory path to data; Default setting is /path/to/repo/data
4+
# Specify directory path to training dataset
55
data_dir: data/coherent_euclid_dataset/
6+
# Provide name of training dataset
67
file: train_Euclid_res_200_TrainStars_id_001.npy
78
# if training data set file does not exist, generate a new one by setting values below
89
stars: null
@@ -26,7 +27,9 @@ data:
2627
euclid_obsc: true
2728
n_stars: 200
2829
test:
30+
# Specify directory path to training dataset
2931
data_dir: data/coherent_euclid_dataset/
32+
# Provide name of test dataset
3033
file: test_Euclid_res_id_001.npy
3134
# If test data set file not provided produce a new one
3235
stars: null

config/metrics_config.yaml

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,85 @@
11
metrics:
22
# Specify the type of model weights to load by entering "psf_model" to load weights of final psf model or "checkpoint" to load weights from a checkpoint callback.
33
model_save_path: <enter psf_model or checkpoint>
4+
45
# Choose the training cycle for which to evaluate the psf_model. Can be: 1, 2, ...
56
saved_training_cycle: 2
7+
68
# Metrics-only run: Specify model_params for a pre-trained model else leave blank if running training + metrics
79
# Specify path to Parent Directory of Trained Model
810
trained_model_path: </path/to/parent/directory/of/trained/model>
11+
912
# Name of the Trained Model Config file stored in config sub-directory in the trained_model_path parent directory
1013
trained_model_config: <enter name of trained model config file>
11-
#Evaluate the monchromatic RMSE metric.
12-
eval_mono_metric_rmse: True
13-
#Evaluate the OPD RMSE metric.
14-
eval_opd_metric_rmse: True
15-
#Evaluate the super-resolution and the shape RMSE metrics for the train dataset.
16-
eval_train_shape_sr_metric_rmse: True
14+
15+
# Evaluate the monchromatic RMSE metric.
16+
eval_mono_metric: True
17+
18+
# Evaluate the OPD RMSE metric.
19+
eval_opd_metric: True
20+
21+
# Evaluate the super-resolution and the shape RMSE metrics for the train dataset.
22+
eval_train_shape_results_dict: False
23+
24+
# Evaluate the super-resolution and the shape RMSE metrics for the test dataset.
25+
eval_test_shape_results_dict: False
26+
1727
# Name of Plotting Config file - Enter name of yaml file to run plot metrics else if empty run metrics evaluation only
1828
plotting_config: <enter name of plotting_config .yaml file or leave empty>
29+
1930
ground_truth_model:
2031
model_params:
21-
#Model used as ground truth for the evaluation. Options are: 'poly' for polychromatic and 'physical' [not available].
22-
model_name: poly
32+
# PSF model used as ground truth for the evaluation. Options are: 'ground_truth_poly' for polychromatic and 'ground_truth_physical_poly' for polychromatic model with the physical layer extension PSF models.
33+
model_name: <ground_truth_poly or ground_truth_physical_poly>
2334

2435
# Evaluation parameters
25-
#Number of bins used for the ground truth model poly PSF generation
36+
# Number of bins used for the ground truth model poly PSF generation
2637
n_bins_lda: 20
2738

28-
#Downsampling rate to match the oversampled model to the specified telescope's sampling.
39+
# Downsampling rate to match the oversampled model to the specified telescope's sampling.
2940
output_Q: 3
3041

31-
#Oversampling rate used for the OPD/WFE PSF model.
42+
# Oversampling rate used for the OPD/WFE PSF model.
3243
oversampling_rate: 3
3344

34-
#Dimension of the pixel PSF postage stamp
45+
# Dimension of the pixel PSF postage stamp
3546
output_dim: 32
3647

37-
#Dimension of the OPD/Wavefront space."
48+
# Dimension of the OPD/Wavefront space."
3849
pupil_diameter: 256
3950

40-
#Boolean to define if we use sample weights based on the noise standard deviation estimation
51+
# Top-hat filter to avoid the aliasing effect in the obscuration mask
52+
LP_filter_length: 2
53+
54+
# Boolean to define if we use sample weights based on the noise standard deviation estimation
4155
use_sample_weights: True
4256

43-
#Interpolation type for the physical poly model. Options are: 'none', 'all', 'top_K', 'independent_Zk'."
57+
# Flag to use Zernike prior
58+
use_prior: False
59+
60+
# Correct centroids
61+
correct_centroids: False
62+
63+
# Sigma of the window function used to compute the centroids. Default Euclid value is 2.5
64+
sigma_centroid_window: 2.5
65+
66+
# Default reference_shifts value for observations at Euclid conditions, i.e., pixel sampling and telescope parameters.
67+
reference_shifts: [-1/3, -1/3]
68+
69+
# Rotation angle (in degrees) for the obscuration mask.
70+
# Must be a multiple of 90 (0, 90, 180, 270, etc.).
71+
# Rotation is counterclockwise.
72+
obscuration_rotation_angle: 0
73+
74+
# Consider CCD missalignments
75+
add_ccd_misalignments: False
76+
77+
# CCD missalignments input file path
78+
# This should be refactored. It might be better to directly look for the `tiles.npy` in
79+
# the `data/assets/` directory in the repository
80+
ccd_misalignments_input_path:
81+
82+
# Interpolation type for the physical poly model. Options are: 'none', 'all', 'top_K', 'independent_Zk'."
4483
interpolation_type: None
4584

4685
# SED intepolation points per bin
@@ -55,68 +94,68 @@ metrics:
5594
# Standard deviation of the multiplicative SED Gaussian noise.
5695
sed_sigma: 0
5796

58-
#Limits of the PSF field coordinates for the x axis.
97+
# Limits of the PSF field coordinates for the x axis.
5998
x_lims: [0.0, 1.0e+3]
6099

61-
#Limits of the PSF field coordinates for the y axis.
100+
# Limits of the PSF field coordinates for the y axis.
62101
y_lims: [0.0, 1.0e+3]
63102

64-
# Hyperparameters for Parametric model
103+
# Hyperparameters for the Parametric model
65104
param_hparams:
66105
# Random seed for Tensor Flow Initialization
67106
random_seed: 3877572
68107

69108
# Parameter for the l2 loss function for the Optical path differences (OPD)/WFE
70109
l2_param: 0.
71110

72-
#Zernike polynomial modes to use on the parametric part.
111+
# Zernike polynomial modes to use on the parametric part.
73112
n_zernikes: 45
74113

75-
#Max polynomial degree of the parametric part.
114+
# Max polynomial degree of the parametric part.
76115
d_max: 2
77116

78-
#Flag to save optimisation history for parametric model
117+
# Flag to save optimisation history for parametric model
79118
save_optim_history_param: true
80119

81120
# Hyperparameters for non-parametric model
82121
nonparam_hparams:
83-
#Max polynomial degree of the non-parametric part.
122+
# Max polynomial degree of the non-parametric part.
84123
d_max_nonparam: 5
85124

86125
# Number of graph features
87126
num_graph_features: 10
88127

89-
#L1 regularisation parameter for the non-parametric part."
128+
# L1 regularisation parameter for the non-parametric part."
90129
l1_rate: 1.0e-8
91130

92-
#Flag to enable Projected learning for DD_features to be used with `poly` or `semiparametric` model.
131+
# Flag to enable Projected learning for DD_features to be used with `poly` or `semiparametric` model.
93132
project_dd_features: False
94133

95-
#Flag to reset DD_features to be used with `poly` or `semiparametric` model
134+
# Flag to reset DD_features to be used with `poly` or `semiparametric` model
96135
reset_dd_features: False
97136

98-
#Flag to save optimisation history for non-parametric model
137+
# Flag to save optimisation history for non-parametric model
99138
save_optim_history_nonparam: True
100139

101140
metrics_hparams:
102141
# Batch size to use for the evaluation.
103142
batch_size: 16
104143

105-
#Save RMS error for each super resolved PSF in the test dataset in addition to the mean across the FOV."
106-
#Flag to get Super-Resolution pixel PSF RMSE for each individual test star.
107-
#If `True`, the relative pixel RMSE of each star is added to ther saving dictionary.
144+
# Save RMS error for each super resolved PSF in the test dataset in addition to the mean across the FOV."
145+
# Flag to get Super-Resolution pixel PSF RMSE for each individual test star.
146+
# If `True`, the relative pixel RMSE of each star is added to ther saving dictionary.
108147
opt_stars_rel_pix_rmse: False
109148

110149
## Specific parameters
111150
# Parameter for the l2 loss of the OPD.
112151
l2_param: 0.
113152

114153
## Define the resolution at which you'd like to measure the shape of the PSFs
115-
#Downsampling rate from the high-resolution pixel modelling space.
154+
# Downsampling rate from the high-resolution pixel modelling space.
116155
# Recommended value: 1
117156
output_Q: 1
118157

119-
#Dimension of the pixel PSF postage stamp; it should be big enough so that most of the signal is contained inside the postage stamp.
158+
# Dimension of the pixel PSF postage stamp; it should be big enough so that most of the signal is contained inside the postage stamp.
120159
# It also depends on the Q values used.
121160
# Recommended value: 64 or higher
122161
output_dim: 64

0 commit comments

Comments
 (0)