File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
1212- Now targetting torch version 1.12, up from 1.11.
1313- ` OnnxExporter ` accepts a ` device ` argument to enable tracing on other devices.
14+ - ` FinalRewardTestCheck ` can now be configured with another key and to use windowed data.
1415
1516### Deprecations
1617
Original file line number Diff line number Diff line change @@ -33,13 +33,23 @@ def __init__(
3333 callback : LoggingMixin ,
3434 cutoff : float ,
3535 test_length : int ,
36+ key : str = "training/scaled_reward" ,
37+ use_windowed : bool = False ,
3638 ):
3739 super ().__init__ (cycle = test_length )
3840 self ._cb = callback
3941 self ._cutoff = cutoff
42+ self ._key = key
43+ self ._use_windowed = use_windowed
4044
4145 def end_cycle (self ):
42- reward = self ._cb .scalar_logs ["training/scaled_reward" ]
46+ if self ._use_windowed :
47+ data = self ._cb .windowed_scalar [self ._key ]
48+ reward = sum (data ) / len (data )
49+ else :
50+ reward = self ._cb .scalar_logs [self ._key ]
51+
4352 if reward < self ._cutoff :
4453 raise Exception (f"Reward too low: { reward } " )
54+
4555 raise TrainingShutdownException ()
You can’t perform that action at this time.
0 commit comments