diff --git a/docs/explanations/linear_predict.md b/docs/explanations/linear_predict.md new file mode 100644 index 0000000..bd9de52 --- /dev/null +++ b/docs/explanations/linear_predict.md @@ -0,0 +1,135 @@ +# Linear predict optimization + +## When the linear predict is used + +At model setup time, `is_all_linear` checks whether every latent factor's transition +function name belongs to `{"linear", "constant"}`. This is an all-or-nothing decision: +if even one factor uses a nonlinear transition (e.g. `translog`), the entire model falls +back to the unscented predict. + +The check happens in `get_maximization_inputs`, where the predict function is selected +via `functools.partial`. When the linear path is chosen, extra keyword arguments +(`latent_factors`, `constant_factor_indices`, `n_all_factors`) are bound at setup time so +the predict function has the same call signature as the unscented variant. + +## Why it is faster and uses less memory + +The unscented predict generates $2n + 1$ sigma points (where $n$ is the number of latent +factors), transforms each one through the transition function, then recovers predicted +means and covariances from weighted statistics. Its QR decomposition operates on a matrix +of shape $(3n + 1) \times n$: the $2n + 1$ weighted deviation rows plus $n$ rows for the +shock standard deviations. + +The linear predict skips sigma-point generation entirely. Because the transition is +linear, the predicted mean is just a matrix--vector product, and the predicted covariance +follows from the standard linear Gaussian formula. Its QR decomposition operates on a +$(2n) \times n$ matrix: $n$ rows from the propagated Cholesky factor and $n$ rows for the +shocks. The reduction from $3n + 1$ to $2n$ rows speeds up the QR step and removes all +sigma-point overhead. + +The memory savings can be more important than the speed gains. The unscented path +materialises $2n + 1$ sigma points for every observation and mixture component, and +JAX's automatic differentiation retains intermediate buffers for the backward pass. The +linear path replaces all of this with a single matrix multiply whose memory footprint +scales with $n^2$ rather than with the number of sigma points times the number of +observations. On memory-constrained GPUs this can be the difference between fitting the +model and running out of memory. + +## Building F and c + +The linear predict assembles a transition matrix $F$ of shape +$(n_\text{latent}, n_\text{all})$ and a constant vector $c$ of length $n_\text{latent}$ +from the `trans_coeffs` dictionary. Here $n_\text{all}$ includes both latent and observed +factors. + +For each latent factor $i$: + +- **Linear factor**: `trans_coeffs[factor]` is a 1-d array whose last element is the + intercept and whose preceding elements are the coefficients on all factors (latent and + observed). Row $i$ of $F$ is set to `coeffs[:-1]` and $c_i$ is set to `coeffs[-1]`. +- **Constant factor**: row $i$ of $F$ is the unit vector $e_i$ (identity row) and + $c_i = 0$, so the factor value is simply carried forward. + +The implementation uses a stack-then-mask approach: all coefficient arrays are stacked +into a single matrix (with zero-padded rows for constant factors), an identity matrix +provides the constant-factor rows, and `jnp.where` selects between them using a boolean +mask. This avoids per-element `.at[i].set()` calls and conditional branching, producing +a cleaner trace for JAX's compiler. + +Three construction strategies were benchmarked (loop with conditional `.at[i].set()`, +stack-then-mask with `jnp.where`, and index-scatter with pre-separated sub-matrices). +All three produced identical XLA graphs and showed no meaningful runtime difference +(~6.3--6.7 ms per call on CPU, 4-factor model, 5000 observations), confirming that the +construction is fully resolved at trace time. The stack-then-mask variant was kept for +its cleaner, more idiomatic JAX style. + +## Mean prediction + +The mean prediction incorporates anchoring, which rescales factors to a common metric +across periods. Let $s^{\text{in}}$ and $c^{\text{in}}$ be the input-period scaling +factors and constants, and $s^{\text{out}}$ and $c^{\text{out}}$ the output-period +counterparts. The steps are: + +1. **Anchor** the input states: $x^a = x \odot s^{\text{in}} + c^{\text{in}}$. +2. **Concatenate** observed factors to form the full state vector + $\tilde{x} = [x^a, x^{\text{obs}}]$. +3. **Apply the linear transition**: $y^a = \tilde{x}\, F^\top + c$. +4. **Un-anchor** to get the predicted states: + $\hat{x} = (y^a - c^{\text{out}}) \oslash s^{\text{out}}$. + +## Covariance prediction (square-root form) + +skillmodels maintains covariances in square-root (upper Cholesky) form throughout. Let +$R$ denote the current upper Cholesky factor so that $P = R^\top R$. The linear predict +propagates $R$ as follows. + +Define the effective transition matrix + +$$ +G = \operatorname{diag}(1 / s^{\text{out}})\; F_{\text{latent}}\; + \operatorname{diag}(s^{\text{in}}) +$$ + +where $F_{\text{latent}}$ is the first $n_\text{latent}$ columns of $F$ (the columns +corresponding to latent factors). $G$ folds the anchoring scales into the transition so +that the covariance update works directly in the un-anchored (internal) scale. + +The predicted covariance satisfies + +$$ +\hat{P} = G\, P\, G^\top + Q +$$ + +where $Q = \operatorname{diag}(\sigma / s^{\text{out}})^2$ and $\sigma$ is the vector of +shock standard deviations. In square-root form, the upper Cholesky factor $\hat{R}$ of +$\hat{P}$ is obtained via a single QR decomposition of the stacked matrix + +$$ +S = \begin{bmatrix} R\, G^\top \\ \operatorname{diag}(\sigma / s^{\text{out}}) + \end{bmatrix} +$$ + +which has shape $(2n) \times n$. The upper-triangular $R$-factor of $S$ (its first $n$ +rows) gives $\hat{R}$. + +## Observed factors + +Observed factors (e.g. investment measures whose values are known from data) appear as +columns in $F$ and therefore influence the predicted mean through the matrix--vector +product. However, they carry no uncertainty: their columns are excluded from the +covariance propagation. This is why $G$ uses only the first $n_\text{latent}$ columns of +$F$ rather than the full matrix. + +## Practical impact + +Benchmarks on a 4-factor linear model (`health-cognition`, +`no_feedback_to_investments_linear`, 8 GiB GPU) show a modest ~6 % speed-up on GPU +(8.4 vs 8.9 s per optimizer iteration) and negligible difference on CPU. The speed gain +is small because with only 4 latent factors the unscented transform generates just 9 +sigma points — a trivially cheap operation on modern hardware. + +The memory reduction is the more significant benefit. Under the same conditions the +unscented path ran out of GPU memory when only ~5 GiB was free, while the linear path +ran without issues. For models with more latent factors both advantages grow: the +sigma-point count scales as $2n + 1$ and the QR matrix shrinks from $(3n + 1) \times n$ +to $(2n) \times n$. diff --git a/docs/myst.yml b/docs/myst.yml index 9050842..0f443a4 100644 --- a/docs/myst.yml +++ b/docs/myst.yml @@ -40,6 +40,7 @@ project: children: - file: explanations/names_and_concepts.md - file: explanations/notes_on_factor_scales.md + - file: explanations/linear_predict.md - title: Reference Guides children: - file: reference_guides/transition_functions.md diff --git a/pixi.lock b/pixi.lock index 031ce22..1129890 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9497,7 +9497,7 @@ packages: timestamp: 1753199211006 - pypi: ./ name: skillmodels - version: 0.0.24.dev302+g350c3b9e5.d20260318 + version: 0.0.24.dev308+g220a9b077.d20260318 sha256: c3e259ba2e68a9ccf4593eecfac180a8d2db8d26b757a19d3784ca917060e6f6 requires_dist: - dags>=0.5.1 diff --git a/src/skillmodels/kalman_filters.py b/src/skillmodels/kalman_filters.py index 069fa9d..7f9549f 100644 --- a/src/skillmodels/kalman_filters.py +++ b/src/skillmodels/kalman_filters.py @@ -1,6 +1,6 @@ """Kalman filter operations for state estimation using the square-root form.""" -from collections.abc import Callable +from collections.abc import Callable, Mapping import jax import jax.numpy as jnp @@ -8,6 +8,14 @@ from skillmodels.qr import qr_gpu +LINEAR_FUNCTION_NAMES = frozenset({"linear", "constant"}) + + +def is_all_linear(function_names: Mapping[str, str]) -> bool: + """Return True if every factor uses a linear or constant transition function.""" + return all(name in LINEAR_FUNCTION_NAMES for name in function_names.values()) + + array_qr_jax = ( jax.vmap(jax.vmap(qr_gpu)) if jax.default_backend() == "gpu" @@ -228,6 +236,141 @@ def kalman_predict( return predicted_states, predicted_covs +def linear_kalman_predict( + transition_func: Callable | None, # noqa: ARG001 + states: Array, + upper_chols: Array, + sigma_scaling_factor: float, # noqa: ARG001 + sigma_weights: Array, # noqa: ARG001 + trans_coeffs: dict[str, Array], + shock_sds: Array, + anchoring_scaling_factors: Array, + anchoring_constants: Array, + observed_factors: Array, + *, + latent_factors: tuple[str, ...], + constant_factor_indices: frozenset[int], + n_all_factors: int, +) -> tuple[Array, Array]: + """Make a linear Kalman predict (square-root form). + + Much cheaper than the unscented predict because it avoids sigma point + generation and transformation. Only valid when every factor uses a `linear` + or `constant` transition function. + + The positional parameters `transition_func`, `sigma_scaling_factor` and + `sigma_weights` are accepted for signature compatibility with + `kalman_predict` but are ignored. + + Args: + transition_func: Ignored (kept for signature compatibility). + states: Array of shape (n_obs, n_mixtures, n_states). + upper_chols: Array of shape (n_obs, n_mixtures, n_states, n_states). + sigma_scaling_factor: Ignored. + sigma_weights: Ignored. + trans_coeffs: Dict mapping factor name to 1d coefficient array. + shock_sds: 1d array of length n_states. + anchoring_scaling_factors: Array of shape (2, n_states). + anchoring_constants: Array of shape (2, n_states). + observed_factors: Array of shape (n_obs, n_observed_factors). + latent_factors: Tuple of latent factor names. + constant_factor_indices: Indices of factors with `constant` transition. + n_all_factors: Total number of factors (latent + observed). + + Returns: + Predicted states, same shape as states. + Predicted upper_chols, same shape as upper_chols. + + """ + n_latent = len(latent_factors) + + f_mat, c_vec = _build_f_and_c( + latent_factors, constant_factor_indices, n_all_factors, trans_coeffs + ) + + s_in = anchoring_scaling_factors[0][:n_latent] # (n_latent,) for input period + s_out = anchoring_scaling_factors[1][:n_latent] # (n_latent,) for output period + c_in = anchoring_constants[0][:n_latent] # (n_latent,) + c_out = anchoring_constants[1][:n_latent] # (n_latent,) + + # Mean prediction + anchored_states = states * s_in + c_in # (n_obs, n_mix, n_latent) + # Concatenate with observed factors to get full state vector + n_obs, n_mix, _ = states.shape + obs_expanded = jnp.broadcast_to( + observed_factors[:, jnp.newaxis, :], (n_obs, n_mix, observed_factors.shape[1]) + ) + full_states = jnp.concatenate([anchored_states, obs_expanded], axis=-1) + + predicted_anchored = full_states @ f_mat.T + c_vec # (n_obs, n_mix, n_latent) + predicted_states = (predicted_anchored - c_out) / s_out + + # Covariance prediction (square-root form) + # G = diag(1/s_out) @ F_latent @ diag(s_in) where F_latent is the first + # n_latent columns of F + f_latent = f_mat[:, :n_latent] # (n_latent, n_latent) + g_mat = (f_latent * s_in) / s_out[:, jnp.newaxis] # (n_latent, n_latent) + + # Stack: [upper_chol @ G.T ; diag(shock_sds / s_out)] + chol_g = upper_chols @ g_mat.T # (n_obs, n_mix, n_latent, n_latent) + shock_diag = jnp.diag(shock_sds / s_out) # (n_latent, n_latent) + + stack = jnp.concatenate( + [chol_g, jnp.broadcast_to(shock_diag, chol_g.shape)], axis=-2 + ) # (n_obs, n_mix, 2*n_latent, n_latent) + + predicted_covs = array_qr_jax(stack)[1][:, :, :n_latent] + + return predicted_states, predicted_covs + + +def _build_f_and_c( + latent_factors: tuple[str, ...], + constant_factor_indices: frozenset[int], + n_all_factors: int, + trans_coeffs: dict[str, Array], +) -> tuple[Array, Array]: + """Build F matrix and c vector from transition coefficients. + + Stack all coefficient arrays, build identity rows for constant factors, + and select via a boolean mask. + + Args: + latent_factors: Tuple of latent factor names. + constant_factor_indices: Indices of factors with `constant` transition. + n_all_factors: Total number of factors (latent + observed). + trans_coeffs: Dict mapping factor name to 1d coefficient array. + + Returns: + f_mat: Array of shape (n_latent, n_all_factors). + c_vec: Array of shape (n_latent,). + + """ + n_latent = len(latent_factors) + identity = jnp.eye(n_latent, n_all_factors) + + # Will be of shape (n_latent, n_all+1) + all_coeffs = jnp.stack( + [ + trans_coeffs[f] + if i not in constant_factor_indices + else jnp.zeros(n_all_factors + 1) + for i, f in enumerate(latent_factors) + ] + ) + + f_from_coeffs = all_coeffs[:, :-1] + c_from_coeffs = all_coeffs[:, -1] + + is_constant = jnp.array([i in constant_factor_indices for i in range(n_latent)]) + mask = is_constant[:, None] # (n_latent, 1) + + f_mat = jnp.where(mask, identity, f_from_coeffs) + c_vec = jnp.where(is_constant, 0.0, c_from_coeffs) + + return f_mat, c_vec + + def _calculate_sigma_points( states: Array, upper_chols: Array, diff --git a/src/skillmodels/likelihood_function.py b/src/skillmodels/likelihood_function.py index 5d6ec5c..ee89f01 100644 --- a/src/skillmodels/likelihood_function.py +++ b/src/skillmodels/likelihood_function.py @@ -9,10 +9,7 @@ from jax import Array from skillmodels.clipping import soft_clipping -from skillmodels.kalman_filters import ( - kalman_predict, - kalman_update, -) +from skillmodels.kalman_filters import kalman_update from skillmodels.parse_params import parse_params from skillmodels.types import ( Dimensions, @@ -28,7 +25,7 @@ def log_likelihood( parsing_info: ParsingInfo, measurements: Array, controls: Array, - transition_func: Callable, + predict_func: Callable, sigma_scaling_factor: float, sigma_weights: Array, dimensions: Dimensions, @@ -50,7 +47,9 @@ def log_likelihood( observed measurements. NaN if the measurement was not observed. controls: Array of shape (n_periods, n_obs, n_controls) with observed control variables for the measurement equations. - transition_func: The transition function. + predict_func: Callable that performs the predict step. Either + `kalman_predict` (partialed with `transition_func`) or + `linear_kalman_predict` (partialed with factor metadata). sigma_scaling_factor: A scaling factor that controls the spread of the sigma points. sigma_weights: 1d array of length n_sigma with non-negative sigma weights. @@ -76,7 +75,7 @@ def log_likelihood( parsing_info=parsing_info, measurements=measurements, controls=controls, - transition_func=transition_func, + predict_func=predict_func, sigma_scaling_factor=sigma_scaling_factor, sigma_weights=sigma_weights, dimensions=dimensions, @@ -94,7 +93,7 @@ def log_likelihood_obs( parsing_info: ParsingInfo, measurements: Array, controls: Array, - transition_func: Callable, + predict_func: Callable, sigma_scaling_factor: float, sigma_weights: Array, dimensions: Dimensions, @@ -123,7 +122,9 @@ def log_likelihood_obs( observed measurements. NaN if the measurement was not observed. controls: Array of shape (n_periods, n_obs, n_controls) with observed control variables for the measurement equations. - transition_func: The transition function. + predict_func: Callable that performs the predict step. Either + `kalman_predict` (partialed with `transition_func`) or + `linear_kalman_predict` (partialed with factor metadata). sigma_scaling_factor: A scaling factor that controls the spread of the sigma points. Bigger means that sigma points are further apart. Depends on the sigma_point algorithm chosen. @@ -177,7 +178,7 @@ def log_likelihood_obs( parsed_params=parsed_params, sigma_scaling_factor=sigma_scaling_factor, sigma_weights=sigma_weights, - transition_func=transition_func, + predict_func=predict_func, observed_factors=observed_factors, ) _body = jax.checkpoint(_body, prevent_cse=False) @@ -201,7 +202,7 @@ def _scan_body( parsed_params: ParsedParams, sigma_scaling_factor: float, sigma_weights: Array, - transition_func: Callable, + predict_func: Callable, observed_factors: Array, ) -> tuple[dict[str, Array], dict[str, Array]]: # ================================================================================== @@ -250,7 +251,7 @@ def _scan_body( "observed_factors": observed_factors[t], } - fixed_kwargs = {"transition_func": transition_func} + fixed_kwargs = {"predict_func": predict_func} # ================================================================================== # Do a predict step or a do-nothing fake predict step @@ -292,7 +293,7 @@ def _one_arg_anchoring_update( def _one_arg_no_predict( kwargs: dict[str, Any], - transition_func: Callable, # noqa: ARG001 + predict_func: Callable, # noqa: ARG001 ) -> tuple[Array, Array, Array]: """Just return the states cond chols without any changes.""" return kwargs["states"], kwargs["upper_chols"], kwargs["states"] @@ -300,11 +301,8 @@ def _one_arg_no_predict( def _one_arg_predict( kwargs: dict[str, Any], - transition_func: Callable, + predict_func: Callable, ) -> tuple[Array, Array, Array]: """Do a predict step but also return the input states as filtered states.""" - new_states, new_upper_chols = kalman_predict( - transition_func, - **kwargs, - ) + new_states, new_upper_chols = predict_func(**kwargs) return new_states, new_upper_chols, kwargs["states"] diff --git a/src/skillmodels/likelihood_function_debug.py b/src/skillmodels/likelihood_function_debug.py index 203d90f..6391695 100644 --- a/src/skillmodels/likelihood_function_debug.py +++ b/src/skillmodels/likelihood_function_debug.py @@ -9,7 +9,6 @@ from jax import Array from skillmodels.clipping import soft_clipping -from skillmodels.kalman_filters import kalman_predict from skillmodels.kalman_filters_debug import kalman_update from skillmodels.parse_params import parse_params from skillmodels.types import ( @@ -26,7 +25,7 @@ def log_likelihood( parsing_info: ParsingInfo, measurements: Array, controls: Array, - transition_func: Callable[..., Array], + predict_func: Callable[..., tuple[Array, Array]], sigma_scaling_factor: float, sigma_weights: Array, dimensions: Dimensions, @@ -49,7 +48,7 @@ def log_likelihood( measurements. NaN if the measurement was not observed. controls: Array of shape (n_periods, n_obs, n_controls) with observed control variables for the measurement equations. - transition_func: The transition function. + predict_func: Callable that performs the predict step. sigma_scaling_factor: A scaling factor that controls the spread of the sigma points. Bigger means that sigma points are further apart. sigma_weights: 1d array of length n_sigma with non-negative sigma weights. @@ -102,7 +101,7 @@ def log_likelihood( parsed_params=parsed_params, sigma_scaling_factor=sigma_scaling_factor, sigma_weights=sigma_weights, - transition_func=transition_func, + predict_func=predict_func, observed_factors=observed_factors, ) @@ -154,7 +153,7 @@ def _scan_body( parsed_params: ParsedParams, sigma_scaling_factor: float, sigma_weights: Array, - transition_func: Callable[..., Array], + predict_func: Callable[..., tuple[Array, Array]], observed_factors: Array, ) -> tuple[dict[str, Array], dict[str, Any]]: # ================================================================================== @@ -203,7 +202,7 @@ def _scan_body( "observed_factors": observed_factors[t], } - fixed_kwargs = {"transition_func": transition_func} + fixed_kwargs = {"predict_func": predict_func} # ================================================================================== # Do a predict step or a do-nothing fake predict step @@ -246,7 +245,7 @@ def _one_arg_anchoring_update( def _one_arg_no_predict( kwargs: dict[str, Any], - transition_func: Callable[..., Array], # noqa: ARG001 + predict_func: Callable[..., tuple[Array, Array]], # noqa: ARG001 ) -> tuple[Array, Array, Array]: """Just return the states cond chols without any changes.""" return kwargs["states"], kwargs["upper_chols"], kwargs["states"] @@ -254,11 +253,8 @@ def _one_arg_no_predict( def _one_arg_predict( kwargs: dict[str, Any], - transition_func: Callable[..., Array], + predict_func: Callable[..., tuple[Array, Array]], ) -> tuple[Array, Array, Array]: """Do a predict step but also return the input states as filtered states.""" - new_states, new_upper_chols = kalman_predict( - transition_func, - **kwargs, - ) + new_states, new_upper_chols = predict_func(**kwargs) return new_states, new_upper_chols, kwargs["states"] diff --git a/src/skillmodels/maximization_inputs.py b/src/skillmodels/maximization_inputs.py index 4d3dcd3..a01c55c 100644 --- a/src/skillmodels/maximization_inputs.py +++ b/src/skillmodels/maximization_inputs.py @@ -18,7 +18,12 @@ enforce_fixed_constraints, get_constraints, ) -from skillmodels.kalman_filters import calculate_sigma_scaling_factor_and_weights +from skillmodels.kalman_filters import ( + calculate_sigma_scaling_factor_and_weights, + is_all_linear, + kalman_predict, + linear_kalman_predict, +) from skillmodels.model_spec import ModelSpec from skillmodels.params_index import get_params_index from skillmodels.parse_params import create_parsing_info @@ -225,12 +230,28 @@ def _partial_some_log_likelihood( if max(iteration_to_period) != last_aug_period - 1: raise ValueError("Unexpected iteration_to_period configuration") + if is_all_linear(model.transition_info.function_names): + constant_factor_indices = frozenset( + i + for i, f in enumerate(model.labels.latent_factors) + if model.transition_info.function_names[f] == "constant" + ) + predict_func = functools.partial( + linear_kalman_predict, + model.transition_info.func, + latent_factors=model.labels.latent_factors, + constant_factor_indices=constant_factor_indices, + n_all_factors=model.dimensions.n_all_factors, + ) + else: + predict_func = functools.partial(kalman_predict, model.transition_info.func) + return functools.partial( fun, parsing_info=parsing_info, measurements=measurements, controls=controls, - transition_func=model.transition_info.func, + predict_func=predict_func, sigma_scaling_factor=sigma_scaling_factor, sigma_weights=sigma_weights, dimensions=model.dimensions, diff --git a/tests/test_kalman_filters.py b/tests/test_kalman_filters.py index eae8a12..9099696 100644 --- a/tests/test_kalman_filters.py +++ b/tests/test_kalman_filters.py @@ -15,6 +15,7 @@ calculate_sigma_scaling_factor_and_weights, kalman_predict, kalman_update, + linear_kalman_predict, transform_sigma_points, ) from skillmodels.kalman_filters_debug import kalman_update as kalman_update_debug @@ -232,6 +233,272 @@ def transition_function(params, states): aaae(calc_chols[0, 0].T @ calc_chols[0, 0], expected_cov) +@pytest.mark.parametrize("seed", SEEDS) +def test_linear_kalman_predict_against_filterpy(seed) -> None: + """Test linear_kalman_predict gives same result as filterpy's linear predict.""" + rng = np.random.default_rng(seed) + state, cov = _random_state_and_covariance(rng) + dim = len(state) + trans_mat = rng.uniform(low=-1, high=1, size=(dim, dim + 1)) + # last column is the constant + f_mat = trans_mat[:, :-1] + c_vec = trans_mat[:, -1] + + shock_sds = 0.5 * np.arange(dim) / max(dim, 1) + + fp_filter = KalmanFilter(dim_x=dim, dim_z=1) + fp_filter.x = state.reshape(dim, 1) + fp_filter.F = f_mat + fp_filter.B = np.eye(dim) + fp_filter.P = cov + fp_filter.Q = np.diag(shock_sds**2) + + fp_filter.predict(u=c_vec.reshape(dim, 1)) + expected_state = fp_filter.x + expected_cov = fp_filter.P + + sm_state, sm_chol = _convert_predict_inputs_from_filterpy_to_skillmodels(state, cov) + scaling_factor, weights = calculate_sigma_scaling_factor_and_weights(dim, 2) + + latent_factors = tuple(f"fac{i}" for i in range(dim)) + trans_coeffs = { + f"fac{i}": jnp.array(np.append(trans_mat[i, :-1], trans_mat[i, -1])) + for i in range(dim) + } + anch_scaling = jnp.ones((2, dim)) + anch_constants = jnp.zeros((2, dim)) + observed_factors = jnp.zeros((1, 0)) + + calc_states, calc_chols = linear_kalman_predict( + None, # transition_func (ignored) + sm_state, + sm_chol, + float(scaling_factor), + weights, + trans_coeffs, + jnp.array(shock_sds), + anch_scaling, + anch_constants, + observed_factors, + latent_factors=latent_factors, + constant_factor_indices=frozenset(), + n_all_factors=dim, + ) + + aaae(calc_states.flatten(), expected_state.flatten()) + aaae(calc_chols[0, 0].T @ calc_chols[0, 0], expected_cov) + + +@pytest.mark.parametrize("seed", SEEDS) +def test_linear_predict_matches_unscented_for_linear_model(seed) -> None: + """Linear predict should give identical results to unscented for linear models.""" + rng = np.random.default_rng(seed) + state, cov = _random_state_and_covariance(rng) + dim = len(state) + trans_mat = rng.uniform(low=-1, high=1, size=(dim, dim)) + + shock_sds = 0.5 * np.arange(dim) / max(dim, 1) + + def linear_func(params, states): + return jnp.dot(states, params) + + def transition_function(params, states): + return jnp.column_stack( + [linear_func(params[f"fac{i}"], states) for i in range(dim)] + ) + + sm_state, sm_chol = _convert_predict_inputs_from_filterpy_to_skillmodels(state, cov) + scaling_factor, weights = calculate_sigma_scaling_factor_and_weights(dim, 2) + # For unscented: trans_coeffs values are just the row of the transition matrix + trans_coeffs_unscented = {f"fac{i}": jnp.array(trans_mat[i]) for i in range(dim)} + # For linear: trans_coeffs values have constant appended (0 for pure linear) + trans_coeffs_linear = { + f"fac{i}": jnp.array(np.append(trans_mat[i], 0.0)) for i in range(dim) + } + anch_scaling = jnp.ones((2, dim)) + anch_constants = jnp.zeros((2, dim)) + observed_factors = jnp.zeros((1, 0)) + latent_factors = tuple(f"fac{i}" for i in range(dim)) + + unscented_states, unscented_chols = kalman_predict( + transition_function, + sm_state, + sm_chol, + float(scaling_factor), + weights, + trans_coeffs_unscented, + jnp.array(shock_sds), + anch_scaling, + anch_constants, + observed_factors, + ) + + linear_states, linear_chols = linear_kalman_predict( + None, + sm_state, + sm_chol, + float(scaling_factor), + weights, + trans_coeffs_linear, + jnp.array(shock_sds), + anch_scaling, + anch_constants, + observed_factors, + latent_factors=latent_factors, + constant_factor_indices=frozenset(), + n_all_factors=dim, + ) + + aaae(linear_states, unscented_states, decimal=5) + aaae( + linear_chols[0, 0].T @ linear_chols[0, 0], + unscented_chols[0, 0].T @ unscented_chols[0, 0], + decimal=5, + ) + + +def test_linear_predict_with_constant_factors() -> None: + """Test that constant factors produce identity rows in F.""" + rng = np.random.default_rng(42) + dim = 3 + state, cov = _random_state_and_covariance(rng, dim=dim) + shock_sds = np.array([0.1, 0.0, 0.2]) + + sm_state, sm_chol = _convert_predict_inputs_from_filterpy_to_skillmodels(state, cov) + scaling_factor, weights = calculate_sigma_scaling_factor_and_weights(dim, 2) + + # fac0: linear, fac1: constant, fac2: linear + trans_coeffs = { + "fac0": jnp.array([0.5, 0.3, 0.1, 0.2]), # 3 coeffs + constant + "fac1": jnp.array([]), # constant factor has no params + "fac2": jnp.array([0.1, 0.2, 0.8, -0.1]), + } + anch_scaling = jnp.ones((2, dim)) + anch_constants = jnp.zeros((2, dim)) + observed_factors = jnp.zeros((1, 0)) + latent_factors = ("fac0", "fac1", "fac2") + + calc_states, _calc_chols = linear_kalman_predict( + None, + sm_state, + sm_chol, + float(scaling_factor), + weights, + trans_coeffs, + jnp.array(shock_sds), + anch_scaling, + anch_constants, + observed_factors, + latent_factors=latent_factors, + constant_factor_indices=frozenset({1}), + n_all_factors=dim, + ) + + # fac1 (constant) should remain unchanged + aaae(calc_states[0, 0, 1], state[1]) + + # fac0 should be linear combination + constant + expected_fac0 = 0.5 * state[0] + 0.3 * state[1] + 0.1 * state[2] + 0.2 + aaae(calc_states[0, 0, 0], expected_fac0) + + +def test_linear_predict_with_observed_factors() -> None: + """Test that observed factors are used correctly as extra columns in F.""" + rng = np.random.default_rng(42) + n_latent = 2 + n_observed = 1 + state, cov = _random_state_and_covariance(rng, dim=n_latent) + shock_sds = np.array([0.1, 0.2]) + + sm_state, sm_chol = _convert_predict_inputs_from_filterpy_to_skillmodels(state, cov) + scaling_factor, weights = calculate_sigma_scaling_factor_and_weights(n_latent, 2) + + observed_val = 3.0 + observed_factors = jnp.array([[observed_val]]) + + # fac0 depends on both latent + observed, fac1 depends only on latent + trans_coeffs = { + "fac0": jnp.array([0.5, 0.3, 0.2, 0.1]), # 2 latent + 1 observed + constant + "fac1": jnp.array([0.1, 0.9, 0.0, 0.0]), + } + anch_scaling = jnp.ones((2, n_latent + n_observed)) + anch_constants = jnp.zeros((2, n_latent + n_observed)) + latent_factors = ("fac0", "fac1") + + calc_states, _calc_chols = linear_kalman_predict( + None, + sm_state, + sm_chol, + float(scaling_factor), + weights, + trans_coeffs, + jnp.array(shock_sds), + anch_scaling, + anch_constants, + observed_factors, + latent_factors=latent_factors, + constant_factor_indices=frozenset(), + n_all_factors=n_latent + n_observed, + ) + + expected_fac0 = 0.5 * state[0] + 0.3 * state[1] + 0.2 * observed_val + 0.1 + expected_fac1 = 0.1 * state[0] + 0.9 * state[1] + 0.0 * observed_val + 0.0 + aaae(calc_states[0, 0, 0], expected_fac0) + aaae(calc_states[0, 0, 1], expected_fac1) + + +def test_linear_predict_with_wide_anchoring_arrays() -> None: + """Regression: anchoring arrays have n_all columns, not just n_latent. + + At runtime, `parse_params` produces anchoring arrays of shape + `(n_aug_periods, n_all_factors)` — latent columns followed by observed-factor + columns (scaling=1, constant=0). This test uses that shape to verify + `linear_kalman_predict` slices correctly. + """ + rng = np.random.default_rng(42) + n_latent = 2 + n_observed = 1 + n_all = n_latent + n_observed + state, cov = _random_state_and_covariance(rng, dim=n_latent) + shock_sds = np.array([0.1, 0.2]) + + sm_state, sm_chol = _convert_predict_inputs_from_filterpy_to_skillmodels(state, cov) + scaling_factor, weights = calculate_sigma_scaling_factor_and_weights(n_latent, 2) + + observed_val = 3.0 + observed_factors = jnp.array([[observed_val]]) + + trans_coeffs = { + "fac0": jnp.array([0.5, 0.3, 0.2, 0.1]), + "fac1": jnp.array([0.1, 0.9, 0.0, 0.0]), + } + # Shape (2, n_all) — matches what parse_params returns at runtime + anch_scaling = jnp.ones((2, n_all)) + anch_constants = jnp.zeros((2, n_all)) + latent_factors = ("fac0", "fac1") + + calc_states, _calc_chols = linear_kalman_predict( + None, + sm_state, + sm_chol, + float(scaling_factor), + weights, + trans_coeffs, + jnp.array(shock_sds), + anch_scaling, + anch_constants, + observed_factors, + latent_factors=latent_factors, + constant_factor_indices=frozenset(), + n_all_factors=n_all, + ) + + expected_fac0 = 0.5 * state[0] + 0.3 * state[1] + 0.2 * observed_val + 0.1 + expected_fac1 = 0.1 * state[0] + 0.9 * state[1] + 0.0 * observed_val + 0.0 + aaae(calc_states[0, 0, 0], expected_fac0) + aaae(calc_states[0, 0, 1], expected_fac1) + + def _random_state_and_covariance(rng, dim=None): if dim is None: dim = rng.integers(low=1, high=10)