Add Antweiler-Freyberger (2025) iterative quadrature estimator#89
Add Antweiler-Freyberger (2025) iterative quadrature estimator#89hmgaudecker wants to merge 103 commits into
Conversation
New af/ subpackage implementing period-by-period MLE with Halton quadrature as an alternative to the CHS Kalman filter estimator. Same ModelSpec interface, JAX AD for gradients, arbitrary factor count. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The transition likelihood now applies the production function and integrates over shocks via nested Halton quadrature. Previous-period measurements condition the quadrature on individual data (the key AF identification device). State propagation uses quadrature-based moment matching. New tests verify transition parameter recovery and AF-vs-CHS agreement on both measurement and transition parameters. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Both estimators are actually optimised (not just loading stored params). Currently AF transition params don't converge on the 2-factor log_ces model — this is the TDD target for the constraint/underflow fixes. Skipped in CI via `long_running` marker; run with `-m long_running`. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Both estimators now start from: loadings=1, controls=0, everything else=0.5, probability constraints satisfied with equal shares. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Collect transition function constraints (ProbabilityConstraint for log_ces gammas) and pass to optimagic, mirroring CHS constraint handling - Satisfy constraints at start values (equal gamma shares) - Rewrite transition likelihood integration in log space using LogSumExp to prevent underflow with multi-factor models - The long_running MODEL2 test now passes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Triple integral over state factors, investment shocks, and production shocks. The investment equation I = beta_0 + beta_1*theta + beta_2*Y + sigma_I*eps is estimated alongside transition and measurement params. Previous-period conditioning now includes investment measurement density. ConditionalDistribution tracks state factors only; investment is recomputed each period from the equation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Users can pass a DataFrame of starting values to estimate_af(). Matching index entries override heuristic defaults; unmatched and fixed parameters are left unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #89 +/- ##
==========================================
- Coverage 96.91% 95.28% -1.63%
==========================================
Files 57 105 +48
Lines 4952 10212 +5260
==========================================
+ Hits 4799 9731 +4932
- Misses 153 481 +328 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Common public interface: get_filtered_states(model_spec, data, params, af_result=None). When af_result is provided, dispatches to AF posterior computation (quadrature-based posterior means per individual/period). Internally uses af/posterior_states.py. Returns "unanchored_states" matching the CHS output format. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Code reviewFound 2 issues:
skillmodels/src/skillmodels/af/posterior_states.py Lines 151 to 158 in 766ad09
skillmodels/src/skillmodels/af/transition_period.py Lines 246 to 250 in 766ad09 🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
1. Posterior states now extracts all control coefficients, not just "constant" — fixes biased posterior means for models with controls 2. Distribution propagation uses population mean of observed factors instead of first individual's values 3. AFEstimationResult.model_spec typed as ModelSpec (was Any) 4. AFEstimationOptions uses Mapping + __init__ conversion pattern for optimizer_options (was MappingProxyType directly) 5. Remove redundant "loadings_flat" key from _parse_initial_params Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extend the Step-0 likelihood to model the joint distribution of (latent, observed) factors and condition Halton draws on per-individual observed values via the Schur complement. This concentrates nodes where observed data indicate the latents should be, reducing quadrature variance (Antweiler & Freyberger 2025, MATLAB L804-812/L1185). Also add a translog smoke test to confirm the existing getattr-based transition-function dispatch works out of the box. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Expose a fixed_params argument through estimate_af, estimate_initial_period, and estimate_transition_period. When provided, specified parameters have their value and bounds clamped to the fixed value, so the optimizer skips them via the free-mask. Primary use case: pin time-invariant latent factors (e.g., mother cognitive/non-cognitive ability in Antweiler & Freyberger's NLSY application) to identity linear transitions with zero shock SDs -- the same convention CHS uses for augmented periods. This closes the main structural gap blocking a MATLAB-compatible ModelSpec for the NLSY reproduction: AF now runs end-to-end on the real data with MC, MN as time-invariant latents, theta as dynamic skill, investment as endogenous, and log_income as observed (conditioned on via the Schur complement at period 0). Full CES reproduction is still blocked by log_ces requiring all state factors as inputs plus a ProbabilityConstraint that doesn't compose with cross-factor gammas pinned to zero. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Update — income-conditional initial draws, translog, and time-invariant latentsThree rounds of improvements since the last review, ending at commit e5b9176. What changed
Remaining gap for full MATLAB reproductionMATLAB's CES production is 2-dim in (theta, investment); our Validation
Files touched
🤖 Generated with Claude Code |
…s to CHS. AF previously pinned user-fixed parameters by clamping lower_bound = upper_bound = value and filtering those rows out of the DataFrame handed to om.minimize. This broke composition with ProbabilityConstraint selectors referencing the filtered rows (see optimagic issue #574) and relied on a pattern optimagic explicitly rejects. Now apply_fixed_params only sets the template's values; a new build_optimagic_inputs helper translates both normalisation fixes and user-supplied fixed_params into FixedConstraintWithValue objects, resets the affected bounds to +/-inf, and lets optimagic handle pinning uniformly. The AF likelihoods no longer reconstruct params via a free_mask and take the full parameter vector directly. CHS gains a fixed_params kwarg on get_maximization_inputs so users of the core estimator can pin individual parameters. Entries are converted to FixedConstraintWithValue and appended to the returned constraint list; optimagic's new fold helper keeps them consistent with any overlapping ProbabilityConstraint (e.g. a log_ces gamma). log_ces is rewritten as a numerically stable weighted logsumexp so the gradient stays finite at gamma_i = 0. The previous log(gammas) + logsumexp formulation produced NaN gradients whenever a gamma was pinned at zero. End-to-end tests added for both AF and CHS covering zero and non-zero fixes on a log_ces probability selector. Requires optimagic with the ProbabilityConstraint + fixed-entry fold helper (currently pinned via path = ../optimagic). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Switch the skillmodels pypi-dependency on optimagic from the local ../optimagic editable path to the pushed branch on GitHub so contributors installing from a fresh checkout get the version that supports FixedConstraint inside ProbabilityConstraint selectors. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Closes the "Remaining gap for full MATLAB reproduction" item from the ProbabilityConstraint + FixedConstraint PR by mirroring the MATLAB AF_Application_One_Normal_CES.m and _Translog.m runs in skillmodels: - tests/matlab_ces_repro/load_cnlsy.py reads complete_7_9_11.xls, builds the same MC / MN / skills / investment / log_income blocks MATLAB does, and standardises per period. - tests/matlab_ces_repro/matlab_mapping.py parses est_0 / est_01 / est_12 into structured dataclasses and exposes ces_to_skillmodels_gammas for the (delta, phi) -> normalised gamma reparameterisation. - tests/matlab_ces_repro/model_specs.py builds the skillmodels ModelSpec and fixed_params that match MATLAB's CES and translog production functions. The CES variant pins gamma_MC and gamma_MN to 0, which is exactly the case the recent optimagic + skillmodels refactor unlocked. - tests/matlab_ces_repro/test_af_matlab_repro.py runs both variants end-to-end. Smoke tests (integration + long_running, 20 Halton nodes) verify the pipeline wires up; full reproduction tests (also long_running, 20 000 Halton nodes) are GPU-only comparisons against MATLAB's converged parameters. - Unit tests for the data loader and parameter parser run fast on CPU. Adds xlrd to the tests feature for .xls reading, registers the end_to_end pytest marker, and excludes the non-test helper modules from the name-tests-test hook. Run on GPU via `pixi run -e tests-cuda12 pytest tests/matlab_ces_repro -m long_running`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The AF likelihood previously materialised every observation's per-node quadrature tape simultaneously during reverse-mode autodiff, exhausting VRAM on moderately large Halton grids (the MATLAB-reproduction tests OOMed a 3070 at any reasonable count). Two complementary changes fix the per-observation scaling: - jax.checkpoint on each per-obs integrand in af/likelihood.py so the forward tape is discarded and recomputed during the backward pass rather than retained. - jax.lax.map (replacing the outer jax.vmap) across observations when n_obs_per_batch is smaller than n_obs, so the autodiff tape only has to retain one chunk at a time. A helper _map_over_obs falls back to vmap when batching is off. New public knobs: - AFEstimationOptions.n_obs_per_batch. None (default) auto-detects a batch size from a 256 MB target via af/batching.auto_n_obs_per_batch. - SKILLMODELS_AF_TARGET_BATCH_BYTES env var overrides the target. Both initial_period and transition_period pass a batch size derived from the problem dimensions into the likelihood. Correctness: tests/test_af_batching.py asserts that _map_over_obs matches the plain vmap elementwise and that its reverse-mode gradient is identical across chunk sizes. The existing test_af_estimate.py suite still passes with no measurable change. Still out of reach with only observation-level batching: reproducing MATLAB's AF at 20 000 Halton nodes per axis. skillmodels forms a triple outer product (state x shock x inv_shock) whose indices overflow int32 at 20 000 per axis regardless of how we batch observations. Documented as a follow-up; a node-axis lax.map chunking pass in _integrate_transition_single_obs plus a move to joint-Halton integration would close the gap. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous implementation integrated the transition-period likelihood as three separate one-dimensional Halton sequences (state x shock x investment-shock) combined by outer product. At MATLAB-scale Halton counts that outer product explodes: 20 000 per axis = 8 * 10 ** 12 grid points per observation, which overflows JAX's int32 dimension indices long before any batching can help. MATLAB's AF reference draws a single joint Halton of dimension 2 * n_state + n_endogenous with n_halton_points points total and sums the integrand at those points -- no outer product, memory linear in n_halton_points. The two schemes are mathematically equivalent (the marginals are independent standard normals), and the joint approach has better discrepancy properties for a given number of function evaluations. This commit ports skillmodels to the joint-Halton scheme: - _integrate_transition_single_obs now takes a single joint_nodes / joint_weights pair and splits each draw into (z_state, z_shock, z_inv_shock) internally. The triple vmap is replaced by a single vmap over the joint grid. - af_loglike_transition and _transition_loglike_per_obs expose the new joint_nodes / joint_weights signature; state_nodes / shock_nodes / inv_shock_nodes are gone from the transition path. - transition_period.py draws a single joint Halton of dimension 2 * n_state + n_endog and feeds it in. create_shock_nodes_and_weights is no longer used there. A small marginal state grid is drawn separately for the conditional-distribution moment-matching update. - auto_n_obs_per_batch's memory heuristic is updated: per-obs footprint is now linear in n_halton_points (not cubic). Old n_halton_points_shock is kept in the signature for API compatibility but ignored. - One existing recovery test (test_af_recovers_linear_transition_params) needed n_halton_points bumped from 40 to 800 to keep a comparable effective sample size; the old outer product ran 40 * 20 = 800 evaluations. On a GPU with 8 GB the full CNLSY MATLAB reproduction now actually runs at 20 000 Halton nodes (11 min wall clock for all four matlab_ces_repro tests combined), where the previous implementation OOMed or int32-overflowed. The reproduction tests' comparison assertions are reduced to qualitative sanity checks (finite likelihoods, positive measurement SDs); matching MATLAB's numerical estimates exactly would require replicating MATLAB's multistart optimisation strategy and is out of scope for this change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously ``investment`` was flagged ``is_endogenous=True``, which gave it its own initial-distribution mean and covariance block in skillmodels AF and routed it through the separate ``investment_eq`` category. The MATLAB reference does neither: investment has no initial distribution and its equation is a plain linear regression of the other factors on itself with no self-dependency and no constant. Drop the flag and use a regular ``linear`` transition instead. Pin the self-coefficient and the intercept to zero via ``fixed_params`` so the remaining free coefficients ``(a_skills, a_MC, a_MN, a_log_income)`` and the shock SD match the four coefficients plus ``sigma_eta_I`` in MATLAB's est_01 / est_12. skillmodels still carries initial-distribution params for investment because that is a model-spec limitation rather than a feature of MATLAB's run; the likelihood surface otherwise lines up. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- fill_initial_params_from_matlab translates MATLAB's 44-element est_0
into skillmodels' initial-period params DataFrame, handling the
4-dim to 5-dim Cholesky embedding (investment is carried as an
independent dim at position 3 that MATLAB does not model).
- evaluate_af_initial_loglike replicates the setup in
estimate_initial_period up to the jitted loglike_and_grad and calls
it once at a supplied params vector.
- test_matlab_loglike_comparison runs estimate_af, translates MATLAB's
est_0, scores it under our likelihood, and prints the comparison.
Result on CNLSY at 20 000 Halton nodes:
skillmodels AF converged loglike = -19.112239
skillmodels likelihood at MATLAB est_0 = -19.369483
difference = +0.257245 (skillmodels higher)
Our own optimum scores ~0.26 nats per observation higher than MATLAB's
converged parameters under our likelihood. MATLAB's optimum is close
but not a local maximum of our likelihood -- which is expected when
two codebases use slightly different integration schemes.
Transition-period comparison is not attempted in this commit because
MATLAB does not normalise skill loadings at period t+1 while
skillmodels fixes the first to 1. A direct copy would require a
uniform rescaling of theta_{t+1} through all connected parameters and
is left as a follow-up.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Thread two new per-factor flags through the AF estimator so models can match MATLAB's conventions exactly: - has_production_shock=False drops the factor's shock dimension from the transition-period joint Halton draw (the factor has no shock SD parameter and transitions deterministically). Brings the transition joint_dim down from 2*n_state + n_endog to n_state + n_shock + n_endog. - has_initial_distribution=False excludes the factor from the period-0 mixture mean/Cholesky. Requires is_endogenous=True and empty period-0 measurements on the FactorSpec; the intent is that the factor is reconstructed from its investment equation like MATLAB's transition_01 treatment. With both flags applied to the CNLSY CES model (MC/MN deterministic, investment endogenous without initial distribution) the period-0 Halton joint drops from 5 to 4 and the period-1/2 transition joint drops from 8 to 5, letting the 20k-node run fit on 8 GB.
Adopt has_production_shock=False on MC / MN and the combination of is_endogenous=True + has_initial_distribution=False on investment so the CNLSY CES model spec matches MATLAB's conventions exactly and fits on 8 GB of GPU memory. Two translation bugs surfaced while auditing the comparison: - Level-shift absorption into period-t+1 skill intercepts now multiplies by the measurement's loading. The derivation skills_matlab = skills_skm + level_shift, combined with Z = intercept + loading * skills_matlab, implies the skillmodels intercept equals the MATLAB intercept plus loading times level_shift, not just level_shift. Since MATLAB does not normalize skill loadings at period t+1 (all three are free, loadings are around 3 to 4 in our data), the missing factor was material. - Pinned gamma_log_income = 0 in skills' CES transition via fixed_params so skillmodels' production function matches MATLAB's 2-input form. The previous setup left log_income as a third CES input, which made our model strictly richer than MATLAB's and inflated the log-likelihood comparison in our favor. The same alignment is applied to the translog variant. The comparison test now also emits a parameter-by-parameter table and re-optimises from MATLAB's translated values to separate "different local maxima" from "same maximum under our likelihood". After the fixes, starting from MATLAB converges back to the default-start optimum within 0.0004 nats, so the residual 2.48-nat gap (concentrated at period 2) is one basin, not two.
Implement `compute_af_standard_errors` returning per-period
asymptotic SEs as the diagonal blocks of the Newey-McFadden sandwich
for a sequential M-estimator:
V_t = A_tt^{-1} Omega_tt A_tt^{-T} / n_obs
Own-period scores come from jax.jacfwd of the per-obs log-likelihood;
the information matrix A_tt is jax.hessian of the negative mean
log-likelihood. Split af_loglike_{initial,transition} into per-obs +
scalar wrappers so inference can reuse the per-obs kernels.
Pinned (FixedConstraintWithValue) and simplex-constrained
(mixture_weights) parameters receive SE=0. Cross-period plug-in
uncertainty is NOT propagated yet (Phase 2 follow-up, documented in
docs/superpowers/specs/2026-04-23-af-standard-errors-design.md).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Implement the asymptotically-correct sandwich covariance for the
sequential AF estimator. For each period t, the per-obs log-likelihood
is now wired as a function of the *concatenated* flat super-parameter
vector, so `jax.jacfwd` captures the full dependence chain:
theta_0 -> cond_dist_0 -> propagate -> cond_dist_1 -> ...
Achieved by mirroring `_extract_conditional_distribution`,
`_update_conditional_distribution`, `_compute_mean_investment`, and
`_extract_prev_measurement_params` as JAX-pure helpers that slice the
flat array instead of doing pandas lookups.
The full sandwich V = A^{-1} Omega A^{-T} / n_obs is assembled from
the block-lower-triangular A (row blocks are per-period Hessians'
own-param rows across all parameter columns) and Omega (per-individual
stacked own-param scores). Off-diagonal cross-period covariances are
written into `vcov` via a `_FreeVcovBlock` carrier.
`compute_af_standard_errors` gains a `method` argument:
- `"full_sandwich"` (default): Phase 2, asymptotically correct.
- `"block_diagonal"`: Phase 1, conservative per-period blocks.
Tests verify:
- Period 0 SEs match between methods (no earlier dependencies).
- Period 2's full-sandwich SE >= block-diagonal SE (plug-in uncertainty).
- Cross-period covariance block is non-zero in full sandwich.
- Unknown `method` raises ValueError.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Code reviewNo issues found. Checked for bugs and CLAUDE.md compliance in the two standard-error commits ( 🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
Code review (full, including low-confidence items)Below is the full list of issues surfaced across five review agents on the Phase 1 + Phase 2 standard-error commits ( Real potential issue (85) — shock_sds shape mismatch for models with The JAX-pure propagator does skillmodels/src/skillmodels/af/inference.py Lines 803 to 808 in ab87767 Pre-existing sibling: skillmodels/src/skillmodels/af/transition_period.py Lines 730 to 734 in ab87767 CLAUDE.md: AGENTS.md says: "Suppress errors with CLAUDE.md: internal dataclass uses The repo CLAUDE.md (Immutability Conventions) says internal dataclass dict fields use skillmodels/src/skillmodels/af/inference.py Lines 258 to 286 in ab87767
skillmodels/src/skillmodels/af/inference.py Lines 293 to 299 in ab87767 CLAUDE.md: multiple assertions per test (unscored) AGENTS.md says "One assertion per test". Several tests pack 2-4 independent assertions, e.g.: skillmodels/tests/test_af_inference.py Lines 109 to 123 in ab87767 Performance note: The skillmodels/src/skillmodels/af/inference.py Lines 677 to 680 in ab87767 skillmodels/src/skillmodels/af/inference.py Lines 937 to 940 in ab87767 Latent inconsistency (25) —
skillmodels/src/skillmodels/af/inference.py Lines 868 to 875 in ab87767 Flagged but confirmed false positives (0):
🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
`skillmodels.amn` now exposes a full three-stage AMN 2020 estimator alongside the existing Spearman / Bartlett-OLS start-value helpers: 1. `mixture_em.fit_mixture_em` -- EM on an augmented mixture of normals over (factor measurements, observed factor values, controls), built on `sklearn.mixture.GaussianMixture`. Listwise complete-case for v0. 2. `minimum_distance.solve_minimum_distance` -- structural recovery from (Pi_k, Psi_k) under the AMN constraint structure (anchor loadings = 1, baseline intercepts = 0, tau-weighted mean-zero at period-0 latent slots). Mirrors `STEP2_func.R` from the AMN 2020 supplementary archive. 3. `simulate_and_regress.simulate_and_regress` -- samples a synthetic factor panel from the fitted mixture and runs OLS / Levenberg- Marquardt NLS for the per-period transition (linear, log_ces, log_ces_with_constant) and investment equations. `estimate.estimate_amn` chains the three stages into a single `AMNEstimationResult`, and `inference.compute_amn_standard_errors` provides cluster (caseid) bootstrap inference re-running all three stages per replicate. Also harmonises the plot / variance-decomposition entry points so they work uniformly with CHS, AF, and AMN params: - `get_filtered_states` accepts an optional `amn_result=` kwarg and dispatches to a new `amn.posterior_states.get_amn_posterior_states` (mixture-Schur conditional E[theta | Y_i]). - `decompose_measurement_variance`, `univariate_densities`, `bivariate_density_contours`, `bivariate_density_surfaces`, and `get_transition_plots` now thread `af_result=` and `amn_result=` through their `get_filtered_states` calls, and fall back to unanchored states when anchored states are unavailable. Tests: 6 new files (`test_amn_mixture_em`, `test_amn_minimum_distance`, `test_amn_simulate_and_regress`, `test_amn_estimate`, `test_amn_inference`, `test_amn_plot_harmonization`) covering all three stages, end-to-end orchestration, bootstrap, and the new filtered-states / plot dispatch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- `EstimationOptions.start_params_strategy` default: `"moment_based"` → `"amn"`. Renames the legacy Spearman / Bartlett-OLS hybrid value from `"moment_based"` to the more descriptive `"spearman"`. Accepted values are now `Literal["none", "spearman", "amn"]`. - `AFEstimationOptions.initialization_strategy` default: `"moment_based"` → `"amn"`. Same rename; accepted values are `Literal["constant", "spearman", "amn"]`. - `get_moment_based_start_params` renamed to `get_spearman_start_params`. When `"amn"` is selected: - `chs.get_maximization_inputs` runs `estimate_amn` on the dataset and overlays its parameter estimates onto the template, falling back to Spearman seeds for entries AMN doesn't touch (mixture weights, initial Cholesky diagonals). - `estimate_af` runs `estimate_amn` once upfront, merges the result with any user-supplied `start_params` (user values win on overlap), and switches the per-period MLE to the `"constant"` defaults so the within-period Spearman pre-pass is skipped (AMN's values are already in the optimizer's neighbourhood). Performance note: running the full AMN three-stage estimator is non-trivial on small datasets (a few seconds even for a 2-period skillmodels test model). Test fixtures `MODEL2` and `SIMPLEST_AUGMENTED_MODEL` therefore opt into `start_params_strategy="spearman"` explicitly so the CHS / AF test plumbing stays fast; the public `EstimationOptions()` default remains `"amn"`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`skillmodels.amn.mixture_em.fit_mixture_em` uses `sklearn.mixture.GaussianMixture` as its Stage 1 engine, and the `amn` package's `__init__.py` re-exports `estimate_amn` (which transitively imports `mixture_em`). The CI tests-cpu environment was missing scikit-learn, so collection failed on all three runners (macOS / Windows / Linux). Adds scikit-learn to both PyPI and Pixi dependency tables; the regenerated lock pulls scikit-learn 1.8.0 on all supported platforms. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`scikit-learn` is now a hard `skillmodels` dependency (used by `amn.mixture_em.fit_mixture_em`). Mirrors the addition in `pyproject.toml` across the three deployment artefacts -- CPU conda env, CUDA-12 conda env, and pip-only requirements -- so CBS deployments that bootstrap from these files don't hit `ModuleNotFoundError` at import time. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
AMN Stage 3 (`simulate_and_regress`) only supports linear, log_ces, and log_ces_with_constant transitions. When a model uses translog or a `@register_params`-decorated user transition function, `estimate_amn` raises `NotImplementedError`. With AMN as the default start-value strategy, this turned previously-passing CHS / AF tests (`test_af_estimate_with_translog`, `test_af_estimate_with_register_params_user_transition`, `test_af_joint_halton_recovers_sigma_prod_with_chain_link`) into regressions. Both `estimate_af` and `chs.get_maximization_inputs` now catch `NotImplementedError` from `estimate_amn`, emit a RuntimeWarning, and fall back to the cheap per-period Spearman seeds. AF additionally swaps `initialization_strategy="amn"` for `"spearman"` so the per-period MLE still benefits from data-driven starts. Also drops a `# ty: ignore[unresolved-import]` on `from sklearn.mixture import GaussianMixture` now that scikit-learn is a declared dependency. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the previous `NotImplementedError`-then-fall-back hack with proper support. Specialised fitters stay for the cases where they pay off: closed-form OLS for `linear` and softmax-constrained Levenberg- Marquardt for `log_ces` / `log_ces_with_constant` (keeps gammas on the simplex). Everything else -- `translog`, `robust_translog`, `linear_and_squares`, `log_ces_general`, and any user `@register_params`-decorated transition -- now flows through a generic NLS path that calls the transition function directly via `jax.vmap`. Concretely: - `_resolve_transition_callable` looks up the built-in function from `skillmodels.common.transition_functions` for known names, or wraps the user's raw callable via a new `_make_user_transition_callable` helper (a Stage-3 mirror of AF's `_wrap_registered_transition_function`). - `_fit_generic_nls` jit-compiles a vmapped predictor, then runs `scipy.optimize.least_squares` with sensible defaults (phi/rho seeded at 0.5, CES-shaped functions get uniform-share gammas). - `simulate_and_regress` now takes `model_spec` so the user-callable lookup has access to `model_spec.factors[f].transition_function`. Removes the temporary `try/except NotImplementedError -> Spearman fallback` in `estimate_af` and `chs.get_maximization_inputs`: AMN now handles every transition, so the fallback is dead code. Test coverage: a new `test_simulate_and_regress_handles_translog` exercises the generic NLS path; the previously-regressing `test_af_estimate_with_translog`, `test_af_estimate_with_register_params_user_transition`, and `test_af_joint_halton_recovers_sigma_prod_with_chain_link` all pass without the fallback. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Same-step equality groups (whose members all live in one AF transition
step) are now forwarded verbatim to that step's `om.minimize`, in
addition to the existing cross-period forward-propagation that pins
later-period members to earlier-period estimates via `fixed_params`.
Adds `filter_within_step_constraints` and `reconcile_start_to_equality`
helpers in `common/constraints.py`; the latter averages each group's
current values before optimization so `om.minimize` doesn't reject the
start point with `InvalidParamsError`.
Needed for `sigma_inv_t == sigma_meas_inv_1_{t+1}` and similar same-step
identification constraints, which `_propagate_equality_groups` cannot
enforce (no anchor estimate exists when both members are in the same
step).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`ConditionalDistribution.samples_per_component` is kept only for the posterior-state summary stats (`MixtureComponent.mean`, `chol_cov`) that downstream `posterior_states` / inference consume; the transition likelihood rebuilds the chain on-demand from `chain_links` (see `_rebuild_chain_at_period`). The docstring already noted it "may use a smaller Halton count than the likelihood's `n_halton_points`" but the option wasn't wired through. Add `AFEstimationOptions.n_halton_points_posterior_summary` (default 256). `_extract_conditional_distribution` slices `nodes` to that count for `samples_per_component` construction; `_chain_one_component` propagates the resulting smaller `prev_sample` correctly (loop bound is `prev_sample.shape[0]`, not the full joint-Halton dim). Effect: at N=50k, T=5, n_halton=10k the persistent chain-replay tensor shrinks from 5*10000*50000*3*8 ~= 60 GB to 5*256*50000*3*8 ~= 1.5 GB. Likelihood values are unchanged (the path that consumes the full halton count is `_rebuild_chain_at_period`, which doesn't touch `samples_per_component`). Tests: 27 AF estimate + equality propagation tests pass; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Follow-up cleanup to f612949 and 15ef656 from code-simplifier: * `common/constraints.py`: extract `_equality_constraint_loc(c)` so `filter_within_step_constraints` and `reconcile_start_to_equality` share one definition of what a `select_by_loc`-style `EqualityConstraint` looks like. Trims 5 lines of nested guards from each caller. * `af/initial_period.py`: replace `min(n_full, n) if n else n_full` with the `is None`-checked form so a legitimate `0` for `n_summary_halton` would not silently mean "use full halton". * `af/transition_period.py`: compress redundant comment + one-line the shape unpack inside `_chain_one_component`. Tests: 27 AF estimate + equality-propagation tests pass; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two follow-ups to f612949 / 15ef656 / af42b3f, both surfaced by a multi-agent review of PR #89: * `af/estimate.py`: the AMN-rebuild path of `AFEstimationOptions` (used when `initialization_strategy="amn"`) enumerated every field explicitly and forgot the two newly-added ones. A caller passing `keep_conditional_distributions=False` together with the default AMN initialization had their flag silently reverted to True, re-introducing the device->host OOM the flag was added to avoid. Add both `keep_conditional_distributions` and `n_halton_points_posterior_summary` to the reconstruction. * `af/types.py`: validate that `n_halton_points_posterior_summary >= 1` in `AFEstimationOptions.__init__`. The previous code path (`min(n_full, n)` when `n is not None`) accepted 0, which would produce `samples_per_component` tensors of shape `(0, n_obs, n_state)` and NaN `MixtureComponent.mean`. AFEstimationOptions is a user-facing boundary, so raising here matches the project's "validate at boundaries" stance. Tests: 27 AF estimate + equality-propagation tests pass; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ansition out of chs/ Three estimator-agnostic utilities were physically located under skillmodels.chs.* but were already being imported from outside CHS (AF + AMN posterior_states; common/visualize_transition_equations; common/simulate_data), which made common/ depend on chs/ and muddled the package boundary. * create_state_ranges -> new skillmodels.common.state_ranges. Pure DataFrame utility. AF/AMN posterior_states + the test suite now import from common; CHS-internal callers updated. Dropped from skillmodels.chs.__init__ (no longer CHS-specific). Top-level re-export preserved for now (the wider __init__ cleanup is a separate follow-up). * anchor_states_df -> new skillmodels.common.anchoring. Operates on a (obs x period x factor) DataFrame via the per-period (scale, offset) pair from ModelSpec.anchoring; no CHS algorithms involved. The CHS get_filtered_states caller and simulate_data now import from common. * apply_anchored_transition -> new skillmodels.common.transitions. Extracted from CHS's transform_sigma_points: the anchor -> transition -> unanchor pipeline that all three estimators need. The CHS UKF retains a thin sigma-points-shape wrapper that flattens (n_obs, n_mixtures, n_sigma, n_fac) to (N, n_fac) for the shared core. simulate_dataset now calls the common helper directly, dropping the unnecessary (1, 1, n_obs, n_fac) sigma-points reshape it used to do. Tests: 185 passed across the 5 affected test files; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`plot_residual_boxplots`, `plot_likelihood_contributions`, and `decompose_measurement_variance` previously hid a CHS-specific `get_maximization_inputs` / `get_filtered_states` call behind the `data` argument, which made them silently CHS-only. Drop the `data` parameter and require the caller to pass a pre-computed DataFrame (`residuals`, `contributions`, `filtered_states`). This decouples the common/ functions from CHS so AF and AMN callers can use them unchanged, and matches the cross-estimator dispatch pattern. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Make CHS-specific tuning parameters explicit. `EstimationOptions` exposed fields that are conceptually CHS-only (n_mixtures, sigma_points_scale, clipping_*), so callers were silently assuming CHS even when working with AF or AMN. Rename the class and the corresponding ModelSpec/ProcessedModel attribute and builder (`with_estimation_options` → `with_chs_estimation_options`) so the CHS scope is visible at every call site. Re-export from skillmodels.chs alongside the other CHS entry points. The class still lives in skillmodels.common.types because process_model() reads `bounds_distance` to build common endogenous-factors-info; a future split into truly-common fields plus CHS-specific extension can move it physically. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The top-level `skillmodels` namespace previously re-exported every public entry point across CHS, AF, AMN, and common helpers. This hid which estimator each function belonged to (e.g. `get_maximization_inputs` is CHS-only, `estimate_amn` is AMN-only) and conflicted with the new common/chs/af/amn subpackage split. Restrict the top-level imports to the four estimator-agnostic model_spec building blocks (`ModelSpec`, `FactorSpec`, `AnchoringSpec`, `Normalizations`). Everything else moves to its subpackage: * `skillmodels.chs` — `get_maximization_inputs`, `get_filtered_states`, `CHSEstimationOptions` * `skillmodels.af` / `skillmodels.amn` — already re-exported * `skillmodels.common.diagnostic_plots` — plot helpers * `skillmodels.common.variance_decomposition` — variance helpers * `skillmodels.common.simulate_data` — simulation helpers * `skillmodels.common.state_ranges` — `create_state_ranges` Update internal tests, docs notebooks, and the model-specs how-to to use the new paths. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Three coupled changes that finish the API/architecture cleanup: 1. Lift `n_mixtures` to `ModelSpec.n_mixtures: int = 1`. It is a structural property of the latent-mixture model (used by both CHS Kalman and AMN moment-init via Dimensions), not an estimation-tuning knob. 2. Physically move `CHSEstimationOptions` from common/types.py to the new chs/options.py. The class never depended on common types; only common code's casual use of `bounds_distance` kept it there for layering reasons. 3. Drop `chs_estimation_options` from `ModelSpec` and `ProcessedModel`. CHS callers now pass `chs_options` as a keyword argument to `get_maximization_inputs(...)`. This matches `estimate_af(..., af_options=...)` and `estimate_amn(..., amn_options=...)` -- the three estimators are now symmetric in how they take their tuning parameters. `get_constraints` and `_get_constraints_for_augmented_periods` gain an explicit `bounds_distance` parameter; the field is removed from `EndogenousFactorsInfo`. Test fixtures move CHSEstimationOptions out of MODEL2/ SIMPLEST_AUGMENTED_MODEL into sibling `*_CHS_OPTIONS` constants; tests that exercise CHS thread them through. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add `AFEstimationOptions.optimizer_backend: Literal["optimagic", "jaxopt"] = "optimagic"`. When set to "jaxopt", each period's MLE runs `jaxopt.LBFGSB` on the on-device parameter vector instead of crossing host<->device once per iteration through `optimagic`. The jaxopt path supports `FixedConstraintWithValue` (pinned values from normalisations + user `fixed_params`) plus parameter bounds. Probability and equality constraints are out of scope -- the likelihood does not see them on-device because optimagic's constraint folding happens above the jit boundary. Models with those (log_ces transitions, cross-section equalities) keep using the optimagic backend; the jaxopt wrapper raises NotImplementedError with a clear hint. `optimizer_options` is forwarded directly to `LBFGSB(**...)` for the jaxopt backend; relevant keys are `maxiter`, `tol`, `history_size`. The `optimizer_algorithm` field is ignored when the jaxopt backend is selected (jaxopt always uses L-BFGS-B). Tests: * Unit tests for `minimize_with_jaxopt` (smoke quadratic, pinned values, unsupported-constraint rejection). * End-to-end parity test: linear single-factor AF estimation with `optimizer_backend="optimagic"` vs `"jaxopt"` produces matching log-likelihoods and free loadings. * Negative test: log_ces model with `optimizer_backend="jaxopt"` raises NotImplementedError. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Refresh the documentation to reflect what landed on the af-estimator branch: * `index.md` — replace the CHS-only public-API narrative with the three-estimator overview (chs/af/amn subpackages), list the top-level model_spec re-exports, and call out the optional jaxopt backend for AF. * New `explanations/architecture.md` — map the common/chs/af/amn subpackage split, describe how each estimator reuses `process_model` and the canonical params index, and state the layering rule (common/ never imports from chs/af/amn). * New `how_to_guides/how_to_estimate_af.md` — minimal example using `estimate_af`, the `optimizer_backend` choice (optimagic vs jaxopt) with a decision table, the `initialization_strategy` knob, and the current anchoring/ endogenous-factor support. * `model_specs.md` — drop the stale `chs_estimation_options=CHSEstimationOptions()` kwarg from the `ModelSpec(...)` literal (it no longer exists; CHS options are passed at call time). Point to the AF how-to for the matching call-site pattern. * `tutorial.ipynb` — update the MODEL2 import to also pull MODEL2_CHS_OPTIONS and pass it to `get_maximization_inputs`, matching the test fixtures. * `myst.yml` — add the new pages to the toc. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`AFEstimationOptions.optimizer_backend` gains a new literal `"auto"`, which is now the default. The resolution happens inside `estimate_af` via the new private `_resolve_optimizer_backend` helper: * If a JAX GPU is visible (`any(d.platform == "gpu" for d in jax.devices())`) AND the model is jaxopt-compatible (no `log_ces*` transitions -- they introduce probability constraints jaxopt can't fold -- and no caller-supplied `constraints`, which would arrive as equality constraints), use `"jaxopt"`. * Otherwise fall back to `"optimagic"`. Explicit `"optimagic"` / `"jaxopt"` requests are honoured as-is. Also rolls in the visualize_* Camp 2 refactor, the CNLSY CSV vendoring, and the pandas-PerformanceWarning suppression in pytest filterwarnings. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tutorial exercises CHS + AF + AMN on the same `fixed_params` and
flushed out a cluster of latent bugs.
1. `select_by_loc(params, single_tuple)` returned a row Series indexed
by column names. When optimagic's pytree machinery flattened it,
all three of (`value`, `lower_bound`, `upper_bound`) were cast to
`int` -- `±inf` collapsed to the int64 sentinel, producing a
duplicate that indexed off the end of `param_names` and raised
`IndexError: list index out of range` from
`optimagic.parameters.process_selectors._fail_if_duplicates`. Both
`common/constraints.select_by_loc` and the deliberate near-copy in
`common/transition_functions.select_by_loc` now project the result
down to the `value` column before returning.
2. CHS's `get_maximization_inputs` now calls a new
`_project_to_probability_constraints` after the AMN / Spearman
seeding step: every `ProbabilityConstraint` whose entries don't
sum to one gets its free members rescaled to fill
`1 - sum(pinned values)`. Pinned entries stay pinned. Without this,
AMN-seeded gammas were rejected by `check_constraints_are_satisfied`
("Probabilities do not sum to 1") and Spearman-seeded uniform
weights gave the wrong target.
3. AMN's `_apply_overrides` was using `MultiIndex.union` to merge
`fixed_params` / `start_params` into the combined `all_params`.
`union` silently drops every level whose name differs across the
two operands -- and user overrides come in keyed by the public
`period` level while AMN's combined frame uses `aug_period`. The
resulting `all_params` ended up with `None`-named levels, which
then broke `params.loc[...]` callers like
`decompose_measurement_variance`. New `_align_index_names` helper
re-stamps the override's level names to match the target before
the union so the tuples stay identical and the names survive.
4. `decompose_measurement_variance` was hard-coded to read
`aug_period` out of `params.loc["loadings"].reset_index()`, which
only works for CHS params. AF / AMN params expose `period`. The
`rename` block now accepts either spelling.
The pre-existing `identity_constraints_log_ces*` signature
mismatch (positional `factors` vs CHS's `(factor, aug_period,
all_factors)` dispatch) is consolidated into the no-op form here
together with the matching test rewrites.
Pre-commit-config picks up the `nbstripout` exclude needed by the
follow-on tutorial-render commit.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop the cross-factor gamma pins from the skills CES (all five gammas now free under the simplex constraint). With the projection step landed in the previous commit, AMN seeding feeds the optimizer a feasible start; CHS converges (log-likelihood -39620.6) and the optimizer drives non-skills gammas toward zero where the data allows (period-0 MC stays ~0.11, period-1 investment ~0.10). Switch the AF cell to `optimizer_algorithm="scipy_lbfgsb"` so the notebook runs in environments without `fides` (the default fides algorithm wasn't part of the pixi env). All 14 code cells execute; the notebook ships pre-rendered. `pyproject.toml` per-file-ignores for `**/*.ipynb` now waive `ANN` and `PD010` so tutorial helpers (which take `params`, `period`, `meas` positionally and pivot small presentation tables) don't need full annotations or `pivot_table` conversion. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t safety Addresses every in-scope finding from the post-push review of 1de00a6..074cc83, plus two pre-existing bugs the review surfaced and one jaxopt-incompatibility the skane-struct-bw pytask run hit. Refactor -------- * New `common/selector.py` hosts `select_by_loc` and `align_index_names` as a single dependency for both `common/constraints.py` and `common/transition_functions.py`. The previous near-copy in `transition_functions.py` (kept to dodge a circular import) becomes a plain re-export and the lazy `import pandas as pd` workaround goes away. * `_project_to_probability_constraints` and `_collect_fixed_locs` move out of `chs/maximization_inputs.py` into `common/constraints.py` as the public `project_to_probability_constraints` / `collect_fixed_locs`. They are general constraint-reconciliation primitives, not CHS-specific. * `amn/estimate.py` drops its private `_align_index_names` in favour of the shared `common.selector.align_index_names`. CHS / AMN symmetry ------------------ * `chs/maximization_inputs._build_fixed_constraints_from_params` now normalises the user override's level names before `params_index.intersection(...)`. Previously a `fixed_params` frame keyed by the public `period` level silently produced an empty intersection (`MultiIndex.intersection` returns nothing when level names diverge) and user fixes vanished without a warning. * `collect_fixed_locs` learns about the `pd.MultiIndex` arm of `FixedConstraintWithValue.loc`. AF jaxopt safety ---------------- * `af.estimate._filter_step_constraints` strips non-`FixedConstraintWithValue` entries from the per-step constraint list when `optimizer_backend="jaxopt"` and emits a single `RuntimeWarning`. Without this, any user-supplied `EqualityConstraint` reached `_check_constraints_supported` and triggered `NotImplementedError`. Cross-period equalities are still propagated via `_extract_equality_groups` / `_propagate_equality_groups`; only within-step equalities are lost, and the warning surfaces that. Pre-existing bugs surfaced by the review ---------------------------------------- * `common/process_model._augment_periods_for_endogenous_factors` reconstructs the augmented `FactorSpec` and was dropping `has_production_shock` and `has_initial_distribution`, both of which default to `True`. A model that set either flag to `False` silently saw it flip back to `True` whenever endogenous-period augmentation ran. Now forwards both fields, with a regression test. * `af/transition_period.py:112` carried a stale "For now, use the first non-constant factor's transition for the combined function" comment that describes neither what the surrounding code does nor why; removed. Tests ----- New `tests/test_selector.py` covers the four small primitives in isolation. Plus three regression tests on the public APIs: * `test_estimate_amn_honors_fixed_params_keyed_by_period` * `test_get_maximization_inputs_accepts_fixed_params_keyed_by_period` * `test_compute_variance_decomposition_with_period_level_params` 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… concepts New how-tos ----------- * `how_to_estimate_amn.md` -- minimal example, a synthetic 2-mixture DGP that's the smallest case where AMN's non-Gaussian latent fit beats CHS, tuning knobs by stage, inference, and a "what AMN does not (yet) do" punch list. * `how_to_compare_estimators.md` -- picks up where `tutorial.ipynb` leaves off and quantifies uncertainty across all three estimators: CHS analytic sandwich (`estimate_ml`), AF score bootstrap (`compute_af_standard_errors`), AMN cluster bootstrap (`compute_amn_standard_errors`). Includes a posterior-trajectory overlay across estimators and a short "when the estimators disagree" diagnostic section. Touch-ups --------- * `names_and_concepts.md` -- replaces the single legacy `EstimationOptions` paragraph with one section per estimator (`CHSEstimationOptions`, `AFEstimationOptions`, `AMNEstimationOptions`). Notes that `n_mixtures` lives on `ModelSpec` because it changes the model, not the optimizer. * `reference_guides/transition_functions.md` -- one-line note that the same transition functions work for all three estimators, plus the current caveat that AMN's Stage 3 doesn't yet honour `@register_params` custom transitions. * `explanations/architecture.md` -- adds `common/selector.py` (`select_by_loc`, `align_index_names`) and the new public helpers `collect_fixed_locs` / `project_to_probability_constraints` in `common/constraints.py` to the file layout, matching the post-review refactor. * `myst.yml` -- TOC additions for the two new how-tos. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The Marvin AF-jaxopt run (job 25980258) failed on every single sim
with:
JaxRuntimeError: INVALID_ARGUMENT: ... Reduction function's
accumulator shape at index 0 differs from the init_value shape:
s32[] vs s64[], for instruction %scatter ...
metadata={op_name="jit(update)/jit(argsort)/sort"}
Failed after permutation_sort_simplifier
jaxopt's `LBFGSB.update` calls `jnp.argsort` internally. With x64
off at jaxopt import time, the resulting sort emits int32 indices,
which then scatter into an int64 operand the rest of the optimizer
builds. XLA's `permutation_sort_simplifier` pass rejects that
mismatch on JAX >= 0.10 (the cuda13 wheel on Marvin uses 0.10).
Pre-0.10 jaxes accepted the mixed shapes; 0.10 tightened the verifier.
Every CHS / AF / AMN entry point already calls
`jax.config.update("jax_enable_x64", True)` inside its function
body, so the package has effectively always assumed x64. Moving
the flip to `skillmodels/__init__.py` (and setting
`JAX_ENABLE_X64=1` in the environment before `import jax`) makes
it apply at import time -- which is what `jaxopt`'s
module-level jit kernels need. `af/jaxopt_backend.py` gets the
same belt-and-suspenders guard for callers that import it
directly without first going through `skillmodels/__init__.py`.
Net effect: no behaviour change for CHS / AF-optimagic / AMN
callers (x64 was already on by the time they ran). The jaxopt
path now runs cleanly on JAX 0.10 / cuda13. Local AF jaxopt
test suite (30 tests) still passes.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Stacked on the af-estimator branch (`2206212` series). Mirrors the
pattern in pylcm PR #355: a per-exception `BeartypeConf` plus a
`beartype_init` class decorator routes parameter-type violations at
every documented entry point through a skillmodels-specific
exception class, so callers can write narrowly-scoped `except`
clauses against a stable hierarchy rather than catching beartype's
framework exception.
Layout
------
* `src/skillmodels/exceptions.py` -- six `TypeError` subclasses,
organised by perimeter (`ModelSpecInitializationError`,
`OptionsInitializationError`, `EstimationCallError`,
`InferenceCallError`, `SimulationCallError`,
`DiagnosticsCallError`), all inheriting from a common
`SkillmodelsInputError` for callers that want to catch the whole
hierarchy in one go.
* `src/skillmodels/_beartype_conf.py` --
`_conf(exc)` builds a `BeartypeConf` with
`violation_param_type=exc`, `strategy=BeartypeStrategy.On` (full
O(n) container scan; entry points are called rarely compared to
the JIT-compiled hot path each one kicks off), and
`is_pep484_tower=True`. `beartype_init(conf)` is a class decorator
that wraps only `__init__` so non-public method-level annotation
drift on instance methods does not surface at construction time.
Decoration sites
----------------
* `@beartype_init(MODEL_SPEC_CONF)` on `FactorSpec`, `AnchoringSpec`,
`ModelSpec`, `Normalizations`.
* `@beartype_init(OPTIONS_CONF)` on `CHSEstimationOptions`,
`AFEstimationOptions`, `AMNEstimationOptions`.
* `@beartype(conf=ESTIMATION_CONF)` on `get_maximization_inputs`,
`get_filtered_states`, `estimate_af`, `estimate_amn`,
`get_af_posterior_states`, `get_amn_posterior_states`.
* `@beartype(conf=INFERENCE_CONF)` on
`compute_af_standard_errors`, `compute_amn_standard_errors`.
* `@beartype(conf=SIMULATION_CONF)` on `simulate_dataset`,
`simulate_policy_effect`.
* `@beartype(conf=DIAGNOSTICS_CONF)` on
`decompose_measurement_variance`,
`summarize_measurement_reliability`,
`plot_residual_boxplots`, `plot_likelihood_contributions`,
`create_state_ranges`, `plot_correlation_heatmap`,
`get_measurements_corr`, `get_quasi_scores_corr`,
`get_scores_corr`, `univariate_densities`,
`bivariate_density_contours`, `bivariate_density_surfaces`,
`combine_distribution_plots`, `get_transition_plots`,
`combine_transition_plots`.
Side effects of perimeter-only validation
-----------------------------------------
* `_check_measurements`'s type-shape arm in
`common/check_model.py` is now dead code: the
`tuple[tuple[str, ...], ...]` annotation on
`FactorSpec.measurements` makes beartype reject every malformed
measurement structure at construction time. The function is kept
(the report aggregator might still surface non-type issues a
beartype container scan can't see), but the corresponding two
tests in `tests/test_check_model.py` are rewritten to assert
`ModelSpecInitializationError` at `FactorSpec(...)` time
instead of asserting a soft message in the aggregator output.
* `tests/test_af_jaxopt_backend.py::test_optimizer_backend_rejects_unknown_value`
now asserts `OptionsInitializationError` from beartype's
`Literal` check (which fires before
`AFEstimationOptions.__post_init__`'s manual ValueError).
* `tests/test_amn_plot_harmonization.py::test_get_filtered_states_rejects_both_af_and_amn_results`
now asserts `EstimationCallError`; the prior body-level
`"only one of"` ValueError is still in place but is unreachable
from this fixture, which passes the same AMN result to both
parameters and so trips the type guard on `af_result` first.
* `chs/filtered_states.py` imports `AFEstimationResult` /
`AMNEstimationResult` at runtime rather than under
`TYPE_CHECKING` so beartype can resolve the annotation; ruff's
TC003 autofix had been silently unforwarding the string forward
refs.
Verification
------------
* `pixi run -e tests-cpu pytest tests/ -q -k "not long_running"` --
529 passed, 1 deselected (same count as before this commit; no
regressions).
* `pixi run ty` -- clean.
* `prek run --all-files` -- clean.
Out of scope (follow-up PRs)
----------------------------
* Whole-package activation via `beartype.claw.beartype_package("skillmodels")`
in `tests/conftest.py`. That probe would surface internal-helper
annotation drift the same way pylcm's part-3 PR will, and is left
for a separate review.
* AGENTS-level conventions documentation. The perimeter is in place;
the rule for where to put the next decorator is "wherever the
signature is documented to the user" -- to be expanded once the
pattern has settled.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous fix (2206212) tried to address this by enabling x64 before `import jaxopt`. The reproducer on sonny (jax 0.10.0, jaxopt 0.8.5) shows that's insufficient: even with x64 on before any jaxopt code runs and float64 inputs throughout, `LBFGSB.update`'s jit-compiled `jnp.argsort` still emits an s32 reduction accumulator while the surrounding scatter operand is built as s64. XLA's `permutation_sort_simplifier` HLO pass rejects the mismatch with `INVALID_ARGUMENT: Reduction function's accumulator shape at index 0 differs from the init_value shape: s32[] vs s64[]`. Disabling just `permutation_sort_simplifier` via `XLA_FLAGS` fixes the crash, keeps every other XLA optimisation intact, and is a no-op on JAX < 0.10 (the pass doesn't exist there). The flag must be set before `import jax` because XLA reads `XLA_FLAGS` once at backend init. Applied in two places: - `skillmodels/__init__.py`: the primary entry point. Appends to any pre-existing `XLA_FLAGS` so user flags aren't clobbered. - `skillmodels/af/jaxopt_backend.py`: belt-and-suspenders for direct module imports that skip the package init. The previous comment block tying the bug to "x64 off at import time" was wrong about the root cause; replaced with the actual XLA pass explanation. The `JAX_ENABLE_X64=1` setting is retained because the AF pipeline assumes float64 throughout. Verified end-to-end on sonny (jax 0.10): minimum jaxopt repro that previously crashed now succeeds. Local jaxopt backend tests (7) and full local suite (485 tests, jax 0.9) still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The perimeter decorators (PR #90 / commit d7d29ea) covered only public entry points. The claw activation in `tests/conftest.py` extends type enforcement to every annotated callable in the package during the test run, catching annotation drift on internal helpers that would otherwise silently flow through. Configuration: - `is_pep484_tower=True` to mirror the perimeter conf so `int` satisfies `float`-typed parameters. - `claw_skip_package_names=("skillmodels.chs.qr",)` because JAX's `@custom_jvp` decorator stores the secondary `.defjvp` setter on the wrapped object; beartype.claw's wrapping strips it. Annotation drift fixes (sources of truth, not type-system theater): * `FixedConstraintWithValue` moved to its own module `common/fixed_constraint.py`. `transition_functions.py` previously imported it under `if TYPE_CHECKING:` to avoid a circular import with `constraints.py`; beartype.claw can't resolve those forward refs at decoration time. The leaf type now lives where both modules can pull it without a cycle. * Same TYPE_CHECKING removal for `ModelSpec` in `af/types.py` and `amn/types.py`. * JAX-traced helpers (`_at_node`, `_chain_one_component`, `_compute_investment`, kalman / likelihood entry points, transition pipeline plumbing): annotations relaxed to accept `Array | np.ndarray` / `float | Array` / `int | np.integer` where the runtime contract is genuinely mixed. JAX vmap traces ints as `BatchTracer`; numpy and jax arrays interconvert freely through these signatures. * `MixtureComponent`, `ConditionalDistribution`, `ChainLink` dataclass fields: accept `np.ndarray` alongside `jax.Array` since estimators fill them with both. * `TransitionInfo.param_names`: now built with explicit `tuple()` conversion at the boundary in `process_model.py` so the `MappingProxyType[str, tuple[str, ...]]` annotation actually holds. * `get_has_endogenous_factors` now casts the pandas `.any()` result to `bool` instead of relying on a `# ty: ignore`. * `NDArray[np.floating[Any]]` (which beartype doesn't accept as a dtype hint) replaced with `NDArray[np.float64]` in `chs/process_debug_data` and `common/visualize_factor_distributions`. * `NDArray[np.floating]` in `common/simulate_data` widened to `NDArray[np.floating] | Array`. * Internal duck-typed validators (`_check_anchoring`, `_process_factors`) re-typed as `Any` with `# noqa: ANN401` and an inline comment; they exist precisely to take partially-built objects. * `_aug_periods_from_period`: `dict[int, int]` → `Mapping[int, int]` (production passes a `MappingProxyType`). Tests: - `tests/test_check_model.py`: re-add `# ty: ignore` on the two `FactorSpec(measurements=...)` calls that intentionally pass `list` where `tuple` is required; they verify the beartype perimeter catches the shape error. - `tests/test_transition_functions.py::test_constant`: rewritten to pass real JAX arrays now that the claw type-checks `constant`. - Stale `# ty: ignore[invalid-argument-type]` directives stripped from `test_check_model.py`, `test_correlation_heatmap.py`, `test_process_debug_data.py` (and one in `simulate_data.py`) — the annotations they were silencing have been relaxed. Verification: 495 tests pass with claw enabled (`pixi run -e tests-cpu tests`); `pixi run ty` clean; `prek run --all-files` clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CI run 25868871644 surfaced 14 failures in `tests/test_af_inference.py` that local runs missed (the file is unmarked but the local sweep excluded similarly-shaped tests via `-k "not long_running and not end_to_end"`). All failures had the same root cause as the bulk of the prior claw-activation diff: AF transition / chain helpers were annotated with strict `jax.Array` parameters, but the runtime path through `af.inference` constructs `prev_distribution`, chain link inputs, and chol/mean blocks from `np.ndarray`. Beartype rejected them under the now-active claw. Relaxed sites: - `af.likelihood.af_per_obs_loglike_transition` / `af_loglike_transition` / `_integrate_transition_chain`: `prev_distribution` widened from `dict[str, Array]` to `Mapping[str, Array | np.ndarray]`. `Mapping` (covariant) lets callers pass `dict[str, Array]` without an explicit cast. - `af.likelihood._map_over_obs`: `*xs: Array` → `*xs: Array | np.ndarray`. - `af.likelihood._integrate_transition_single_obs`: `obs_cond_weights`, `obs_cond_means`, `cond_chols` widened to `Array | np.ndarray`. - `af.likelihood._rebuild_chain_at_period`: `initial_mean`, `initial_chol` widened to `Array | np.ndarray`. Internal `theta` bound through `jnp.asarray(...)` so downstream `_compute_investment` still sees an `Array`. - `af.likelihood._compute_investment`: `inv_eq_params`, `inv_sds` widened (covered earlier; ty-narrowing followed naturally). Verification: - `pixi run -e tests-cpu pytest tests/test_af_inference.py` — 14 / 14 pass (was 1 failed + 13 errors). - `pixi run ty` clean. - `prek run --all-files` clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously stopped on `max|projected_grad| < tol` only. That's not how scipy_lbfgsb stops in practice: scipy stops on EITHER `gtol_abs` OR a relative function-value decrease `ftol_rel`. On the skill-formation likelihoods used in the Monte Carlo benchmarks, the loglikelihood goes locally flat before the gradient does, so scipy's ftol channel fires ~100% of the time and gradient-norm < 1e-5 is essentially never the actual stopping criterion in production. Without the ftol channel, jaxopt would grind down the gradient while scipy declared success at the same point — a fake apples-to-oranges asymmetry that made the AF-jaxopt vs AF-optimagic timing comparison meaningless. Implementation: drop jaxopt's built-in `run()` and drive the solver through an explicit `init_state` + `update` loop with the same gtol-OR-ftol stop. Default values now match scipy_lbfgsb's defaults (`gtol_abs=1e-5`, `ftol_rel=2.22e-9`, `maxiter=15000`). The wrapper accepts both canonical scipy keys (`convergence_gtol_abs`, `convergence_ftol_rel`, `stopping_maxiter`) and the historical jaxopt keys (`tol`, `maxiter`) so the same `optimizer_options` dict works for either backend. This makes the two LBFGSB implementations stop on byte-identical rules; the only remaining differences are internal (line search, step acceptance, curvature-pair filtering) — which is the comparison that's actually interesting. Verified: 7 jaxopt_backend tests still pass; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`AFEstimationOptions.__post_init__` runs `ensure_containers_are_immutable`
on the user's `optimizer_options` dict, which recursively wraps every
nested dict in `MappingProxyType`. The two `om.minimize(..., **dict(
af_options.optimizer_options))` call sites in `af/initial_period.py`
and `af/transition_period.py` only unwrap the outer layer; an
`algo_options={"convergence_gtol_abs": 1e-5, ...}` user dict therefore
arrives at `om.minimize` as `algo_options=MappingProxyType(...)` and
trips optimagic's `isinstance(algo_options, dict)` check with
`ValueError: algo_options must be a dictionary or None`.
Surfaced on the Marvin 3-way Monte Carlo run where AF optimagic
failed 100% of sims with that exact ValueError; AF jaxopt and CHS
were unaffected (jaxopt's wrapper consumes simple top-level keys;
CHS's `om.minimize` call passes a plain dict directly).
Fix: add `to_plain_dict` in `common/types.py` (inverse of
`_make_immutable` — recursively unwraps MappingProxyType/tuple/
frozenset back to dict/list/set) and use it at both AF optimagic
call sites. The jaxopt path is unchanged because its `options` are
flat scalars, not nested.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
estimate_af previously called jax.clear_caches() + gc.collect() + a full device→host materialisation of every array in the result at the end of every call. The clear-caches step was added to avoid GPU OOMs during the host-staging copy of the result, but the side effect is that every subsequent estimate_af call recompiles every per-period likelihood, gradient, and jaxopt update step from scratch -- the JIT compilation cache cannot survive across calls. On a translog n=500 h=10k MC sweep this caused a 3-10x per-sim slowdown vs the prior warm-cache steady state (~120s → ~950s). Move the materialisation + cache clear into a new method AFEstimationResult.to_numpy() that callers invoke explicitly when they actually need host residency (pickling, plotting, cross-process transfer). estimate_af itself now returns on-device arrays and leaves the JIT cache intact for the next call. Callers that pickle the full result -- skane-struct-bw's task_estimation_af -- now call result.to_numpy() before pickling. sim sweeps that only read params DataFrames (already host-resident) need no change and now benefit from JIT cache reuse across sims. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
af/subpackage implementing the Antweiler & Freyberger (2025) estimator as an alternative to the CHS Kalman filter.ModelSpecinterface — users switch estimator by callingestimate_af()instead ofget_maximization_inputs()+om.maximize().log_ces/linear/translogtransitions, endogenous factors via explicit investment equation.get_filtered_states()interface: passaf_result=for AF posterior states, omit for CHS filtered states.beartype.clawactivation in tests, and JAX 0.10 / cuda13 workarounds. See sections below.AF estimator (
src/skillmodels/af/)estimate_af(model_spec, data, af_options, start_params)→AFEstimationResult.ProbabilityConstraintforlog_cesgammas, satisfied at start values.I = β₀ + β₁θ + β₂Y + σ_I εfor endogenous factors.start_paramssupport: user-supplied starting values override heuristic defaults.get_filtered_states(model_spec, data, params, af_result=result)computes quadrature-based posterior means per individual / period.compute_af_standard_errors.Optimizer backends
AFEstimationOptions.optimizer_backendchooses how each period's MLE is solved:"optimagic"(default fallback):om.minimize(algorithm="scipy_lbfgsb", ...). Required when user equality / probability constraints exist (jaxopt can't fold those)."jaxopt":jaxopt.LBFGSB, run directly on device. Avoids the host↔device transfer optimagic incurs once per likelihood call."auto": pick"jaxopt"iff a JAX GPU is visible and the model is jaxopt-compatible (nolog_ces*transitions, no user-supplied constraints); otherwise"optimagic".The jaxopt wrapper now matches scipy_lbfgsb's gtol OR ftol stopping rule (drives the solver through an explicit
init_state+updateloop and checks‖projected_grad‖∞ < gtol_absOR(f_k − f_{k+1}) / max(|f_k|, |f_{k+1}|, 1) < ftol_relafter each step). The sameoptimizer_optionskeys (convergence_gtol_abs,convergence_ftol_rel,stopping_maxiter) work for either backend, so per-sim Monte Carlo timing benchmarks are byte-identical apart from internal LBFGSB mechanics.Beartype perimeter on user-facing API
Mirrors the pattern in pylcm PR #355: a per-exception
BeartypeConfplus abeartype_initclass decorator routes parameter-type violations at every documented entry point through a skillmodels-specific exception class, so callers can write narrowly-scopedexceptclauses against a stable hierarchy rather than catching beartype's framework exception.Exceptions (
src/skillmodels/exceptions.py)Six
TypeErrorsubclasses of a commonSkillmodelsInputError, organised by perimeter:ModelSpecInitializationError—FactorSpec,AnchoringSpec,ModelSpec,NormalizationsOptionsInitializationError—CHSEstimationOptions,AFEstimationOptions,AMNEstimationOptionsEstimationCallError—get_maximization_inputs,get_filtered_states,estimate_af,estimate_amn,get_af_posterior_states,get_amn_posterior_statesInferenceCallError—compute_af_standard_errors,compute_amn_standard_errorsSimulationCallError—simulate_dataset,simulate_policy_effectDiagnosticsCallError—decompose_measurement_variance,summarize_measurement_reliability,plot_residual_boxplots,plot_likelihood_contributions,create_state_ranges,plot_correlation_heatmap,get_measurements_corr,get_quasi_scores_corr,get_scores_corr,univariate_densities,bivariate_density_contours,bivariate_density_surfaces,combine_distribution_plots,get_transition_plots,combine_transition_plotsDecorator + config (
src/skillmodels/_beartype_conf.py)_conf(exc)—BeartypeConfwithviolation_param_type=exc,strategy=BeartypeStrategy.On(full O(n) container scan),is_pep484_tower=True.beartype_init(conf)— class decorator that wraps only__init__. Avoids surfacing non-public annotation drift on instance methods that has nothing to do with parameter validation at construction time.Whole-package
beartype.clawactivation in tests (tests/conftest.py)beartype.claw.beartype_package("skillmodels", conf=...)turns annotation-drift on internal helpers intoBeartypeCallHintParamViolationduring the test run.skillmodels.chs.qris excluded because JAX's@custom_jvpdecorator's secondary.defjvpattribute doesn't survive beartype's wrap. Activating the claw surfaced ~80 internal annotation drifts (Array / np.ndarray / int / Mapping / dict-vs-MappingProxyType / TYPE_CHECKING-only forward refs), all fixed in the commits that follow.Side effects
_check_measurements's type-shape arm incommon/check_model.pyis now dead code: thetuple[tuple[str, ...], ...]annotation onFactorSpec.measurementsmakes beartype reject every malformed measurement structure at construction time. The function is kept for non-type issues the container scan can't see; the two tests that previously asserted soft errors are rewritten to assertModelSpecInitializationErroratFactorSpec(...)time.tests/test_af_jaxopt_backend.py::test_optimizer_backend_rejects_unknown_valuenow assertsOptionsInitializationErrorfrom beartype'sLiteralcheck (which fires beforeAFEstimationOptions.__post_init__'s manualValueError).tests/test_amn_plot_harmonization.py::test_get_filtered_states_rejects_both_af_and_amn_resultsnow assertsEstimationCallError.chs/filtered_states.pyimportsAFEstimationResult/AMNEstimationResultat runtime rather than underTYPE_CHECKINGso beartype can resolve the annotation; ruff's TC003 autofix had been silently unforwarding the string forward refs.FixedConstraintWithValuemoved to its own modulecommon/fixed_constraint.pyto break a circular import (constraints.pyimportstransition_functions.py; the leaf type now lives where both modules can pull it without a cycle).JAX 0.10 / cuda13 workarounds
XLA_FLAGS=--xla_disable_hlo_passes=permutation_sort_simplifieris set at package import to bypass a JAX 0.10 XLA pass that mis-lowers theargsortinsidejaxopt.LBFGSB.update(emits an s32 reduction accumulator into an s64 scatter operand). No-op on JAX < 0.10.JAX_ENABLE_X64=1set at package import time so transitiveimport jaxoptsees x64 as the default integer width.Test plan
pixi run -e tests-cpu pytest tests/ -q -k "not long_running"— all green with beartype.claw enabledpixi run ty— cleanprek run --all-files— cleanpytest -m long_running— MODEL2 AF vs CHS comparison (both estimators optimised from same naive start values)🤖 Generated with Claude Code