Skip to content

Commit 24ffc08

Browse files
author
Jennifer Pollack
committed
Update authors lists, doc strings, and code organisation
1 parent 24c420c commit 24ffc08

2 files changed

Lines changed: 75 additions & 34 deletions

File tree

src/wf_psf/training/train.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
A module which defines the classes and methods
44
to manage training of the psf model.
55
6-
:Author: Jennifer Pollack <jennifer.pollack@cea.fr>
6+
:Authors: Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>, Ezequiel Centofanti <ezequiel.centofanti@cea.fr>
77
88
"""
99

@@ -18,12 +18,26 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21+
def get_gpu_info():
22+
"""Get GPU Information.
23+
24+
A function to return GPU
25+
device name.
26+
27+
Returns
28+
-------
29+
device_name: str
30+
Name of GPU device
31+
32+
"""
33+
device_name = tf.test.gpu_device_name()
34+
return device_name
35+
2136
def setup_training():
2237
"""Set up Training.
2338
2439
A function to setup training.
2540
26-
2741
"""
2842
device_name = get_gpu_info()
2943
logger.info(f"Found GPU at: {device_name}")
@@ -257,23 +271,32 @@ def _prepare_callbacks(
257271
)
258272

259273

260-
def get_gpu_info():
261-
"""Get GPU Information.
262-
263-
A function to return GPU
264-
device name.
274+
def get_loss_metrics_monitor_and_outputs(training_handler, data_conf):
275+
"""Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle.
276+
277+
Parameters
278+
----------
279+
training_handler: TrainingParamsHandler
280+
TrainingParamsHandler object containing training parameters
281+
data_conf: object
282+
Data configuration object containing training and test data
265283
266284
Returns
267285
-------
268-
device_name: str
269-
Name of GPU device
270-
286+
loss: tf.keras.losses.Loss
287+
Loss function to be used for training
288+
param_metrics: list
289+
List of metrics for the parametric model
290+
non_param_metrics: list
291+
List of metrics for the non-parametric model
292+
monitor: str
293+
Metric to monitor for saving the model
294+
outputs: tf.Tensor
295+
Tensor containing the outputs for training
296+
output_val: tf.Tensor
297+
Tensor containing the outputs for validation
298+
271299
"""
272-
device_name = tf.test.gpu_device_name()
273-
return device_name
274-
275-
def get_loss_metrics_monitor_and_outputs(training_handler, data_conf):
276-
"""Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle."""
277300

278301
if training_handler.training_hparams.loss == "mask_mse":
279302
loss = train_utils.MaskedMeanSquaredError()
@@ -304,25 +327,46 @@ def train(
304327
optimizer_dir,
305328
psf_model_dir,
306329
):
307-
"""Train.
330+
"""
331+
Train the PSF model over one or more parametric and non-parametric training cycles.
308332
309-
A function to train the psf model.
333+
This function manages multi-cycle training of a parametric + non-parametric PSF model,
334+
including initialization, loss/metric configuration, optimizer setup, model checkpointing,
335+
and optional projection or resetting of non-parametric features. Each cycle can include
336+
both parametric and non-parametric training stages, and training history is saved for each.
310337
311338
Parameters
312339
----------
313-
training_params: Recursive Namespace object
314-
Recursive Namespace object containing the training parameters
315-
training_data: obj
316-
TrainingDataHandler object containing the training data parameters
317-
test_data: object
318-
TestDataHandler object containing the test data parameters
319-
checkpoint_dir: str
320-
Absolute path to checkpoint directory
321-
optimizer_dir: str
322-
Absolute path to optimizer history directory
323-
psf_model_dir: str
324-
Absolute path to psf model directory
340+
training_params : RecursiveNamespace
341+
Contains all training configuration parameters, including:
342+
- learning rates per cycle
343+
- number of epochs per component per cycle
344+
- model type and training behavior flags
345+
- multi-cycle definitions and callbacks
346+
347+
data_conf : object
348+
Contains training and validation datasets via attributes:
349+
- data_conf.training_data: TrainingDataHandler instance with SEDs and positions
350+
- data_conf.test_data: TestDataHandler instance with validation SEDs and positions
351+
352+
checkpoint_dir : str
353+
Directory where model checkpoints will be saved during training.
354+
355+
optimizer_dir : str
356+
Directory where the optimizer history (as a NumPy .npy file) will be stored.
357+
358+
psf_model_dir : str
359+
Directory where the final trained PSF model weights will be saved per cycle.
360+
361+
Returns
362+
-------
363+
None
325364
365+
Side Effects
366+
------------
367+
- Saves model weights to `psf_model_dir` per training cycle (or final one if not all saved)
368+
- Saves optimizer histories to `optimizer_dir`
369+
- Logs cycle information and time durations
326370
"""
327371
# Start measuring elapsed time
328372
starting_time = time.time()

src/wf_psf/training/train_utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
process for the PSF model. These functions help with managing training cycles,
66
callbacks, and related operations.
77
8-
Author: Tobias Liaudat <tobias.liaudat@cea.fr>
8+
Authors: Tobias Liaudat <tobias.liaudat@cea.fr>, Ezequiel Centofanti <ezequiel.centofanti@cea.fr>,
9+
Jennifer Pollack <jennifer.pollack@cea.fr>
910
"""
1011

1112
import numpy as np
@@ -470,7 +471,6 @@ def train_cycle_part(
470471
callbacks: Optional[list[Callable]] = None,
471472
sample_weight: Optional[tf.Tensor] = None,
472473
verbose: int = 1,
473-
first_run: bool = False,
474474
cycle_part: str = "parametric",
475475
) -> tf.keras.Model:
476476
"""
@@ -514,9 +514,6 @@ def train_cycle_part(
514514
verbose: int, optional
515515
Verbosity mode (0, 1, or 2). Default is 1.
516516
517-
first_run: bool, optional
518-
Flag indicating if this is the first run (affects how the model is built). Default is False.
519-
520517
cycle_part: str, optional
521518
Specifies which part of the model to train ("parametric" or "non-parametric"). Default is "parametric".
522519

0 commit comments

Comments
 (0)