-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathforecast.py
More file actions
209 lines (166 loc) · 6.41 KB
/
forecast.py
File metadata and controls
209 lines (166 loc) · 6.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
from typing import Tuple, List
import torch
import numpy as np
from integrators import (
euler_integrate_batched,
rk4_integrate_batched,
)
def forecast_distribution_batched(
model,
x_tau: torch.Tensor,
S: int = 50,
steps: int = 12,
integrator: str = "euler",
device: torch.device = torch.device("cpu"),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
One-step predictive distribution via CFM memory ODE.
Implements:
Z_0 ∼ N(x_τ, σ_min^2 I_d),
dZ_t/dt = v(t,Z_t),
x_{τ+1} ≈ Z_1.
Args:
model: CFM forecaster (DenseCFMForecaster or TopKCFMForecaster),
with model.hp.sigma_min and .drift(z,t).
x_tau: (d,) current state.
S: number of Monte Carlo samples.
steps: ODE integration steps for t ∈ [0,1].
integrator: "euler" or "rk4".
device: computation device.
Returns:
mu: (d,) empirical mean of Z_1 samples.
cov_est: (d,d) empirical covariance of Z_1 samples.
samples: (S,d) all Z_1 samples.
"""
x_tau = x_tau.to(device=device, dtype=torch.float32)
d = x_tau.shape[0]
sigma_min = float(model.hp.sigma_min)
cov0 = (sigma_min ** 2) * torch.eye(d, device=device)
dist0 = torch.distributions.MultivariateNormal(x_tau, covariance_matrix=cov0)
# Z_0 samples: (S, d)
z0 = dist0.sample((S,))
f = lambda z, t: model.drift(z, t)
if integrator == "euler":
z1 = euler_integrate_batched(f, z0, steps=steps)
elif integrator == "rk4":
z1 = rk4_integrate_batched(f, z0, steps=steps)
else:
raise ValueError(f"Unknown integrator: {integrator}")
# Empirical mean and covariance
mu = z1.mean(dim=0) # (d,)
cov_est = torch.cov(z1.T) if S > 1 else torch.zeros(d, d, device=device)
return mu, cov_est, z1
def multi_step_forecast_torch(
model,
seq_np: np.ndarray,
horizon: int,
steps: int = 12,
S: int = 50,
integrator: str = "euler",
device: torch.device = torch.device("cpu"),
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Autoregressive multi-step predictor with Particle Propagation
Instead of collapsing to the mean at each step, we propagate the entire
ensemble of S particles.
Args:
model: CFM forecaster.
seq_np: observed prefix of one trajectory, shape (T_obs, d).
horizon: number of steps to forecast.
steps: ODE integration steps per forecast.
S: number of MC samples.
integrator: "euler" or "rk4".
device: torch device.
Returns:
pred_mu: (horizon, d) predicted mean at each future step (for metrics).
pred_samples: list of length horizon,
each element is a (S, d) tensor of samples for that step.
"""
seq = torch.tensor(seq_np, dtype=torch.float32, device=device)
x_last = seq[-1] # The last observed state x_tau
d = x_last.shape[0]
sigma_min = float(model.hp.sigma_min)
cov0 = (sigma_min ** 2) * torch.eye(d, device=device)
dist0 = torch.distributions.MultivariateNormal(x_last, covariance_matrix=cov0)
# Current state of the ensemble Z: (S, d)
z_particles = dist0.sample((S,))
pred_mu = []
pred_samples: List[torch.Tensor] = []
f = lambda z, t: model.drift(z, t)
for _ in range(horizon):
if integrator == "euler":
z_next = euler_integrate_batched(f, z_particles, steps=steps)
elif integrator == "rk4":
z_next = rk4_integrate_batched(f, z_particles, steps=steps)
else:
raise ValueError(f"Unknown integrator: {integrator}")
mu_step = z_next.mean(dim=0) # (d,)
pred_mu.append(mu_step.unsqueeze(0))
pred_samples.append(z_next)
z_particles = z_next
pred_mu = torch.cat(pred_mu, dim=0) # (horizon, d)
return pred_mu, pred_samples
def generate_trajectories_torch(
model,
z0_np: np.ndarray,
horizon: int,
dt: float,
S: int = 50,
steps: int = 100,
integrator: str = "euler",
device: torch.device = torch.device("cpu"),
) -> Tuple[np.ndarray, np.ndarray]:
"""
Generate an ensemble of unconditional trajectories from the CFM memory ODE.
This differs from multi-step forecasting in that we:
- start from a Gaussian cloud around a single initial condition z0,
- evolve each particle forward for 'horizon' steps,
- and do NOT condition on an observed prefix (it's purely model dynamics).
Mathematically, for each ensemble member s:
Z_0^{(s)} ∼ N(z0, σ_min^2 I_d),
Z_{k+1}^{(s)} = Φ(Z_k^{(s)}), k = 0,...,horizon-1,
where Φ is the flow over t ∈ [0,1] induced by dZ_t/dt = v(t, Z_t).
Args
----
model : CFM forecaster
z0_np : np.ndarray, shape (d,)
Initial condition around which we sample the ensemble.
horizon : int
Number of discrete steps to generate (so we return horizon+1 time points).
dt : float
Physical time step size for the output time array (for plotting).
S : int
Number of ensemble members (trajectories).
steps : int
ODE integration steps per unit interval [0,1].
integrator : {"euler", "rk4"}
ODE integrator.
device : torch.device
Returns
-------
time : np.ndarray, shape (horizon+1,)
Discrete time points 0, dt, 2dt, ..., horizon*dt.
traj : np.ndarray, shape (S, horizon+1, d)
Generated trajectories. traj[s, k, :] is Z_k^{(s)}.
"""
z0_np = np.asarray(z0_np, dtype=np.float32)
d = z0_np.shape[0]
z0 = torch.tensor(z0_np, dtype=torch.float32, device=device)
sigma_min = float(model.hp.sigma_min)
cov0 = (sigma_min ** 2) * torch.eye(d, device=device)
dist0 = torch.distributions.MultivariateNormal(z0, covariance_matrix=cov0)
# Initial ensemble Z_0^{(s)}
Z = dist0.sample((S,)) # (S, d)
traj = torch.empty(S, horizon + 1, d, device=device)
traj[:, 0, :] = Z
f = lambda z, t: model.drift(z, t)
for t_idx in range(1, horizon + 1):
if integrator == "euler":
Z = euler_integrate_batched(f, Z, steps=steps)
elif integrator == "rk4":
Z = rk4_integrate_batched(f, Z, steps=steps)
else:
raise ValueError(f"Unknown integrator: {integrator}")
traj[:, t_idx, :] = Z
time = torch.arange(horizon + 1, device=device).float() * dt
return time.cpu().numpy(), traj.detach().cpu().numpy()