Skip to content

Commit 24c420c

Browse files
author
Jennifer Pollack
committed
Update fixtures and unit tests for training_preprocessing bug dataset_type fix
1 parent 44fb78d commit 24c420c

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

src/wf_psf/tests/test_data/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
)
8383

8484
data = RecursiveNamespace(
85-
train=RecursiveNamespace(
85+
training=RecursiveNamespace(
8686
data_dir="data",
8787
file="coherent_euclid_dataset/train_Euclid_res_200_TrainStars_id_001.npy",
8888
),

src/wf_psf/tests/test_data/training_preprocessing_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ def mock_data():
6565
def test_process_sed_data(data_params, simPSF):
6666
# Test processing SED data without initialization
6767
data_handler = DataHandler(
68-
"train", data_params, simPSF, n_bins_lambda=10, load_data=False
68+
"training", data_params, simPSF, n_bins_lambda=10, load_data=False
6969
)
7070
assert data_handler.sed_data is None # SED data should not be processed
7171

7272
# Test processing SED data with initialization
7373
data_handler = DataHandler(
74-
"train", data_params, simPSF, n_bins_lambda=10, load_data=True
74+
"training", data_params, simPSF, n_bins_lambda=10, load_data=True
7575
)
7676
assert data_handler.sed_data is not None # SED data should be processed
7777

@@ -94,11 +94,11 @@ def test_load_train_dataset(tmp_path, data_params, simPSF):
9494

9595
# Initialize DataHandler instance
9696
data_params = RecursiveNamespace(
97-
train=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy")
97+
training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy")
9898
)
9999

100100
n_bins_lambda = 10
101-
data_handler = DataHandler("train", data_params, simPSF, n_bins_lambda, load_data=False)
101+
data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, load_data=False)
102102

103103
# Call the load_dataset method
104104
data_handler.load_dataset()
@@ -158,15 +158,15 @@ def test_load_train_dataset_missing_noisy_stars(tmp_path, data_params, simPSF):
158158
np.save(temp_data_file, mock_dataset)
159159

160160
data_params = RecursiveNamespace(
161-
train=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy")
161+
training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy")
162162
)
163163

164164
n_bins_lambda = 10
165-
data_handler = DataHandler("train", data_params, simPSF, n_bins_lambda, load_data=False)
165+
data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, load_data=False)
166166

167167
with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning:
168168
data_handler.load_dataset()
169-
mock_warning.assert_called_with("Missing 'noisy_stars' in train dataset.")
169+
mock_warning.assert_called_with("Missing 'noisy_stars' in training dataset.")
170170

171171
def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF):
172172
"""Test that a warning is raised if 'stars' is missing in test data."""
@@ -201,7 +201,7 @@ def test_process_sed_data(data_params, simPSF):
201201
}
202202
# Initialize DataHandler instance
203203
n_bins_lambda = 4
204-
data_handler = DataHandler("train", data_params, simPSF, n_bins_lambda, False)
204+
data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, False)
205205

206206
data_handler.dataset = mock_dataset
207207
data_handler.process_sed_data()

0 commit comments

Comments
 (0)