Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
71e9d59
Refactor model configuration to use frozen dataclasses instead of dicts
hmgaudecker Jan 8, 2026
7766672
Use frozendict for immutable dict fields in dataclasses
hmgaudecker Jan 8, 2026
562cfaa
Ignore ty false positive.
hmgaudecker Jan 8, 2026
4d2c271
Return ProcessedModel dataclass from process_model() instead of dict
hmgaudecker Jan 8, 2026
264016c
Rename FactorEndogenousInfo to FactorInfo
hmgaudecker Jan 9, 2026
3bb334b
Ensure complete typing in src/skillmodels.
hmgaudecker Jan 9, 2026
e2d687a
Require Python 3.14 before fixing type annotations.
hmgaudecker Jan 9, 2026
921c612
Move TESTS_DIR -> TEST_DATA_DIR, which points to a subdirectory of sr…
hmgaudecker Jan 9, 2026
dce66ad
Fix more linting issues.
hmgaudecker Jan 9, 2026
e0f59d0
Make ruff rules much stricter.
hmgaudecker Jan 9, 2026
b15372a
Further tighten type annotations. Fix a missing run-time annotation o…
hmgaudecker Jan 11, 2026
74b2f18
Move unsafe_fixes = false from .pre-commit config to pyproject.
hmgaudecker Jan 11, 2026
04aaa10
Fix query in data simulation.
hmgaudecker Jan 12, 2026
493ead0
More fixes to imports, add test.
hmgaudecker Jan 12, 2026
51209ad
Use more tuples in place of lists to prevent errors.
hmgaudecker Jan 12, 2026
21da176
Fix typing.
hmgaudecker Jan 12, 2026
5925303
Dataclasses for user input.
hmgaudecker Jan 19, 2026
9590c46
Simplify.
hmgaudecker Jan 21, 2026
69f6eab
Use modern rng everywhere.
hmgaudecker Jan 28, 2026
f8dac75
Update CLAUDE.md
hmgaudecker Jan 28, 2026
846534b
Get rid of if TYPE_CHECKING blocks
hmgaudecker Jan 28, 2026
97d84b8
Update hooks and clean up
hmgaudecker Jan 28, 2026
b8a9fbc
Call by name throughout.
hmgaudecker Jan 29, 2026
5012a8b
Autogenerated docs, harmonised hooks / project configuration.
hmgaudecker Jan 29, 2026
7e7784e
Get rid of model_dict.
hmgaudecker Jan 29, 2026
6934f79
Replace yaml model specifications by ModelSpec-s.
hmgaudecker Jan 30, 2026
7cf471a
Next shot at fixing pickling.
hmgaudecker Feb 1, 2026
4380af8
Add improved output formatting.
hmgaudecker Feb 4, 2026
f7a8aa3
Add variance decompositions.
hmgaudecker Feb 4, 2026
348b6c6
Fix variance decomposition to use aug_period instead of period.
hmgaudecker Feb 4, 2026
7662600
Fix diagnostic_plots double-processing of debug data.
hmgaudecker Feb 4, 2026
7c9105c
Review comments 1: single EstimationOptions; improved dataclasses wit…
hmgaudecker Feb 15, 2026
fcb459e
Review comments on docs; tighten typing of Mapping to MappingProxyTyp…
hmgaudecker Feb 16, 2026
8e4ac7d
Back to general implementation of ensure_containers_are_immutable.
hmgaudecker Feb 16, 2026
eda32cd
Merge branch 'strong-typing' into output-formatting
hmgaudecker Feb 16, 2026
a148495
Small fixes.
hmgaudecker Feb 16, 2026
ea067ce
Latest boilerplate version.
hmgaudecker Mar 4, 2026
8207186
Make optimagic a package dependency.
hmgaudecker Mar 5, 2026
8e860d3
Merge branch 'main' into output-formatting
hmgaudecker Mar 9, 2026
f3ffdec
Latest boilerplate.
hmgaudecker Mar 9, 2026
b5f907f
Make jupyterbook and pytask required deps.
hmgaudecker Mar 10, 2026
77ad629
Use dags 0.5.0 dev branch.
hmgaudecker Mar 10, 2026
8b25ba6
Update CLAUDE.md with complete API docs and actual application usage
hmgaudecker Mar 10, 2026
0810e3c
Increase test coverage to close to 100%.
hmgaudecker Mar 10, 2026
c29505e
Fix bugs caught by review agent.
hmgaudecker Mar 10, 2026
24bfdac
Call ty environment 'type-checking'.
hmgaudecker Mar 11, 2026
baf2538
Merge branch 'output-formatting' of github.com:OpenSourceEconomics/sk…
hmgaudecker Mar 11, 2026
da0e140
Require dags 0.5, bump pandas, numpy and hooks, fix resulting issues.
hmgaudecker Mar 11, 2026
b2b68d6
Remove suppression of deprecation warning, which is unnecessary in Pa…
hmgaudecker Mar 11, 2026
fc3dbdd
Make ruff stricter.
hmgaudecker Mar 11, 2026
2f1e14d
Update boilerplate.
hmgaudecker Mar 11, 2026
8f0fc9e
Update ty and remove a couple of unused-ignores.
hmgaudecker Mar 14, 2026
4a836ae
Rename _sel to select_by_loc; fix stale docstring
hmgaudecker Mar 14, 2026
67d7b44
Allow passing states directly to get_transition_plots (health-cogniti…
hmgaudecker Mar 15, 2026
c9519cf
Use optimagic constraints directly. Need a wrapper because we are set…
hmgaudecker Mar 14, 2026
1df0d98
Fix FixedConstraintWithValue.loc type and test expectations
hmgaudecker Mar 14, 2026
2be39aa
Validate not-None input directly.
hmgaudecker Mar 14, 2026
ee4869a
Simplify API by setting .selectors via .loc in __post_init__
hmgaudecker Mar 15, 2026
a32a177
Require dags 0.5.1
hmgaudecker Mar 15, 2026
ede0017
First shot at fixing #36.
hmgaudecker Mar 15, 2026
e3f91d0
Bug fix in kalman_filters.py — added [:n_latent] slice to s_in and c_…
hmgaudecker Mar 15, 2026
0c6cc5a
Fix ty errors and add docs on linear predict.
hmgaudecker Mar 16, 2026
9a7e401
Fix aug_period/period bug in transition equation visualization
hmgaudecker Mar 16, 2026
d23f80c
Remove aug_periods from public-facing functions.
hmgaudecker Mar 16, 2026
c1256f9
Merge branch 'output-formatting' into om.Constraints
hmgaudecker Mar 16, 2026
705488a
Merge branch 'om.Constraints' into linear-predict
hmgaudecker Mar 16, 2026
447719a
Previous commit was too greedy.
hmgaudecker Mar 16, 2026
350c3b9
Merge branch 'output-formatting' into om.Constraints
hmgaudecker Mar 16, 2026
6388c0f
Merge branch 'om.Constraints' into linear-predict
hmgaudecker Mar 16, 2026
220a9b0
Update docs based on benchmark results.
hmgaudecker Mar 18, 2026
94e8704
Use pixi 0.66 in CI; prek autoupdate.
hmgaudecker Mar 18, 2026
b1606b8
CHORE: Approximation-tolerant floating point comparisons, remove pand…
hmgaudecker Mar 18, 2026
980ada3
Get rid of unnecessary block-comments.
hmgaudecker Mar 18, 2026
c0fd717
Merge branch 'output-formatting' into om.Constraints
hmgaudecker Mar 18, 2026
7eb35a9
Merge branch 'om.Constraints' into linear-predict
hmgaudecker Mar 18, 2026
d7ba962
Idiomatic Jax for linear filter, though no speed difference. Added a …
hmgaudecker Mar 18, 2026
5c63203
Merge branch 'main' into linear-predict
hmgaudecker Mar 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions docs/explanations/linear_predict.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Linear predict optimization

## When the linear predict is used

At model setup time, `is_all_linear` checks whether every latent factor's transition
function name belongs to `{"linear", "constant"}`. This is an all-or-nothing decision:
if even one factor uses a nonlinear transition (e.g. `translog`), the entire model falls
back to the unscented predict.

The check happens in `get_maximization_inputs`, where the predict function is selected
via `functools.partial`. When the linear path is chosen, extra keyword arguments
(`latent_factors`, `constant_factor_indices`, `n_all_factors`) are bound at setup time so
the predict function has the same call signature as the unscented variant.

## Why it is faster and uses less memory

The unscented predict generates $2n + 1$ sigma points (where $n$ is the number of latent
factors), transforms each one through the transition function, then recovers predicted
means and covariances from weighted statistics. Its QR decomposition operates on a matrix
of shape $(3n + 1) \times n$: the $2n + 1$ weighted deviation rows plus $n$ rows for the
shock standard deviations.

The linear predict skips sigma-point generation entirely. Because the transition is
linear, the predicted mean is just a matrix--vector product, and the predicted covariance
follows from the standard linear Gaussian formula. Its QR decomposition operates on a
$(2n) \times n$ matrix: $n$ rows from the propagated Cholesky factor and $n$ rows for the
shocks. The reduction from $3n + 1$ to $2n$ rows speeds up the QR step and removes all
sigma-point overhead.

The memory savings can be more important than the speed gains. The unscented path
materialises $2n + 1$ sigma points for every observation and mixture component, and
JAX's automatic differentiation retains intermediate buffers for the backward pass. The
linear path replaces all of this with a single matrix multiply whose memory footprint
scales with $n^2$ rather than with the number of sigma points times the number of
observations. On memory-constrained GPUs this can be the difference between fitting the
model and running out of memory.

## Building F and c

The linear predict assembles a transition matrix $F$ of shape
$(n_\text{latent}, n_\text{all})$ and a constant vector $c$ of length $n_\text{latent}$
from the `trans_coeffs` dictionary. Here $n_\text{all}$ includes both latent and observed
factors.

For each latent factor $i$:

- **Linear factor**: `trans_coeffs[factor]` is a 1-d array whose last element is the
intercept and whose preceding elements are the coefficients on all factors (latent and
observed). Row $i$ of $F$ is set to `coeffs[:-1]` and $c_i$ is set to `coeffs[-1]`.
- **Constant factor**: row $i$ of $F$ is the unit vector $e_i$ (identity row) and
$c_i = 0$, so the factor value is simply carried forward.

The implementation uses a stack-then-mask approach: all coefficient arrays are stacked
into a single matrix (with zero-padded rows for constant factors), an identity matrix
provides the constant-factor rows, and `jnp.where` selects between them using a boolean
mask. This avoids per-element `.at[i].set()` calls and conditional branching, producing
a cleaner trace for JAX's compiler.

Three construction strategies were benchmarked (loop with conditional `.at[i].set()`,
stack-then-mask with `jnp.where`, and index-scatter with pre-separated sub-matrices).
All three produced identical XLA graphs and showed no meaningful runtime difference
(~6.3--6.7 ms per call on CPU, 4-factor model, 5000 observations), confirming that the
construction is fully resolved at trace time. The stack-then-mask variant was kept for
its cleaner, more idiomatic JAX style.

## Mean prediction

The mean prediction incorporates anchoring, which rescales factors to a common metric
across periods. Let $s^{\text{in}}$ and $c^{\text{in}}$ be the input-period scaling
factors and constants, and $s^{\text{out}}$ and $c^{\text{out}}$ the output-period
counterparts. The steps are:

1. **Anchor** the input states: $x^a = x \odot s^{\text{in}} + c^{\text{in}}$.
2. **Concatenate** observed factors to form the full state vector
$\tilde{x} = [x^a, x^{\text{obs}}]$.
3. **Apply the linear transition**: $y^a = \tilde{x}\, F^\top + c$.
4. **Un-anchor** to get the predicted states:
$\hat{x} = (y^a - c^{\text{out}}) \oslash s^{\text{out}}$.

## Covariance prediction (square-root form)

skillmodels maintains covariances in square-root (upper Cholesky) form throughout. Let
$R$ denote the current upper Cholesky factor so that $P = R^\top R$. The linear predict
propagates $R$ as follows.

Define the effective transition matrix

$$
G = \operatorname{diag}(1 / s^{\text{out}})\; F_{\text{latent}}\;
\operatorname{diag}(s^{\text{in}})
$$

where $F_{\text{latent}}$ is the first $n_\text{latent}$ columns of $F$ (the columns
corresponding to latent factors). $G$ folds the anchoring scales into the transition so
that the covariance update works directly in the un-anchored (internal) scale.

The predicted covariance satisfies

$$
\hat{P} = G\, P\, G^\top + Q
$$

where $Q = \operatorname{diag}(\sigma / s^{\text{out}})^2$ and $\sigma$ is the vector of
shock standard deviations. In square-root form, the upper Cholesky factor $\hat{R}$ of
$\hat{P}$ is obtained via a single QR decomposition of the stacked matrix

$$
S = \begin{bmatrix} R\, G^\top \\ \operatorname{diag}(\sigma / s^{\text{out}})
\end{bmatrix}
$$

which has shape $(2n) \times n$. The upper-triangular $R$-factor of $S$ (its first $n$
rows) gives $\hat{R}$.

## Observed factors

Observed factors (e.g. investment measures whose values are known from data) appear as
columns in $F$ and therefore influence the predicted mean through the matrix--vector
product. However, they carry no uncertainty: their columns are excluded from the
covariance propagation. This is why $G$ uses only the first $n_\text{latent}$ columns of
$F$ rather than the full matrix.

## Practical impact

Benchmarks on a 4-factor linear model (`health-cognition`,
`no_feedback_to_investments_linear`, 8 GiB GPU) show a modest ~6 % speed-up on GPU
(8.4 vs 8.9 s per optimizer iteration) and negligible difference on CPU. The speed gain
is small because with only 4 latent factors the unscented transform generates just 9
sigma points — a trivially cheap operation on modern hardware.

The memory reduction is the more significant benefit. Under the same conditions the
unscented path ran out of GPU memory when only ~5 GiB was free, while the linear path
ran without issues. For models with more latent factors both advantages grow: the
sigma-point count scales as $2n + 1$ and the QR matrix shrinks from $(3n + 1) \times n$
to $(2n) \times n$.
1 change: 1 addition & 0 deletions docs/myst.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ project:
children:
- file: explanations/names_and_concepts.md
- file: explanations/notes_on_factor_scales.md
- file: explanations/linear_predict.md
- title: Reference Guides
children:
- file: reference_guides/transition_functions.md
Expand Down
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

145 changes: 144 additions & 1 deletion src/skillmodels/kalman_filters.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
"""Kalman filter operations for state estimation using the square-root form."""

from collections.abc import Callable
from collections.abc import Callable, Mapping

import jax
import jax.numpy as jnp
from jax import Array

from skillmodels.qr import qr_gpu

LINEAR_FUNCTION_NAMES = frozenset({"linear", "constant"})


def is_all_linear(function_names: Mapping[str, str]) -> bool:
"""Return True if every factor uses a linear or constant transition function."""
return all(name in LINEAR_FUNCTION_NAMES for name in function_names.values())


array_qr_jax = (
jax.vmap(jax.vmap(qr_gpu))
if jax.default_backend() == "gpu"
Expand Down Expand Up @@ -228,6 +236,141 @@ def kalman_predict(
return predicted_states, predicted_covs


def linear_kalman_predict(
transition_func: Callable | None, # noqa: ARG001
states: Array,
upper_chols: Array,
sigma_scaling_factor: float, # noqa: ARG001
sigma_weights: Array, # noqa: ARG001
trans_coeffs: dict[str, Array],
shock_sds: Array,
anchoring_scaling_factors: Array,
anchoring_constants: Array,
observed_factors: Array,
*,
latent_factors: tuple[str, ...],
constant_factor_indices: frozenset[int],
n_all_factors: int,
) -> tuple[Array, Array]:
"""Make a linear Kalman predict (square-root form).

Much cheaper than the unscented predict because it avoids sigma point
generation and transformation. Only valid when every factor uses a `linear`
or `constant` transition function.

The positional parameters `transition_func`, `sigma_scaling_factor` and
`sigma_weights` are accepted for signature compatibility with
`kalman_predict` but are ignored.

Args:
transition_func: Ignored (kept for signature compatibility).
states: Array of shape (n_obs, n_mixtures, n_states).
upper_chols: Array of shape (n_obs, n_mixtures, n_states, n_states).
sigma_scaling_factor: Ignored.
sigma_weights: Ignored.
trans_coeffs: Dict mapping factor name to 1d coefficient array.
shock_sds: 1d array of length n_states.
anchoring_scaling_factors: Array of shape (2, n_states).
anchoring_constants: Array of shape (2, n_states).
observed_factors: Array of shape (n_obs, n_observed_factors).
latent_factors: Tuple of latent factor names.
constant_factor_indices: Indices of factors with `constant` transition.
n_all_factors: Total number of factors (latent + observed).

Returns:
Predicted states, same shape as states.
Predicted upper_chols, same shape as upper_chols.

"""
n_latent = len(latent_factors)

f_mat, c_vec = _build_f_and_c(
latent_factors, constant_factor_indices, n_all_factors, trans_coeffs
)

s_in = anchoring_scaling_factors[0][:n_latent] # (n_latent,) for input period
s_out = anchoring_scaling_factors[1][:n_latent] # (n_latent,) for output period
c_in = anchoring_constants[0][:n_latent] # (n_latent,)
c_out = anchoring_constants[1][:n_latent] # (n_latent,)

# Mean prediction
anchored_states = states * s_in + c_in # (n_obs, n_mix, n_latent)
# Concatenate with observed factors to get full state vector
n_obs, n_mix, _ = states.shape
obs_expanded = jnp.broadcast_to(
observed_factors[:, jnp.newaxis, :], (n_obs, n_mix, observed_factors.shape[1])
)
full_states = jnp.concatenate([anchored_states, obs_expanded], axis=-1)

predicted_anchored = full_states @ f_mat.T + c_vec # (n_obs, n_mix, n_latent)
predicted_states = (predicted_anchored - c_out) / s_out

# Covariance prediction (square-root form)
# G = diag(1/s_out) @ F_latent @ diag(s_in) where F_latent is the first
# n_latent columns of F
f_latent = f_mat[:, :n_latent] # (n_latent, n_latent)
g_mat = (f_latent * s_in) / s_out[:, jnp.newaxis] # (n_latent, n_latent)

# Stack: [upper_chol @ G.T ; diag(shock_sds / s_out)]
chol_g = upper_chols @ g_mat.T # (n_obs, n_mix, n_latent, n_latent)
shock_diag = jnp.diag(shock_sds / s_out) # (n_latent, n_latent)

stack = jnp.concatenate(
[chol_g, jnp.broadcast_to(shock_diag, chol_g.shape)], axis=-2
) # (n_obs, n_mix, 2*n_latent, n_latent)

predicted_covs = array_qr_jax(stack)[1][:, :, :n_latent]

return predicted_states, predicted_covs


def _build_f_and_c(
latent_factors: tuple[str, ...],
constant_factor_indices: frozenset[int],
n_all_factors: int,
trans_coeffs: dict[str, Array],
) -> tuple[Array, Array]:
"""Build F matrix and c vector from transition coefficients.

Stack all coefficient arrays, build identity rows for constant factors,
and select via a boolean mask.

Args:
latent_factors: Tuple of latent factor names.
constant_factor_indices: Indices of factors with `constant` transition.
n_all_factors: Total number of factors (latent + observed).
trans_coeffs: Dict mapping factor name to 1d coefficient array.

Returns:
f_mat: Array of shape (n_latent, n_all_factors).
c_vec: Array of shape (n_latent,).

"""
n_latent = len(latent_factors)
identity = jnp.eye(n_latent, n_all_factors)

# Will be of shape (n_latent, n_all+1)
all_coeffs = jnp.stack(
[
trans_coeffs[f]
if i not in constant_factor_indices
else jnp.zeros(n_all_factors + 1)
for i, f in enumerate(latent_factors)
]
)

f_from_coeffs = all_coeffs[:, :-1]
c_from_coeffs = all_coeffs[:, -1]

is_constant = jnp.array([i in constant_factor_indices for i in range(n_latent)])
mask = is_constant[:, None] # (n_latent, 1)

f_mat = jnp.where(mask, identity, f_from_coeffs)
c_vec = jnp.where(is_constant, 0.0, c_from_coeffs)

return f_mat, c_vec


def _calculate_sigma_points(
states: Array,
upper_chols: Array,
Expand Down
Loading
Loading