Skip to content

Commit afa4ad3

Browse files
test/minibatch: use Float64 for time, state, and parameters
The minibatch test was using `Float32` for `u0`, `tspan`, and the network parameters (via the default `Lux.setup` eltype). It is a test of the minibatching loop, not of Float32 throughput, and the Float32 choice exposes a fragile codepath in `SciMLSensitivity`'s `InterpolatingAdjoint` checkpoint re-solve: 1. The sub-checkpoint forward re-solve seeds its initial `dt` with `abs(cpsol_t[end] - cpsol_t[end - 1])` from the previous re-solve. 2. With `Float32` time, an adaptive Tsit5 step that lands one ulp past `interval[2]` is followed by a tiny "correction" step back to `interval[2]`, leaving the last two `cpsol_t` entries 1 ulp apart (~1.2e-7). 3. That ulp value becomes the `dt` for the next re-solve over an interval ~0.052 wide. The Tsit5 controller, given that microscopic seed, ends up emitting a `cpsol` whose time vector has only the start point. 4. On the iteration after that, `cpsol_t[end - 1]` becomes `cpsol_t[0]` and the test errors out with `BoundsError: attempt to access 1-element Vector at index [0]`. `Float64` has ~1e-16 epsilon, so the same ulp-collision pattern produces a `dt` far below any abstol/reltol the controller cares about and the feedback loop never amplifies into a 1-point cpsol. The test passes against the unmodified, registered SciMLSensitivity v7.103.0 with the Float64 change alone, no upstream patch needed. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent c5a4c13 commit afa4ad3

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

test/minibatch.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@ function callback(state, l) #callback function to observe training
2424
return l < 1.0e-2
2525
end
2626

27-
u0 = Float32[200.0]
27+
u0 = [200.0]
2828
datasize = 30
29-
tspan = (0.0f0, 1.5f0)
29+
tspan = (0.0, 1.5)
3030

3131
t = range(tspan[1], tspan[2], length = datasize)
3232
true_prob = ODEProblem(true_sol, u0, tspan)
3333
ode_data = Array(solve(true_prob, Tsit5(), saveat = t))
3434

3535
ann = Lux.Chain(Lux.Dense(1, 8, tanh), Lux.Dense(8, 1, tanh))
3636
pp, st = Lux.setup(rng, ann)
37-
pp = ComponentArray(pp)
37+
pp = ComponentArray{Float64}(pp)
3838

3939
prob = ODEProblem{false}(dudt_, u0, tspan, pp)
4040

0 commit comments

Comments
 (0)