Skip to content

Commit 7b513f1

Browse files
authored
cherry-pick changes to reward-test-check (#166)
Picking some changes from #150 to work towards landing that ~soon.
1 parent a8d1db1 commit 7b513f1

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- Now targetting torch version 1.12, up from 1.11.
1313
- `OnnxExporter` accepts a `device` argument to enable tracing on other devices.
14+
- `FinalRewardTestCheck` can now be configured with another key and to use windowed data.
1415

1516
### Deprecations
1617

emote/callbacks/testing.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,23 @@ def __init__(
3333
callback: LoggingMixin,
3434
cutoff: float,
3535
test_length: int,
36+
key: str = "training/scaled_reward",
37+
use_windowed: bool = False,
3638
):
3739
super().__init__(cycle=test_length)
3840
self._cb = callback
3941
self._cutoff = cutoff
42+
self._key = key
43+
self._use_windowed = use_windowed
4044

4145
def end_cycle(self):
42-
reward = self._cb.scalar_logs["training/scaled_reward"]
46+
if self._use_windowed:
47+
data = self._cb.windowed_scalar[self._key]
48+
reward = sum(data) / len(data)
49+
else:
50+
reward = self._cb.scalar_logs[self._key]
51+
4352
if reward < self._cutoff:
4453
raise Exception(f"Reward too low: {reward}")
54+
4555
raise TrainingShutdownException()

0 commit comments

Comments
 (0)