From c9e3ce4cfde9ec1bad2720c4be381d0be13cbc67 Mon Sep 17 00:00:00 2001 From: feiyang Date: Fri, 27 Feb 2026 21:13:41 +0000 Subject: [PATCH 01/19] feat: add linfa-residual-sequence crate Implements ResidualSequence Struct and StackWith trait for composing regression models in a boosting / residual-stacking pattern. The second (and any further) model trains on the residuals left by the previous one; predictions are summed. Docs and tests were written with AI assistance. --- algorithms/linfa-residual-sequence/Cargo.toml | 18 + algorithms/linfa-residual-sequence/src/lib.rs | 337 ++++++++++++++++++ 2 files changed, 355 insertions(+) create mode 100644 algorithms/linfa-residual-sequence/Cargo.toml create mode 100644 algorithms/linfa-residual-sequence/src/lib.rs diff --git a/algorithms/linfa-residual-sequence/Cargo.toml b/algorithms/linfa-residual-sequence/Cargo.toml new file mode 100644 index 000000000..2a1641e9f --- /dev/null +++ b/algorithms/linfa-residual-sequence/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "linfa-residual-sequence" +version = "0.8.1" +edition = "2018" +description = "Model composition utilities for the linfa ML framework" +license = "MIT OR Apache-2.0" +repository = "https://github.com/rust-ml/linfa" +keywords = ["machine-learning", "linfa", "ai", "ml", "residual"] +categories = ["algorithms", "mathematics", "science"] + +[dependencies] +linfa = { version = "0.8.1", path = "../.." } +ndarray = { version = "0.16" } +thiserror = "2.0" + +[dev-dependencies] +linfa-linear = { path = "../linfa-linear" } +linfa-svm = { path = "../linfa-svm" } diff --git a/algorithms/linfa-residual-sequence/src/lib.rs b/algorithms/linfa-residual-sequence/src/lib.rs new file mode 100644 index 000000000..4a4fa4723 --- /dev/null +++ b/algorithms/linfa-residual-sequence/src/lib.rs @@ -0,0 +1,337 @@ +//! Residual sequence model composition for the linfa ML framework. +//! +//! This crate provides [`ResidualSequence`], which fits models sequentially on +//! the residuals of the previous one. Chain as many as you like via [`StackWith`]: +//! +//! 1. Fit `first` on `(X, Y)` +//! 2. Compute residuals: `R = Y - first.predict(X)` +//! 3. Fit `second` on `(X, R)` +//! 4. Repeat for any further models stacked on top +//! +//! When predicting, all models' outputs are summed. +//! +//! This is the foundation of boosting / residual stacking. +//! +//! # Examples +//! +//! ## Linear + linear +//! +//! Two [`linfa_linear::LinearRegression`] models stacked: the second fits the +//! residuals left by the first. +//! +//! ``` +//! use linfa::traits::{Fit, Predict}; +//! use linfa::DatasetBase; +//! use linfa_linear::LinearRegression; +//! use linfa_residual_sequence::{ResidualSequence, StackWith}; +//! use ndarray::{array, Array2}; +//! +//! // y = 2x: perfectly linear, so the second model should see zero residuals. +//! let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +//! let y = array![0., 2., 4., 6., 8.]; +//! let dataset = DatasetBase::new(x.clone(), y); +//! +//! let fitted = LinearRegression::default() +//! .stack_with(LinearRegression::default()) +//! .fit(&dataset) +//! .unwrap(); +//! +//! let _preds = fitted.predict(&x); +//! ``` +//! +//! ## The second model learns nothing when the first fits perfectly +//! +//! If the first model already captures the data exactly, the residuals are all +//! zero and the second model has nothing to learn — its parameters come out +//! at (or very near) zero. +//! +//! ``` +//! use linfa::traits::{Fit, Predict}; +//! use linfa::DatasetBase; +//! use linfa_linear::LinearRegression; +//! use linfa_residual_sequence::StackWith; +//! use ndarray::{array, Array2}; +//! +//! // y = 2x: one linear model is enough to fit this perfectly. +//! let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +//! let y = array![0., 2., 4., 6., 8.]; +//! let dataset = DatasetBase::new(x.clone(), y); +//! +//! let fitted = LinearRegression::default() +//! .stack_with(LinearRegression::default()) +//! .fit(&dataset) +//! .unwrap(); +//! +//! // The second model trained on zero residuals — nothing left to correct. +//! assert!(fitted.second.params().iter().all(|&c: &f64| c.abs() < 1e-10)); +//! assert!(fitted.second.intercept().abs() < 1e-10); +//! ``` +//! +//! ## Chained SVMs and linear regression +//! +//! A linear-kernel [`linfa_svm::Svm`] captures the overall trend; two +//! Gaussian-kernel SVMs and a [`linfa_linear::LinearRegression`] then fit +//! successive residuals in a four-model chain. +//! +//! ``` +//! use linfa::traits::{Fit, Predict}; +//! use linfa::DatasetBase; +//! use linfa_linear::LinearRegression; +//! use linfa_residual_sequence::{ResidualSequence, StackWith}; +//! use linfa_svm::Svm; +//! use ndarray::Array; +//! +//! // y = sin(x): the linear SVM captures the slope; the RBF SVM captures +//! // the curvature left in the residuals. +//! let x = Array::linspace(0f64, 6., 20) +//! .into_shape_with_order((20, 1)) +//! .unwrap(); +//! let y = x.column(0).mapv(f64::sin); +//! let dataset = DatasetBase::new(x.clone(), y); +//! +//! let fitted = Svm::::params() +//! .c_svr(1., None) +//! .linear_kernel() +//! .stack_with( +//! Svm::::params() +//! .c_svr(10., Some(0.1)) +//! .gaussian_kernel(1.), +//! ) +//! .stack_with(LinearRegression::default()) +//! .stack_with( +//! Svm::::params() +//! .c_svr(10., Some(0.1)) +//! .gaussian_kernel(3.), +//! ) +//! .fit(&dataset) +//! .unwrap(); +//! +//! let _preds = fitted.predict(&x); +//! ``` + +use linfa::dataset::{AsTargets, DatasetBase, Records}; +use linfa::traits::{Fit, Predict}; +use ndarray::{Array1, ArrayBase, Data, Ix1, Ix2, RawDataClone}; +use std::ops::{Add, Sub}; + +type Arr2 = ArrayBase; + +/// Error returned by [`ResidualSequence::fit`]. +/// +/// Wraps the error from whichever of the two model fits failed, keeping them +/// distinguishable without requiring both models to share the same error type. +#[derive(Debug, thiserror::Error)] +pub enum ResidualSequenceError { + #[error("first model: {0}")] + First(E1), + #[error("second model: {0}")] + Second(E2), + // Satisfies the `Fit` trait's `E: From` bound. + #[error(transparent)] + Linfa(#[from] linfa::error::Error), +} + +/// Fits two models sequentially on the residuals of the first. +/// +/// `first` is fit on the original dataset. `second` is fit on the residuals +/// `Y - first.predict(X)`. See the [crate-level docs](crate) for details. +#[derive(Debug, Clone)] +pub struct ResidualSequence { + pub first: F1, + pub second: F2, +} + +/// Extension trait that lets any model params type be composed into a [`ResidualSequence`]. +/// +/// # Example +/// +/// ``` +/// use linfa::traits::Fit; +/// use linfa::DatasetBase; +/// use linfa_linear::LinearRegression; +/// use linfa_residual_sequence::StackWith; +/// use ndarray::{array, Array2}; +/// +/// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +/// let y = array![0., 2., 4., 6., 8.]; +/// let dataset = DatasetBase::new(x.clone(), y); +/// +/// let fitted = LinearRegression::default() +/// .stack_with(LinearRegression::default()) +/// .fit(&dataset) +/// .unwrap(); +/// ``` +pub trait StackWith: Sized { + fn stack_with(self, second: F2) -> ResidualSequence; +} + +impl StackWith for F1 { + fn stack_with(self, second: F2) -> ResidualSequence { + ResidualSequence { + first: self, + second, + } + } +} + +/// Two fitted models produced by [`ResidualSequence::fit`]. +/// +/// Predicts by summing both models' outputs: `first.predict(X) + second.predict(X)`. +#[derive(Debug, Clone)] +pub struct FittedResidualSequence { + pub first: R1, + pub second: R2, +} + +impl Fit, T, ResidualSequenceError> + for ResidualSequence +where + D: Data + RawDataClone, + D::Elem: Copy + Sub, + Arr2: Records, + F1: Fit, T, E1>, + for<'a> F1::Object: Predict<&'a Arr2, Array1>, + F2: Fit, Array1, E2>, + T: AsTargets, + E1: std::error::Error + From, + E2: std::error::Error + From, +{ + type Object = FittedResidualSequence; + + fn fit( + &self, + dataset: &DatasetBase, T>, + ) -> Result> { + let first = self + .first + .fit(dataset) + .map_err(ResidualSequenceError::First)?; + + let y_pred = first.predict(dataset.records()); + let residuals = dataset + .targets() + .as_targets() + .iter() + .zip(y_pred.iter()) + .map(|(y, p)| *y - *p) + .collect::>(); + + let residual_dataset = DatasetBase::new(dataset.records().clone(), residuals); + let second = self + .second + .fit(&residual_dataset) + .map_err(ResidualSequenceError::Second)?; + + Ok(FittedResidualSequence { first, second }) + } +} + +impl<'a, R1, R2, D> Predict<&'a Arr2, Array1> for FittedResidualSequence +where + D: Data, + D::Elem: Copy + Add, + Arr2: Records, + for<'b> R1: Predict<&'b Arr2, Array1>, + for<'b> R2: Predict<&'b Arr2, Array1>, +{ + fn predict(&self, x: &'a Arr2) -> Array1 { + let pred1 = self.first.predict(x); + let pred2 = self.second.predict(x); + pred1 + pred2 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use linfa::error::Error as LinfaError; + use linfa::DatasetBase; + use ndarray::{array, Array1, Array2}; + + #[derive(thiserror::Error, Debug)] + #[error("dummy error")] + struct DummyError(#[from] LinfaError); + + // Params that fits by recording the mean of the targets. + struct MeanParams; + + // Model that predicts the mean it saw during fit. + struct MeanModel(f64); + + impl Fit, Array1, DummyError> for MeanParams { + type Object = MeanModel; + fn fit( + &self, + dataset: &DatasetBase, Array1>, + ) -> Result { + let mean = dataset.targets().iter().sum::() / dataset.targets().len() as f64; + Ok(MeanModel(mean)) + } + } + + impl<'a> Predict<&'a Array2, Array1> for MeanModel { + fn predict(&self, x: &'a Array2) -> Array1 { + Array1::from_elem(x.nrows(), self.0) + } + } + + #[test] + fn second_is_fit_on_residuals() { + // targets = [1, 3]. first sees mean=2, predicts 2 for all. + // residuals = [1-2, 3-2] = [-1, 1]. second sees mean=0. + let model = ResidualSequence { + first: MeanParams, + second: MeanParams, + }; + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + assert_eq!(fitted.first.0, 2.0); // mean of [1, 3] + assert_eq!(fitted.second.0, 0.0); // mean of residuals [-1, 1] + } + + #[test] + fn predict_sums_both_models() { + // first predicts 2.0, second predicts 0.0 → sum = 2.0 + let model = MeanParams.stack_with(MeanParams); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + let records = array![[0.0_f64], [1.0]]; + let predictions = fitted.predict(&records); + assert_eq!(predictions, array![2.0, 2.0]); + } + + #[test] + fn predict_recovers_targets_when_residuals_fit_perfectly() { + // If second perfectly fits the residuals, the combined prediction = original targets. + struct FixedParams(f64); + struct FixedModel(f64); + + impl Fit, Array1, DummyError> for FixedParams { + type Object = FixedModel; + fn fit( + &self, + _dataset: &DatasetBase, Array1>, + ) -> Result { + Ok(FixedModel(self.0)) + } + } + + impl<'a> Predict<&'a Array2, Array1> for FixedModel { + fn predict(&self, x: &'a Array2) -> Array1 { + Array1::from_elem(x.nrows(), self.0) + } + } + + // first predicts 3.0, second predicts 1.0 → sum = 4.0 + let model = FixedParams(3.0) + .stack_with(FixedParams(1.0)) + .stack_with(FixedParams(0.0)); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![4.0, 4.0]); + let fitted = model.fit(&dataset).unwrap(); + + let predictions = fitted.predict(&array![[0.0_f64], [1.0]]); + assert_eq!(predictions, array![4.0, 4.0]); + } +} From edaf942cca45e05ce4ef1b14867097bcdd4bbb63 Mon Sep 17 00:00:00 2001 From: feiyang Date: Fri, 27 Feb 2026 21:43:28 +0000 Subject: [PATCH 02/19] remove doc link --- algorithms/linfa-residual-sequence/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/algorithms/linfa-residual-sequence/src/lib.rs b/algorithms/linfa-residual-sequence/src/lib.rs index 4a4fa4723..622c6a979 100644 --- a/algorithms/linfa-residual-sequence/src/lib.rs +++ b/algorithms/linfa-residual-sequence/src/lib.rs @@ -16,7 +16,7 @@ //! //! ## Linear + linear //! -//! Two [`linfa_linear::LinearRegression`] models stacked: the second fits the +//! Two `linfa_linear::LinearRegression` models stacked: the second fits the //! residuals left by the first. //! //! ``` @@ -69,8 +69,8 @@ //! //! ## Chained SVMs and linear regression //! -//! A linear-kernel [`linfa_svm::Svm`] captures the overall trend; two -//! Gaussian-kernel SVMs and a [`linfa_linear::LinearRegression`] then fit +//! A linear-kernel `linfa_svm::Svm` captures the overall trend; two +//! Gaussian-kernel SVMs and a `linfa_linear::LinearRegression` then fit //! successive residuals in a four-model chain. //! //! ``` From 5795cff9b3d28a67253b81576ab479c34b72a6bd Mon Sep 17 00:00:00 2001 From: feiyang Date: Sun, 1 Mar 2026 18:45:13 +0000 Subject: [PATCH 03/19] move to composing/ module in linfa main crate --- Cargo.toml | 2 ++ algorithms/linfa-residual-sequence/Cargo.toml | 18 --------------- src/composing/mod.rs | 1 + .../composing/residual_sequence.rs | 22 +++++++++---------- 4 files changed, 14 insertions(+), 29 deletions(-) delete mode 100644 algorithms/linfa-residual-sequence/Cargo.toml rename algorithms/linfa-residual-sequence/src/lib.rs => src/composing/residual_sequence.rs (94%) diff --git a/Cargo.toml b/Cargo.toml index 38e6654c0..8f11fc550 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,8 @@ linfa-datasets = { path = "datasets", features = [ "diabetes", "generate", ] } +linfa-linear = { path = "algorithms/linfa-linear" } +linfa-svm = { path = "algorithms/linfa-svm" } statrs = "0.18" [target.'cfg(not(windows))'.dependencies] diff --git a/algorithms/linfa-residual-sequence/Cargo.toml b/algorithms/linfa-residual-sequence/Cargo.toml deleted file mode 100644 index 2a1641e9f..000000000 --- a/algorithms/linfa-residual-sequence/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "linfa-residual-sequence" -version = "0.8.1" -edition = "2018" -description = "Model composition utilities for the linfa ML framework" -license = "MIT OR Apache-2.0" -repository = "https://github.com/rust-ml/linfa" -keywords = ["machine-learning", "linfa", "ai", "ml", "residual"] -categories = ["algorithms", "mathematics", "science"] - -[dependencies] -linfa = { version = "0.8.1", path = "../.." } -ndarray = { version = "0.16" } -thiserror = "2.0" - -[dev-dependencies] -linfa-linear = { path = "../linfa-linear" } -linfa-svm = { path = "../linfa-svm" } diff --git a/src/composing/mod.rs b/src/composing/mod.rs index a1f2acc37..d7453014e 100644 --- a/src/composing/mod.rs +++ b/src/composing/mod.rs @@ -7,6 +7,7 @@ mod multi_class_model; mod multi_target_model; pub mod platt_scaling; +pub mod residual_sequence; pub use multi_class_model::MultiClassModel; pub use multi_target_model::MultiTargetModel; diff --git a/algorithms/linfa-residual-sequence/src/lib.rs b/src/composing/residual_sequence.rs similarity index 94% rename from algorithms/linfa-residual-sequence/src/lib.rs rename to src/composing/residual_sequence.rs index 622c6a979..704a3e4b8 100644 --- a/algorithms/linfa-residual-sequence/src/lib.rs +++ b/src/composing/residual_sequence.rs @@ -23,7 +23,7 @@ //! use linfa::traits::{Fit, Predict}; //! use linfa::DatasetBase; //! use linfa_linear::LinearRegression; -//! use linfa_residual_sequence::{ResidualSequence, StackWith}; +//! use linfa::composing::residual_sequence::{ResidualSequence, StackWith}; //! use ndarray::{array, Array2}; //! //! // y = 2x: perfectly linear, so the second model should see zero residuals. @@ -49,7 +49,7 @@ //! use linfa::traits::{Fit, Predict}; //! use linfa::DatasetBase; //! use linfa_linear::LinearRegression; -//! use linfa_residual_sequence::StackWith; +//! use linfa::composing::residual_sequence::StackWith; //! use ndarray::{array, Array2}; //! //! // y = 2x: one linear model is enough to fit this perfectly. @@ -77,7 +77,7 @@ //! use linfa::traits::{Fit, Predict}; //! use linfa::DatasetBase; //! use linfa_linear::LinearRegression; -//! use linfa_residual_sequence::{ResidualSequence, StackWith}; +//! use linfa::composing::residual_sequence::{ResidualSequence, StackWith}; //! use linfa_svm::Svm; //! use ndarray::Array; //! @@ -109,8 +109,8 @@ //! let _preds = fitted.predict(&x); //! ``` -use linfa::dataset::{AsTargets, DatasetBase, Records}; -use linfa::traits::{Fit, Predict}; +use crate::dataset::{AsTargets, DatasetBase, Records}; +use crate::traits::{Fit, Predict}; use ndarray::{Array1, ArrayBase, Data, Ix1, Ix2, RawDataClone}; use std::ops::{Add, Sub}; @@ -128,7 +128,7 @@ pub enum ResidualSequenceError { Second(E2), // Satisfies the `Fit` trait's `E: From` bound. #[error(transparent)] - Linfa(#[from] linfa::error::Error), + Linfa(#[from] crate::error::Error), } /// Fits two models sequentially on the residuals of the first. @@ -149,7 +149,7 @@ pub struct ResidualSequence { /// use linfa::traits::Fit; /// use linfa::DatasetBase; /// use linfa_linear::LinearRegression; -/// use linfa_residual_sequence::StackWith; +/// use linfa::composing::residual_sequence::StackWith; /// use ndarray::{array, Array2}; /// /// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); @@ -193,8 +193,8 @@ where for<'a> F1::Object: Predict<&'a Arr2, Array1>, F2: Fit, Array1, E2>, T: AsTargets, - E1: std::error::Error + From, - E2: std::error::Error + From, + E1: std::error::Error + From, + E2: std::error::Error + From, { type Object = FittedResidualSequence; @@ -244,8 +244,8 @@ where #[cfg(test)] mod tests { use super::*; - use linfa::error::Error as LinfaError; - use linfa::DatasetBase; + use crate::error::Error as LinfaError; + use crate::DatasetBase; use ndarray::{array, Array1, Array2}; #[derive(thiserror::Error, Debug)] From 9495342d932a39a2bdafc7419f37ad6ac19599ce Mon Sep 17 00:00:00 2001 From: feiyang Date: Sun, 1 Mar 2026 19:01:45 +0000 Subject: [PATCH 04/19] update docs --- src/composing/mod.rs | 3 ++- src/composing/residual_sequence.rs | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/composing/mod.rs b/src/composing/mod.rs index d7453014e..fddee337d 100644 --- a/src/composing/mod.rs +++ b/src/composing/mod.rs @@ -1,9 +1,10 @@ //! Composition models //! -//! This module contains three composition models: +//! This module contains four composition models: //! * `MultiClassModel`: combine multiple binary decision models to a single multi-class model //! * `MultiTargetModel`: combine multiple univariate models to a single multi-target model //! * `Platt`: calibrate a classifier (i.e. SVC) to predicted posterior probabilities +//! * `ResidualSequence`: fit models sequentially on the residuals of the previous one (stagewise additive modeling / boosting) mod multi_class_model; mod multi_target_model; pub mod platt_scaling; diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index 704a3e4b8..8604789cb 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -134,7 +134,7 @@ pub enum ResidualSequenceError { /// Fits two models sequentially on the residuals of the first. /// /// `first` is fit on the original dataset. `second` is fit on the residuals -/// `Y - first.predict(X)`. See the [crate-level docs](crate) for details. +/// `Y - first.predict(X)`. See the [module docs](self) for details. #[derive(Debug, Clone)] pub struct ResidualSequence { pub first: F1, @@ -162,6 +162,9 @@ pub struct ResidualSequence { /// .unwrap(); /// ``` pub trait StackWith: Sized { + /// Wrap `self` and `second` into a [`ResidualSequence`] that will fit + /// `second` on the residuals left by `self`. Calls can be chained to add + /// further stages. fn stack_with(self, second: F2) -> ResidualSequence; } From 9eebbaa38551030f936455b131e5e9bd6d812010 Mon Sep 17 00:00:00 2001 From: feiyang Date: Mon, 2 Mar 2026 18:25:23 +0000 Subject: [PATCH 05/19] implement PredictInplace instead --- src/composing/residual_sequence.rs | 185 ++++++++++++++++++++++++----- 1 file changed, 156 insertions(+), 29 deletions(-) diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index 8604789cb..7e4c505bc 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -110,9 +110,12 @@ //! ``` use crate::dataset::{AsTargets, DatasetBase, Records}; -use crate::traits::{Fit, Predict}; +use crate::traits::{Fit, Predict, PredictInplace}; +use crate::{Float, ParamGuard}; use ndarray::{Array1, ArrayBase, Data, Ix1, Ix2, RawDataClone}; -use std::ops::{Add, Sub}; +#[cfg(feature = "serde")] +use serde_crate::{Deserialize, Serialize}; +use std::ops::AddAssign; type Arr2 = ArrayBase; @@ -128,17 +131,42 @@ pub enum ResidualSequenceError { Second(E2), // Satisfies the `Fit` trait's `E: From` bound. #[error(transparent)] - Linfa(#[from] crate::error::Error), + BaseCrate(#[from] crate::Error), +} + +/// Error returned when checking [`ResidualSequence`] hyperparameters. +/// +/// Wraps the validation error from whichever sub-model's parameter check failed. +#[derive(Debug, thiserror::Error)] +pub enum ResidualParamError { + #[error("first model params: {0}")] + First(E1), + #[error("second model params: {0}")] + Second(E2), } /// Fits two models sequentially on the residuals of the first. /// /// `first` is fit on the original dataset. `second` is fit on the residuals /// `Y - first.predict(X)`. See the [module docs](self) for details. +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] #[derive(Debug, Clone)] pub struct ResidualSequence { - pub first: F1, - pub second: F2, + first: F1, + second: F2, +} + +impl ResidualSequence { + pub fn first(&self) -> &F1 { + &self.first + } + pub fn second(&self) -> &F2 { + &self.second + } } /// Extension trait that lets any model params type be composed into a [`ResidualSequence`]. @@ -177,6 +205,29 @@ impl StackWith for F1 { } } +impl ParamGuard for ResidualSequence +where + F1: ParamGuard, + F2: ParamGuard, +{ + type Checked = Self; + type Error = ResidualParamError; + + /// Validates both sub-model hyperparameters. + /// + /// Returns a reference to `self` if both pass, or the first error encountered. + fn check_ref(&self) -> Result<&Self::Checked, Self::Error> { + self.first.check_ref().map_err(ResidualParamError::First)?; + self.second.check_ref().map_err(ResidualParamError::Second)?; + Ok(self) + } + + fn check(self) -> Result { + self.check_ref()?; + Ok(self) + } +} + /// Two fitted models produced by [`ResidualSequence::fit`]. /// /// Predicts by summing both models' outputs: `first.predict(X) + second.predict(X)`. @@ -186,16 +237,14 @@ pub struct FittedResidualSequence { pub second: R2, } -impl Fit, T, ResidualSequenceError> - for ResidualSequence +impl + RawDataClone, T, E1, E2> + Fit, T, ResidualSequenceError> for ResidualSequence where - D: Data + RawDataClone, - D::Elem: Copy + Sub, Arr2: Records, F1: Fit, T, E1>, - for<'a> F1::Object: Predict<&'a Arr2, Array1>, - F2: Fit, Array1, E2>, - T: AsTargets, + for<'a> F1::Object: Predict<&'a Arr2, Array1>, + F2: Fit, Array1, E2>, + T: AsTargets, E1: std::error::Error + From, E2: std::error::Error + From, { @@ -211,13 +260,7 @@ where .map_err(ResidualSequenceError::First)?; let y_pred = first.predict(dataset.records()); - let residuals = dataset - .targets() - .as_targets() - .iter() - .zip(y_pred.iter()) - .map(|(y, p)| *y - *p) - .collect::>(); + let residuals = &dataset.targets().as_targets() - &y_pred; let residual_dataset = DatasetBase::new(dataset.records().clone(), residuals); let second = self @@ -229,18 +272,19 @@ where } } -impl<'a, R1, R2, D> Predict<&'a Arr2, Array1> for FittedResidualSequence +impl> PredictInplace, Array1> + for FittedResidualSequence where - D: Data, - D::Elem: Copy + Add, - Arr2: Records, - for<'b> R1: Predict<&'b Arr2, Array1>, - for<'b> R2: Predict<&'b Arr2, Array1>, + for<'a> R1: Predict<&'a Arr2, Array1>, + for<'a> R2: Predict<&'a Arr2, Array1>, { - fn predict(&self, x: &'a Arr2) -> Array1 { - let pred1 = self.first.predict(x); - let pred2 = self.second.predict(x); - pred1 + pred2 + fn predict_inplace<'a>(&'a self, x: &'a Arr2, y: &mut Array1) { + y.assign(&self.first.predict(x)); + y.add_assign(&self.second.predict(x)); + } + + fn default_target(&self, x: &Arr2) -> Array1 { + Array1::zeros(x.nrows()) } } @@ -255,6 +299,43 @@ mod tests { #[error("dummy error")] struct DummyError(#[from] LinfaError); + // --- ParamGuard helpers --- + + // Error used by test ParamGuard stubs. + #[derive(thiserror::Error, Debug, PartialEq)] + #[error("invalid params: {0}")] + struct ParamErr(String); + + // Always-valid params stub. + #[derive(Debug)] + struct OkParams; + + impl ParamGuard for OkParams { + type Checked = Self; + type Error = ParamErr; + fn check_ref(&self) -> Result<&Self, ParamErr> { + Ok(self) + } + fn check(self) -> Result { + Ok(self) + } + } + + // Always-invalid params stub. + #[derive(Debug)] + struct BadParams(String); + + impl ParamGuard for BadParams { + type Checked = Self; + type Error = ParamErr; + fn check_ref(&self) -> Result<&Self, ParamErr> { + Err(ParamErr(self.0.clone())) + } + fn check(self) -> Result { + Err(ParamErr(self.0)) + } + } + // Params that fits by recording the mean of the targets. struct MeanParams; @@ -337,4 +418,50 @@ mod tests { let predictions = fitted.predict(&array![[0.0_f64], [1.0]]); assert_eq!(predictions, array![4.0, 4.0]); } + + // --- ParamGuard tests --- + + #[test] + fn param_guard_check_ref_succeeds_when_both_params_valid() { + let seq = OkParams.stack_with(OkParams); + assert!(seq.check_ref().is_ok()); + } + + #[test] + fn param_guard_check_ref_fails_on_invalid_first() { + let seq = BadParams("bad first".into()).stack_with(OkParams); + let err = seq.check_ref().unwrap_err(); + assert!(matches!(err, ResidualParamError::First(ParamErr(_)))); + } + + #[test] + fn param_guard_check_ref_fails_on_invalid_second() { + let seq = OkParams.stack_with(BadParams("bad second".into())); + let err = seq.check_ref().unwrap_err(); + assert!(matches!(err, ResidualParamError::Second(ParamErr(_)))); + } + + #[test] + fn param_guard_check_succeeds_and_returns_self() { + let seq = OkParams.stack_with(OkParams); + assert!(seq.check().is_ok()); + } + + #[test] + fn param_guard_check_fails_on_invalid_first() { + let seq = BadParams("bad".into()).stack_with(OkParams); + assert!(matches!( + seq.check().unwrap_err(), + ResidualParamError::First(_) + )); + } + + #[test] + fn param_guard_check_fails_on_invalid_second() { + let seq = OkParams.stack_with(BadParams("bad".into())); + assert!(matches!( + seq.check().unwrap_err(), + ResidualParamError::Second(_) + )); + } } From eca068ea1d1d3cebdb8a53b727acabb72bcc7f94 Mon Sep 17 00:00:00 2001 From: feiyang Date: Mon, 2 Mar 2026 18:25:46 +0000 Subject: [PATCH 06/19] remove unused param error --- src/composing/residual_sequence.rs | 119 +---------------------------- 1 file changed, 1 insertion(+), 118 deletions(-) diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index 7e4c505bc..5e694676b 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -111,7 +111,7 @@ use crate::dataset::{AsTargets, DatasetBase, Records}; use crate::traits::{Fit, Predict, PredictInplace}; -use crate::{Float, ParamGuard}; +use crate::Float; use ndarray::{Array1, ArrayBase, Data, Ix1, Ix2, RawDataClone}; #[cfg(feature = "serde")] use serde_crate::{Deserialize, Serialize}; @@ -134,17 +134,6 @@ pub enum ResidualSequenceError { BaseCrate(#[from] crate::Error), } -/// Error returned when checking [`ResidualSequence`] hyperparameters. -/// -/// Wraps the validation error from whichever sub-model's parameter check failed. -#[derive(Debug, thiserror::Error)] -pub enum ResidualParamError { - #[error("first model params: {0}")] - First(E1), - #[error("second model params: {0}")] - Second(E2), -} - /// Fits two models sequentially on the residuals of the first. /// /// `first` is fit on the original dataset. `second` is fit on the residuals @@ -205,29 +194,6 @@ impl StackWith for F1 { } } -impl ParamGuard for ResidualSequence -where - F1: ParamGuard, - F2: ParamGuard, -{ - type Checked = Self; - type Error = ResidualParamError; - - /// Validates both sub-model hyperparameters. - /// - /// Returns a reference to `self` if both pass, or the first error encountered. - fn check_ref(&self) -> Result<&Self::Checked, Self::Error> { - self.first.check_ref().map_err(ResidualParamError::First)?; - self.second.check_ref().map_err(ResidualParamError::Second)?; - Ok(self) - } - - fn check(self) -> Result { - self.check_ref()?; - Ok(self) - } -} - /// Two fitted models produced by [`ResidualSequence::fit`]. /// /// Predicts by summing both models' outputs: `first.predict(X) + second.predict(X)`. @@ -299,43 +265,6 @@ mod tests { #[error("dummy error")] struct DummyError(#[from] LinfaError); - // --- ParamGuard helpers --- - - // Error used by test ParamGuard stubs. - #[derive(thiserror::Error, Debug, PartialEq)] - #[error("invalid params: {0}")] - struct ParamErr(String); - - // Always-valid params stub. - #[derive(Debug)] - struct OkParams; - - impl ParamGuard for OkParams { - type Checked = Self; - type Error = ParamErr; - fn check_ref(&self) -> Result<&Self, ParamErr> { - Ok(self) - } - fn check(self) -> Result { - Ok(self) - } - } - - // Always-invalid params stub. - #[derive(Debug)] - struct BadParams(String); - - impl ParamGuard for BadParams { - type Checked = Self; - type Error = ParamErr; - fn check_ref(&self) -> Result<&Self, ParamErr> { - Err(ParamErr(self.0.clone())) - } - fn check(self) -> Result { - Err(ParamErr(self.0)) - } - } - // Params that fits by recording the mean of the targets. struct MeanParams; @@ -418,50 +347,4 @@ mod tests { let predictions = fitted.predict(&array![[0.0_f64], [1.0]]); assert_eq!(predictions, array![4.0, 4.0]); } - - // --- ParamGuard tests --- - - #[test] - fn param_guard_check_ref_succeeds_when_both_params_valid() { - let seq = OkParams.stack_with(OkParams); - assert!(seq.check_ref().is_ok()); - } - - #[test] - fn param_guard_check_ref_fails_on_invalid_first() { - let seq = BadParams("bad first".into()).stack_with(OkParams); - let err = seq.check_ref().unwrap_err(); - assert!(matches!(err, ResidualParamError::First(ParamErr(_)))); - } - - #[test] - fn param_guard_check_ref_fails_on_invalid_second() { - let seq = OkParams.stack_with(BadParams("bad second".into())); - let err = seq.check_ref().unwrap_err(); - assert!(matches!(err, ResidualParamError::Second(ParamErr(_)))); - } - - #[test] - fn param_guard_check_succeeds_and_returns_self() { - let seq = OkParams.stack_with(OkParams); - assert!(seq.check().is_ok()); - } - - #[test] - fn param_guard_check_fails_on_invalid_first() { - let seq = BadParams("bad".into()).stack_with(OkParams); - assert!(matches!( - seq.check().unwrap_err(), - ResidualParamError::First(_) - )); - } - - #[test] - fn param_guard_check_fails_on_invalid_second() { - let seq = OkParams.stack_with(BadParams("bad".into())); - assert!(matches!( - seq.check().unwrap_err(), - ResidualParamError::Second(_) - )); - } } From 9e4c79dfec711cdcf034f304aed43345a2394869 Mon Sep 17 00:00:00 2001 From: feiyang Date: Mon, 2 Mar 2026 20:54:42 +0000 Subject: [PATCH 07/19] use one struct to implement stacking --- src/composing/residual_sequence.rs | 33 +++++++++--------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index 5e694676b..efb1b23c3 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -63,8 +63,8 @@ //! .unwrap(); //! //! // The second model trained on zero residuals — nothing left to correct. -//! assert!(fitted.second.params().iter().all(|&c: &f64| c.abs() < 1e-10)); -//! assert!(fitted.second.intercept().abs() < 1e-10); +//! assert!(fitted.second().params().iter().all(|&c: &f64| c.abs() < 1e-10)); +//! assert!(fitted.second().intercept().abs() < 1e-10); //! ``` //! //! ## Chained SVMs and linear regression @@ -134,10 +134,9 @@ pub enum ResidualSequenceError { BaseCrate(#[from] crate::Error), } -/// Fits two models sequentially on the residuals of the first. +/// A pair of [`Fit`] params that fits sequentially on residuals, returning a pair of fitted models. /// -/// `first` is fit on the original dataset. `second` is fit on the residuals -/// `Y - first.predict(X)`. See the [module docs](self) for details. +/// The fitted pair implements [`PredictInplace`] by summing both outputs. #[cfg_attr( feature = "serde", derive(Serialize, Deserialize), @@ -194,15 +193,6 @@ impl StackWith for F1 { } } -/// Two fitted models produced by [`ResidualSequence::fit`]. -/// -/// Predicts by summing both models' outputs: `first.predict(X) + second.predict(X)`. -#[derive(Debug, Clone)] -pub struct FittedResidualSequence { - pub first: R1, - pub second: R2, -} - impl + RawDataClone, T, E1, E2> Fit, T, ResidualSequenceError> for ResidualSequence where @@ -214,7 +204,7 @@ where E1: std::error::Error + From, E2: std::error::Error + From, { - type Object = FittedResidualSequence; + type Object = ResidualSequence; fn fit( &self, @@ -234,12 +224,12 @@ where .fit(&residual_dataset) .map_err(ResidualSequenceError::Second)?; - Ok(FittedResidualSequence { first, second }) + Ok(ResidualSequence { first, second }) } } impl> PredictInplace, Array1> - for FittedResidualSequence + for ResidualSequence where for<'a> R1: Predict<&'a Arr2, Array1>, for<'a> R2: Predict<&'a Arr2, Array1>, @@ -292,15 +282,12 @@ mod tests { fn second_is_fit_on_residuals() { // targets = [1, 3]. first sees mean=2, predicts 2 for all. // residuals = [1-2, 3-2] = [-1, 1]. second sees mean=0. - let model = ResidualSequence { - first: MeanParams, - second: MeanParams, - }; + let model = MeanParams.stack_with(MeanParams); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); let fitted = model.fit(&dataset).unwrap(); - assert_eq!(fitted.first.0, 2.0); // mean of [1, 3] - assert_eq!(fitted.second.0, 0.0); // mean of residuals [-1, 1] + assert_eq!(fitted.first().0, 2.0); // mean of [1, 3] + assert_eq!(fitted.second().0, 0.0); // mean of residuals [-1, 1] } #[test] From 24466cf7f97140aa5b2c0bd8f619c47a38fb7ed8 Mon Sep 17 00:00:00 2001 From: feiyang Date: Mon, 2 Mar 2026 21:00:52 +0000 Subject: [PATCH 08/19] add deep chain test --- src/composing/residual_sequence.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index efb1b23c3..0fba67b2f 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -334,4 +334,16 @@ mod tests { let predictions = fitted.predict(&array![[0.0_f64], [1.0]]); assert_eq!(predictions, array![4.0, 4.0]); } + + #[test] + fn deep_chain_accessors() { + let model = MeanParams + .stack_with(MeanParams) + .stack_with(MeanParams) + .stack_with(MeanParams); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + assert_eq!(fitted.first().first().first().0, 2.0); // params trained on original targets + } } From cd371959ddfa221c56f781d4f29dbc596131165c Mon Sep 17 00:00:00 2001 From: feiyang Date: Wed, 4 Mar 2026 20:55:35 +0000 Subject: [PATCH 09/19] Rename to ResidualChain Implement Shrinkage implement paramguard for shrinkage --- src/composing/mod.rs | 3 +- src/composing/residual_sequence.rs | 353 ++++++++++++++++++++++------- 2 files changed, 271 insertions(+), 85 deletions(-) diff --git a/src/composing/mod.rs b/src/composing/mod.rs index fddee337d..d7aea6965 100644 --- a/src/composing/mod.rs +++ b/src/composing/mod.rs @@ -4,7 +4,8 @@ //! * `MultiClassModel`: combine multiple binary decision models to a single multi-class model //! * `MultiTargetModel`: combine multiple univariate models to a single multi-target model //! * `Platt`: calibrate a classifier (i.e. SVC) to predicted posterior probabilities -//! * `ResidualSequence`: fit models sequentially on the residuals of the previous one (stagewise additive modeling / boosting) +//! * `ResidualChain`: fit models sequentially on the residuals of the previous one +//! (forward stagewise additive modeling / L2Boosting); see [`residual_sequence::Stagewise`] mod multi_class_model; mod multi_target_model; pub mod platt_scaling; diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index 0fba67b2f..d5fb5bc96 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -1,38 +1,41 @@ -//! Residual sequence model composition for the linfa ML framework. +//! L2Boosting (forward stagewise additive modelling with squared-error loss) +//! for the linfa ML framework. //! -//! This crate provides [`ResidualSequence`], which fits models sequentially on -//! the residuals of the previous one. Chain as many as you like via [`StackWith`]: +//! This module provides [`ResidualChain`], which fits models sequentially on +//! residuals. Chain as many stages as you like via [`Stagewise`]: //! -//! 1. Fit `first` on `(X, Y)` -//! 2. Compute residuals: `R = Y - first.predict(X)` -//! 3. Fit `second` on `(X, R)` -//! 4. Repeat for any further models stacked on top +//! 1. Fit `base` on `(X, Y)` +//! 2. Compute residuals: `R = Y - base.predict(X)` +//! 3. Fit `corrector` on `(X, R)` +//! 4. Repeat for any further correctors stacked on top //! -//! When predicting, all models' outputs are summed. +//! When predicting, all stages' outputs are summed. //! -//! This is the foundation of boosting / residual stacking. +//! This is the special case of FSAM (Friedman 2001) where the loss is squared +//! error. Shrinkage (learning rate ν ∈ (0, 1]) can be set per corrector via +//! [`Shrunk::with_shrinkage`]; the default is ν = 1 (no scaling). //! //! # Examples //! //! ## Linear + linear //! -//! Two `linfa_linear::LinearRegression` models stacked: the second fits the -//! residuals left by the first. +//! Two `linfa_linear::LinearRegression` models stacked: the corrector fits +//! the residuals left by the base. //! //! ``` //! use linfa::traits::{Fit, Predict}; //! use linfa::DatasetBase; //! use linfa_linear::LinearRegression; -//! use linfa::composing::residual_sequence::{ResidualSequence, StackWith}; +//! use linfa::composing::residual_sequence::{ResidualChain, Stagewise}; //! use ndarray::{array, Array2}; //! -//! // y = 2x: perfectly linear, so the second model should see zero residuals. +//! // y = 2x: perfectly linear, so the corrector should see zero residuals. //! let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); //! let y = array![0., 2., 4., 6., 8.]; //! let dataset = DatasetBase::new(x.clone(), y); //! //! let fitted = LinearRegression::default() -//! .stack_with(LinearRegression::default()) +//! .stack_with(LinearRegression::default().shrink_by(1.0)) //! .fit(&dataset) //! .unwrap(); //! @@ -49,7 +52,7 @@ //! use linfa::traits::{Fit, Predict}; //! use linfa::DatasetBase; //! use linfa_linear::LinearRegression; -//! use linfa::composing::residual_sequence::StackWith; +//! use linfa::composing::residual_sequence::Stagewise; //! use ndarray::{array, Array2}; //! //! // y = 2x: one linear model is enough to fit this perfectly. @@ -58,13 +61,13 @@ //! let dataset = DatasetBase::new(x.clone(), y); //! //! let fitted = LinearRegression::default() -//! .stack_with(LinearRegression::default()) +//! .stack_with(LinearRegression::default().shrink_by(1.0)) //! .fit(&dataset) //! .unwrap(); //! -//! // The second model trained on zero residuals — nothing left to correct. -//! assert!(fitted.second().params().iter().all(|&c: &f64| c.abs() < 1e-10)); -//! assert!(fitted.second().intercept().abs() < 1e-10); +//! // The corrector trained on zero residuals — nothing left to correct. +//! assert!(fitted.corrector().model().params().iter().all(|&c: &f64| c.abs() < 1e-10)); +//! assert!(fitted.corrector().model().intercept().abs() < 1e-10); //! ``` //! //! ## Chained SVMs and linear regression @@ -77,7 +80,7 @@ //! use linfa::traits::{Fit, Predict}; //! use linfa::DatasetBase; //! use linfa_linear::LinearRegression; -//! use linfa::composing::residual_sequence::{ResidualSequence, StackWith}; +//! use linfa::composing::residual_sequence::{ResidualChain, Stagewise}; //! use linfa_svm::Svm; //! use ndarray::Array; //! @@ -95,13 +98,15 @@ //! .stack_with( //! Svm::::params() //! .c_svr(10., Some(0.1)) -//! .gaussian_kernel(1.), +//! .gaussian_kernel(1.) +//! .shrink_by(1.0), //! ) -//! .stack_with(LinearRegression::default()) +//! .stack_with(LinearRegression::default().shrink_by(1.0)) //! .stack_with( //! Svm::::params() //! .c_svr(10., Some(0.1)) -//! .gaussian_kernel(3.), +//! .gaussian_kernel(3.) +//! .shrink_by(1.0), //! ) //! .fit(&dataset) //! .unwrap(); @@ -110,54 +115,68 @@ //! ``` use crate::dataset::{AsTargets, DatasetBase, Records}; +use crate::param_guard::ParamGuard; use crate::traits::{Fit, Predict, PredictInplace}; use crate::Float; use ndarray::{Array1, ArrayBase, Data, Ix1, Ix2, RawDataClone}; #[cfg(feature = "serde")] use serde_crate::{Deserialize, Serialize}; -use std::ops::AddAssign; +use std::ops::{AddAssign, Mul}; type Arr2 = ArrayBase; -/// Error returned by [`ResidualSequence::fit`]. +/// Error returned by [`ResidualChain::fit`]. /// /// Wraps the error from whichever of the two model fits failed, keeping them /// distinguishable without requiring both models to share the same error type. #[derive(Debug, thiserror::Error)] -pub enum ResidualSequenceError { - #[error("first model: {0}")] - First(E1), - #[error("second model: {0}")] - Second(E2), +pub enum ResidualChainError { + #[error("base model: {0}")] + Base(E1), + #[error("corrector: {0}")] + Corrector(E2), // Satisfies the `Fit` trait's `E: From` bound. #[error(transparent)] BaseCrate(#[from] crate::Error), } -/// A pair of [`Fit`] params that fits sequentially on residuals, returning a pair of fitted models. +/// A pair of [`Fit`] params that fits sequentially on residuals. /// -/// The fitted pair implements [`PredictInplace`] by summing both outputs. +/// `base` is fit on the original targets; `corrector` (a [`Shrunk`] model) is +/// fit on the residuals left by `base` and scaled by its shrinkage factor ν. +/// Prediction sums `base` and the scaled corrector output. #[cfg_attr( feature = "serde", derive(Serialize, Deserialize), serde(crate = "serde_crate") )] #[derive(Debug, Clone)] -pub struct ResidualSequence { - first: F1, - second: F2, +pub struct ResidualChain { + base: B, + corrector: Shrunk, } -impl ResidualSequence { - pub fn first(&self) -> &F1 { - &self.first +impl ResidualChain { + pub fn base(&self) -> &B { + &self.base } - pub fn second(&self) -> &F2 { - &self.second + pub fn corrector(&self) -> &Shrunk { + &self.corrector } } -/// Extension trait that lets any model params type be composed into a [`ResidualSequence`]. +/// Extension trait that adds residual-chain composition methods to any type. +/// +/// Blanket-implemented for all `Sized` types, so any model params type gains +/// these methods automatically: +/// +/// - [`stack_with`](Stagewise::stack_with): compose `self` (as the base) with +/// a [`Shrunk`] corrector that will be trained on the residuals left by +/// `self`. Returns a [`ResidualChainParams`] whose `.fit()` runs both stages. +/// Calls can be chained to build arbitrarily deep sequences. +/// - [`shrink_by`](Stagewise::shrink_by): wrap `self` in a [`Shrunk`] with the +/// given learning rate ν, making it ready to pass as the `corrector` argument +/// to [`stack_with`]. /// /// # Example /// @@ -165,7 +184,7 @@ impl ResidualSequence { /// use linfa::traits::Fit; /// use linfa::DatasetBase; /// use linfa_linear::LinearRegression; -/// use linfa::composing::residual_sequence::StackWith; +/// use linfa::composing::residual_sequence::Stagewise; /// use ndarray::{array, Array2}; /// /// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); @@ -173,28 +192,37 @@ impl ResidualSequence { /// let dataset = DatasetBase::new(x.clone(), y); /// /// let fitted = LinearRegression::default() -/// .stack_with(LinearRegression::default()) +/// .stack_with(LinearRegression::default().shrink_by(1.0)) /// .fit(&dataset) /// .unwrap(); /// ``` -pub trait StackWith: Sized { - /// Wrap `self` and `second` into a [`ResidualSequence`] that will fit - /// `second` on the residuals left by `self`. Calls can be chained to add - /// further stages. - fn stack_with(self, second: F2) -> ResidualSequence; +pub trait Stagewise: Sized { + /// Compose `self` (as the base model) with `corrector`, which will be + /// trained on the residuals left by `self`. Further stages can be appended + /// by calling `.stack_with(...)` on the returned [`ResidualChainParams`]. + fn stack_with(self, corrector: Shrunk) -> ResidualChainParams; + /// Wrap `self` in a [`Shrunk`] with learning rate `shrinkage` ∈ (0, 1], + /// making it ready to pass as the `corrector` argument to [`stack_with`]. + fn shrink_by(self, shrinkage: F) -> Shrunk; } -impl StackWith for F1 { - fn stack_with(self, second: F2) -> ResidualSequence { - ResidualSequence { - first: self, - second, +impl Stagewise for B { + fn stack_with(self, corrector: Shrunk) -> ResidualChainParams { + ResidualChainParams(ResidualChain { + base: self, + corrector, + }) + } + fn shrink_by(self, shrinkage: F) -> Shrunk { + Shrunk { + model: self, + shrinkage, } } } impl + RawDataClone, T, E1, E2> - Fit, T, ResidualSequenceError> for ResidualSequence + Fit, T, ResidualChainError> for ResidualChain where Arr2: Records, F1: Fit, T, E1>, @@ -204,39 +232,49 @@ where E1: std::error::Error + From, E2: std::error::Error + From, { - type Object = ResidualSequence; + type Object = ResidualChain; fn fit( &self, dataset: &DatasetBase, T>, - ) -> Result> { - let first = self - .first - .fit(dataset) - .map_err(ResidualSequenceError::First)?; + ) -> Result> { + let base = self.base.fit(dataset).map_err(ResidualChainError::Base)?; - let y_pred = first.predict(dataset.records()); + let y_pred = base.predict(dataset.records()); let residuals = &dataset.targets().as_targets() - &y_pred; let residual_dataset = DatasetBase::new(dataset.records().clone(), residuals); - let second = self - .second + let corrector_model = self + .corrector + .model .fit(&residual_dataset) - .map_err(ResidualSequenceError::Second)?; + .map_err(ResidualChainError::Corrector)?; - Ok(ResidualSequence { first, second }) + Ok(ResidualChain { + base, + corrector: Shrunk { + model: corrector_model, + shrinkage: self.corrector.shrinkage, + }, + }) } } impl> PredictInplace, Array1> - for ResidualSequence + for ResidualChain where for<'a> R1: Predict<&'a Arr2, Array1>, for<'a> R2: Predict<&'a Arr2, Array1>, { fn predict_inplace<'a>(&'a self, x: &'a Arr2, y: &mut Array1) { - y.assign(&self.first.predict(x)); - y.add_assign(&self.second.predict(x)); + y.assign(&self.base.predict(x)); + y.add_assign( + &self + .corrector + .model + .predict(x) + .mul(self.corrector.shrinkage), + ); } fn default_target(&self, x: &Arr2) -> Array1 { @@ -244,6 +282,97 @@ where } } +/// A model (params or fitted) paired with a shrinkage factor ν ∈ (0, 1]. +/// +/// Used in two roles: +/// - **Before fitting**: `Shrunk` wraps corrector params `C`; created by +/// [`Stagewise::shrink_by`] and stored in [`ResidualChain`] / [`ResidualChainParams`]. +/// - **After fitting**: `Shrunk` wraps the fitted corrector model; +/// prediction scales the corrector's output by ν before summing with the base. +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone)] +pub struct Shrunk { + model: M, + shrinkage: F, +} + +impl Shrunk { + pub fn model(&self) -> &M { + &self.model + } + pub fn shrinkage(&self) -> F { + self.shrinkage + } + /// Set the shrinkage factor. Validation happens when the containing + /// [`ResidualChainParams`] is checked via [`ParamGuard`]. + pub fn with_shrinkage(mut self, shrinkage: F) -> Self { + self.shrinkage = shrinkage; + self + } +} + +/// Unvalidated [`ResidualChain`] parameters returned by [`Stagewise::stack_with`]. +/// +/// Call `.fit()` to validate and fit in one step — the [`ParamGuard`] blanket +/// impl runs `check_ref` first, which verifies that the outermost corrector's +/// shrinkage factor is in (0, 1]. Inner chains validate lazily when their own +/// `.fit()` is called. You can also call `.check()` / `.check_unwrap()` to +/// validate explicitly. +/// +/// To set an explicit shrinkage factor on the corrector use +/// [`Shrunk::with_shrinkage`]: +/// +/// ``` +/// use linfa::traits::{Fit, Predict}; +/// use linfa::DatasetBase; +/// use linfa_linear::LinearRegression; +/// use linfa::composing::residual_sequence::{Shrunk, Stagewise}; +/// use ndarray::{array, Array2}; +/// +/// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +/// let y = array![0., 2., 4., 6., 8.]; +/// let dataset = DatasetBase::new(x.clone(), y); +/// +/// // The corrector's contribution is scaled by 0.1. +/// let fitted = LinearRegression::default() +/// .stack_with(LinearRegression::default().shrink_by(0.1)) +/// .fit(&dataset) +/// .unwrap(); +/// ``` +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone)] +pub struct ResidualChainParams(ResidualChain); + +impl ParamGuard for ResidualChainParams { + type Checked = ResidualChain; + type Error = crate::Error; + + fn check_ref(&self) -> Result<&ResidualChain, crate::Error> { + let v = self.0.corrector.shrinkage; + let err = crate::Error::Parameters(format!("shrinkage must be in (0, 1], got {v}")); + if v.to_f32() + .map_or(Err(err.clone()), |num| Ok(num > 0.0 && num <= 1.0))? + { + Ok(&self.0) + } else { + Err(err) + } + } + + fn check(self) -> Result, crate::Error> { + self.check_ref()?; + Ok(self.0) + } +} + #[cfg(test)] mod tests { use super::*; @@ -279,21 +408,21 @@ mod tests { } #[test] - fn second_is_fit_on_residuals() { - // targets = [1, 3]. first sees mean=2, predicts 2 for all. - // residuals = [1-2, 3-2] = [-1, 1]. second sees mean=0. - let model = MeanParams.stack_with(MeanParams); + fn corrector_is_fit_on_residuals() { + // targets = [1, 3]. base sees mean=2, predicts 2 for all. + // residuals = [1-2, 3-2] = [-1, 1]. corrector sees mean=0. + let model = MeanParams.stack_with(MeanParams.shrink_by(1.0)); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); let fitted = model.fit(&dataset).unwrap(); - assert_eq!(fitted.first().0, 2.0); // mean of [1, 3] - assert_eq!(fitted.second().0, 0.0); // mean of residuals [-1, 1] + assert_eq!(fitted.base().0, 2.0); // mean of [1, 3] + assert_eq!(fitted.corrector().model().0, 0.0); // mean of residuals [-1, 1] } #[test] fn predict_sums_both_models() { - // first predicts 2.0, second predicts 0.0 → sum = 2.0 - let model = MeanParams.stack_with(MeanParams); + // base predicts 2.0, corrector predicts 0.0 → sum = 2.0 + let model = MeanParams.stack_with(MeanParams.shrink_by(1.0)); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); let fitted = model.fit(&dataset).unwrap(); @@ -304,7 +433,7 @@ mod tests { #[test] fn predict_recovers_targets_when_residuals_fit_perfectly() { - // If second perfectly fits the residuals, the combined prediction = original targets. + // If the corrector perfectly fits the residuals, the combined prediction = original targets. struct FixedParams(f64); struct FixedModel(f64); @@ -324,10 +453,10 @@ mod tests { } } - // first predicts 3.0, second predicts 1.0 → sum = 4.0 + // base predicts 3.0, corrector predicts 1.0 → sum = 4.0 let model = FixedParams(3.0) - .stack_with(FixedParams(1.0)) - .stack_with(FixedParams(0.0)); + .stack_with(FixedParams(1.0).shrink_by(1.0)) + .stack_with(FixedParams(0.0).shrink_by(1.0)); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![4.0, 4.0]); let fitted = model.fit(&dataset).unwrap(); @@ -338,12 +467,68 @@ mod tests { #[test] fn deep_chain_accessors() { let model = MeanParams - .stack_with(MeanParams) - .stack_with(MeanParams) - .stack_with(MeanParams); + .stack_with(MeanParams.shrink_by(1.0)) + .stack_with(MeanParams.shrink_by(1.0)) + .stack_with(MeanParams.shrink_by(1.0)); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + assert_eq!(fitted.base().base().base().0, 2.0); // params trained on original targets + } + + #[test] + fn shrinkage_scales_corrector_prediction() { + // base predicts mean=2.0, corrector predicts mean=0.0 (residuals [-1,1]). + // With shrinkage=0.5, corrector contributes 0.5*0.0 = 0.0 → total = 2.0. + let model = MeanParams.stack_with(MeanParams.shrink_by(0.5)); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); let fitted = model.fit(&dataset).unwrap(); - assert_eq!(fitted.first().first().first().0, 2.0); // params trained on original targets + let preds = fitted.predict(&array![[0.0_f64], [1.0]]); + assert_eq!(preds, array![2.0, 2.0]); + assert_eq!(fitted.corrector().shrinkage(), 0.5); + } + + #[test] + fn shrinkage_corrector_sees_scaled_residuals() { + // base predicts 3.0 always. targets = [4.0, 4.0]. + // residuals = [1.0, 1.0]. corrector (mean) sees mean=1.0. + // With shrinkage=0.5: prediction = 3.0 + 0.5*1.0 = 3.5. + struct FixedParams(f64); + struct FixedModel(f64); + + impl Fit, Array1, DummyError> for FixedParams { + type Object = FixedModel; + fn fit( + &self, + _dataset: &DatasetBase, Array1>, + ) -> Result { + Ok(FixedModel(self.0)) + } + } + + impl<'a> Predict<&'a Array2, Array1> for FixedModel { + fn predict(&self, x: &'a Array2) -> Array1 { + Array1::from_elem(x.nrows(), self.0) + } + } + + let model = FixedParams(3.0).stack_with(MeanParams.shrink_by(0.5)); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![4.0, 4.0]); + let fitted = model.fit(&dataset).unwrap(); + + let preds = fitted.predict(&array![[0.0_f64], [1.0]]); + // corrector saw residuals [1.0, 1.0], mean=1.0, shrunk by 0.5 → 0.5 + assert!((preds[0] - 3.5_f64).abs() < 1e-10); + } + + #[test] + fn shrinkage_invalid_value_returns_error() { + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let model = MeanParams.stack_with(MeanParams.shrink_by(0.0)); + assert!(model.fit(&dataset).is_err()); + + let model = MeanParams.stack_with(MeanParams.shrink_by(1.5)); + assert!(model.fit(&dataset).is_err()); } } From a7bd119264af32d68b32eba6059d467cbb24878e Mon Sep 17 00:00:00 2001 From: feiyang Date: Wed, 4 Mar 2026 21:07:35 +0000 Subject: [PATCH 10/19] satisfy zola --- src/composing/residual_sequence.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index d5fb5bc96..9f5dd001d 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -176,7 +176,7 @@ impl ResidualChain { /// Calls can be chained to build arbitrarily deep sequences. /// - [`shrink_by`](Stagewise::shrink_by): wrap `self` in a [`Shrunk`] with the /// given learning rate ν, making it ready to pass as the `corrector` argument -/// to [`stack_with`]. +/// to [`Stagewise::stack_with`]. /// /// # Example /// From fd0aea0ae6bf87b6eef7f17ec5ee074e243aaf97 Mon Sep 17 00:00:00 2001 From: feiyang Date: Wed, 4 Mar 2026 21:14:07 +0000 Subject: [PATCH 11/19] can only shrink by if target has the same float type --- src/composing/residual_sequence.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index 9f5dd001d..3df70989f 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -118,7 +118,7 @@ use crate::dataset::{AsTargets, DatasetBase, Records}; use crate::param_guard::ParamGuard; use crate::traits::{Fit, Predict, PredictInplace}; use crate::Float; -use ndarray::{Array1, ArrayBase, Data, Ix1, Ix2, RawDataClone}; +use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, RawDataClone}; #[cfg(feature = "serde")] use serde_crate::{Deserialize, Serialize}; use std::ops::{AddAssign, Mul}; @@ -203,7 +203,13 @@ pub trait Stagewise: Sized { fn stack_with(self, corrector: Shrunk) -> ResidualChainParams; /// Wrap `self` in a [`Shrunk`] with learning rate `shrinkage` ∈ (0, 1], /// making it ready to pass as the `corrector` argument to [`stack_with`]. - fn shrink_by(self, shrinkage: F) -> Shrunk; + /// + /// The bound `Self: Fit, Array1, E>` ensures at compile time + /// that the model's element type matches the shrinkage type `F`. + fn shrink_by(self, shrinkage: F) -> Shrunk + where + Self: Fit, Array1, E>, + E: std::error::Error + From; } impl Stagewise for B { @@ -213,7 +219,11 @@ impl Stagewise for B { corrector, }) } - fn shrink_by(self, shrinkage: F) -> Shrunk { + fn shrink_by(self, shrinkage: F) -> Shrunk + where + Self: Fit, Array1, E>, + E: std::error::Error + From, + { Shrunk { model: self, shrinkage, From c8adb7eced110fb189266e6b3b39d8101875ace2 Mon Sep 17 00:00:00 2001 From: feiyang Date: Wed, 4 Mar 2026 21:20:23 +0000 Subject: [PATCH 12/19] work with predict inplace only --- src/composing/residual_sequence.rs | 33 +++++++++++++++++++----------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index 3df70989f..396e39324 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -273,11 +273,11 @@ where impl> PredictInplace, Array1> for ResidualChain where - for<'a> R1: Predict<&'a Arr2, Array1>, - for<'a> R2: Predict<&'a Arr2, Array1>, + R1: PredictInplace, Array1>, + R2: PredictInplace, Array1>, { fn predict_inplace<'a>(&'a self, x: &'a Arr2, y: &mut Array1) { - y.assign(&self.base.predict(x)); + self.base.predict_inplace(x, y); y.add_assign( &self .corrector @@ -411,9 +411,12 @@ mod tests { } } - impl<'a> Predict<&'a Array2, Array1> for MeanModel { - fn predict(&self, x: &'a Array2) -> Array1 { - Array1::from_elem(x.nrows(), self.0) + impl PredictInplace, Array1> for MeanModel { + fn predict_inplace(&self, x: &Array2, y: &mut Array1) { + y.assign(&Array1::from_elem(x.nrows(), self.0)); + } + fn default_target(&self, x: &Array2) -> Array1 { + Array1::zeros(x.nrows()) } } @@ -457,9 +460,12 @@ mod tests { } } - impl<'a> Predict<&'a Array2, Array1> for FixedModel { - fn predict(&self, x: &'a Array2) -> Array1 { - Array1::from_elem(x.nrows(), self.0) + impl PredictInplace, Array1> for FixedModel { + fn predict_inplace(&self, x: &Array2, y: &mut Array1) { + y.assign(&Array1::from_elem(x.nrows(), self.0)); + } + fn default_target(&self, x: &Array2) -> Array1 { + Array1::zeros(x.nrows()) } } @@ -517,9 +523,12 @@ mod tests { } } - impl<'a> Predict<&'a Array2, Array1> for FixedModel { - fn predict(&self, x: &'a Array2) -> Array1 { - Array1::from_elem(x.nrows(), self.0) + impl PredictInplace, Array1> for FixedModel { + fn predict_inplace(&self, x: &Array2, y: &mut Array1) { + y.assign(&Array1::from_elem(x.nrows(), self.0)); + } + fn default_target(&self, x: &Array2) -> Array1 { + Array1::zeros(x.nrows()) } } From 4e4b4941b0574f665ffa60399cab6f5ab9b98b04 Mon Sep 17 00:00:00 2001 From: feiyang Date: Wed, 4 Mar 2026 21:31:27 +0000 Subject: [PATCH 13/19] zola fix --- src/composing/residual_sequence.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index 396e39324..bf1e2e8ee 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -202,7 +202,7 @@ pub trait Stagewise: Sized { /// by calling `.stack_with(...)` on the returned [`ResidualChainParams`]. fn stack_with(self, corrector: Shrunk) -> ResidualChainParams; /// Wrap `self` in a [`Shrunk`] with learning rate `shrinkage` ∈ (0, 1], - /// making it ready to pass as the `corrector` argument to [`stack_with`]. + /// making it ready to pass as the `corrector` argument to [`Stagewise::stack_with`]. /// /// The bound `Self: Fit, Array1, E>` ensures at compile time /// that the model's element type matches the shrinkage type `F`. From 50ed24d11a68d3bc7fff76a200d1dbc0f1df77aa Mon Sep 17 00:00:00 2001 From: feiyang Date: Wed, 4 Mar 2026 21:38:30 +0000 Subject: [PATCH 14/19] add link in docs --- src/composing/residual_sequence.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index bf1e2e8ee..7ca272187 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -15,6 +15,11 @@ //! error. Shrinkage (learning rate ν ∈ (0, 1]) can be set per corrector via //! [`Shrunk::with_shrinkage`]; the default is ν = 1 (no scaling). //! +//! # References +//! +//! - J. H. Friedman (2001). "Greedy function approximation: A gradient boosting machine." +//! +//! //! # Examples //! //! ## Linear + linear From f024a6c5224b31259a2cf225acf40edf7d0e5a9c Mon Sep 17 00:00:00 2001 From: feiyang Date: Wed, 4 Mar 2026 21:46:05 +0000 Subject: [PATCH 15/19] simplify comparison --- src/composing/residual_sequence.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_sequence.rs index 7ca272187..274503ff2 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_sequence.rs @@ -372,13 +372,12 @@ impl ParamGuard for ResidualChainParams { fn check_ref(&self) -> Result<&ResidualChain, crate::Error> { let v = self.0.corrector.shrinkage; - let err = crate::Error::Parameters(format!("shrinkage must be in (0, 1], got {v}")); - if v.to_f32() - .map_or(Err(err.clone()), |num| Ok(num > 0.0 && num <= 1.0))? - { + if v > F::zero() && v <= F::one() { Ok(&self.0) } else { - Err(err) + Err(crate::Error::Parameters(format!( + "shrinkage must be in (0, 1], got {v}" + ))) } } From 69c4805b9714bdcb08ae6b54297da52561c8ad83 Mon Sep 17 00:00:00 2001 From: feiyang Date: Thu, 5 Mar 2026 17:33:48 +0000 Subject: [PATCH 16/19] rename to residual_chain as consistent with struct --- src/composing/mod.rs | 4 ++-- .../{residual_sequence.rs => residual_chain.rs} | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) rename src/composing/{residual_sequence.rs => residual_chain.rs} (98%) diff --git a/src/composing/mod.rs b/src/composing/mod.rs index d7aea6965..bb7271889 100644 --- a/src/composing/mod.rs +++ b/src/composing/mod.rs @@ -5,11 +5,11 @@ //! * `MultiTargetModel`: combine multiple univariate models to a single multi-target model //! * `Platt`: calibrate a classifier (i.e. SVC) to predicted posterior probabilities //! * `ResidualChain`: fit models sequentially on the residuals of the previous one -//! (forward stagewise additive modeling / L2Boosting); see [`residual_sequence::Stagewise`] +//! (forward stagewise additive modeling / L2Boosting); see [`residual_chain::Stagewise`] mod multi_class_model; mod multi_target_model; pub mod platt_scaling; -pub mod residual_sequence; +pub mod residual_chain; pub use multi_class_model::MultiClassModel; pub use multi_target_model::MultiTargetModel; diff --git a/src/composing/residual_sequence.rs b/src/composing/residual_chain.rs similarity index 98% rename from src/composing/residual_sequence.rs rename to src/composing/residual_chain.rs index 274503ff2..6b9b40ca4 100644 --- a/src/composing/residual_sequence.rs +++ b/src/composing/residual_chain.rs @@ -31,7 +31,7 @@ //! use linfa::traits::{Fit, Predict}; //! use linfa::DatasetBase; //! use linfa_linear::LinearRegression; -//! use linfa::composing::residual_sequence::{ResidualChain, Stagewise}; +//! use linfa::composing::residual_chain::{ResidualChain, Stagewise}; //! use ndarray::{array, Array2}; //! //! // y = 2x: perfectly linear, so the corrector should see zero residuals. @@ -57,7 +57,7 @@ //! use linfa::traits::{Fit, Predict}; //! use linfa::DatasetBase; //! use linfa_linear::LinearRegression; -//! use linfa::composing::residual_sequence::Stagewise; +//! use linfa::composing::residual_chain::Stagewise; //! use ndarray::{array, Array2}; //! //! // y = 2x: one linear model is enough to fit this perfectly. @@ -85,7 +85,7 @@ //! use linfa::traits::{Fit, Predict}; //! use linfa::DatasetBase; //! use linfa_linear::LinearRegression; -//! use linfa::composing::residual_sequence::{ResidualChain, Stagewise}; +//! use linfa::composing::residual_chain::{ResidualChain, Stagewise}; //! use linfa_svm::Svm; //! use ndarray::Array; //! @@ -189,7 +189,7 @@ impl ResidualChain { /// use linfa::traits::Fit; /// use linfa::DatasetBase; /// use linfa_linear::LinearRegression; -/// use linfa::composing::residual_sequence::Stagewise; +/// use linfa::composing::residual_chain::Stagewise; /// use ndarray::{array, Array2}; /// /// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); @@ -345,7 +345,7 @@ impl Shrunk { /// use linfa::traits::{Fit, Predict}; /// use linfa::DatasetBase; /// use linfa_linear::LinearRegression; -/// use linfa::composing::residual_sequence::{Shrunk, Stagewise}; +/// use linfa::composing::residual_chain::{Shrunk, Stagewise}; /// use ndarray::{array, Array2}; /// /// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); From 54d148ee6d7dcd9f734d8412cc46f202ec4be7e1 Mon Sep 17 00:00:00 2001 From: feiyang Date: Sun, 8 Mar 2026 16:30:53 +0000 Subject: [PATCH 17/19] implement copy trait --- src/composing/residual_chain.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/composing/residual_chain.rs b/src/composing/residual_chain.rs index 6b9b40ca4..01cb828c6 100644 --- a/src/composing/residual_chain.rs +++ b/src/composing/residual_chain.rs @@ -155,7 +155,7 @@ pub enum ResidualChainError { derive(Serialize, Deserialize), serde(crate = "serde_crate") )] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct ResidualChain { base: B, corrector: Shrunk, @@ -309,7 +309,7 @@ where derive(Serialize, Deserialize), serde(crate = "serde_crate") )] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct Shrunk { model: M, shrinkage: F, @@ -363,7 +363,7 @@ impl Shrunk { derive(Serialize, Deserialize), serde(crate = "serde_crate") )] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct ResidualChainParams(ResidualChain); impl ParamGuard for ResidualChainParams { From ce5809382720d02c96e07e746536d4d8205b0b85 Mon Sep 17 00:00:00 2001 From: feiyang Date: Sun, 8 Mar 2026 17:17:57 +0000 Subject: [PATCH 18/19] add method `chain` which just chains self with an unshrunk corrector. rename stack_with -> chain_shrunk --- src/composing/residual_chain.rs | 69 +++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/src/composing/residual_chain.rs b/src/composing/residual_chain.rs index 01cb828c6..9ed102061 100644 --- a/src/composing/residual_chain.rs +++ b/src/composing/residual_chain.rs @@ -40,7 +40,7 @@ //! let dataset = DatasetBase::new(x.clone(), y); //! //! let fitted = LinearRegression::default() -//! .stack_with(LinearRegression::default().shrink_by(1.0)) +//! .chain(LinearRegression::default()) //! .fit(&dataset) //! .unwrap(); //! @@ -66,7 +66,7 @@ //! let dataset = DatasetBase::new(x.clone(), y); //! //! let fitted = LinearRegression::default() -//! .stack_with(LinearRegression::default().shrink_by(1.0)) +//! .chain(LinearRegression::default()) //! .fit(&dataset) //! .unwrap(); //! @@ -100,18 +100,16 @@ //! let fitted = Svm::::params() //! .c_svr(1., None) //! .linear_kernel() -//! .stack_with( +//! .chain( //! Svm::::params() //! .c_svr(10., Some(0.1)) -//! .gaussian_kernel(1.) -//! .shrink_by(1.0), +//! .gaussian_kernel(1.), //! ) -//! .stack_with(LinearRegression::default().shrink_by(1.0)) -//! .stack_with( +//! .chain(LinearRegression::default()) +//! .chain( //! Svm::::params() //! .c_svr(10., Some(0.1)) -//! .gaussian_kernel(3.) -//! .shrink_by(1.0), +//! .gaussian_kernel(3.), //! ) //! .fit(&dataset) //! .unwrap(); @@ -175,13 +173,13 @@ impl ResidualChain { /// Blanket-implemented for all `Sized` types, so any model params type gains /// these methods automatically: /// -/// - [`stack_with`](Stagewise::stack_with): compose `self` (as the base) with +/// - [`chain`](Stagewise::chain): compose `self` (as the base) with /// a [`Shrunk`] corrector that will be trained on the residuals left by /// `self`. Returns a [`ResidualChainParams`] whose `.fit()` runs both stages. /// Calls can be chained to build arbitrarily deep sequences. /// - [`shrink_by`](Stagewise::shrink_by): wrap `self` in a [`Shrunk`] with the /// given learning rate ν, making it ready to pass as the `corrector` argument -/// to [`Stagewise::stack_with`]. +/// to [`Stagewise::chain`]. /// /// # Example /// @@ -197,17 +195,23 @@ impl ResidualChain { /// let dataset = DatasetBase::new(x.clone(), y); /// /// let fitted = LinearRegression::default() -/// .stack_with(LinearRegression::default().shrink_by(1.0)) +/// .chain(LinearRegression::default()) /// .fit(&dataset) /// .unwrap(); /// ``` pub trait Stagewise: Sized { /// Compose `self` (as the base model) with `corrector`, which will be /// trained on the residuals left by `self`. Further stages can be appended - /// by calling `.stack_with(...)` on the returned [`ResidualChainParams`]. - fn stack_with(self, corrector: Shrunk) -> ResidualChainParams; + /// by calling `.chain_shrunk(...)` on the returned [`ResidualChainParams`]. + fn chain_shrunk(self, corrector: Shrunk) -> ResidualChainParams; + + fn chain(self, corrector: C) -> ResidualChainParams + where + C: Fit, Array1, E>, + E: std::error::Error + From; + /// Wrap `self` in a [`Shrunk`] with learning rate `shrinkage` ∈ (0, 1], - /// making it ready to pass as the `corrector` argument to [`Stagewise::stack_with`]. + /// making it ready to pass as the `corrector` argument to [`Stagewise::chain_shrunk`]. /// /// The bound `Self: Fit, Array1, E>` ensures at compile time /// that the model's element type matches the shrinkage type `F`. @@ -218,12 +222,19 @@ pub trait Stagewise: Sized { } impl Stagewise for B { - fn stack_with(self, corrector: Shrunk) -> ResidualChainParams { + fn chain_shrunk(self, corrector: Shrunk) -> ResidualChainParams { ResidualChainParams(ResidualChain { base: self, corrector, }) } + fn chain(self, corrector: C) -> ResidualChainParams + where + C: Fit, Array1, E>, + E: std::error::Error + From, + { + self.chain_shrunk(corrector.shrink_by(F::one())) + } fn shrink_by(self, shrinkage: F) -> Shrunk where Self: Fit, Array1, E>, @@ -330,7 +341,7 @@ impl Shrunk { } } -/// Unvalidated [`ResidualChain`] parameters returned by [`Stagewise::stack_with`]. +/// Unvalidated [`ResidualChain`] parameters returned by [`Stagewise::chain_shrunk`]. /// /// Call `.fit()` to validate and fit in one step — the [`ParamGuard`] blanket /// impl runs `check_ref` first, which verifies that the outermost corrector's @@ -354,7 +365,7 @@ impl Shrunk { /// /// // The corrector's contribution is scaled by 0.1. /// let fitted = LinearRegression::default() -/// .stack_with(LinearRegression::default().shrink_by(0.1)) +/// .chain_shrunk(LinearRegression::default().shrink_by(0.1)) /// .fit(&dataset) /// .unwrap(); /// ``` @@ -428,7 +439,7 @@ mod tests { fn corrector_is_fit_on_residuals() { // targets = [1, 3]. base sees mean=2, predicts 2 for all. // residuals = [1-2, 3-2] = [-1, 1]. corrector sees mean=0. - let model = MeanParams.stack_with(MeanParams.shrink_by(1.0)); + let model = MeanParams.chain(MeanParams); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); let fitted = model.fit(&dataset).unwrap(); @@ -439,7 +450,7 @@ mod tests { #[test] fn predict_sums_both_models() { // base predicts 2.0, corrector predicts 0.0 → sum = 2.0 - let model = MeanParams.stack_with(MeanParams.shrink_by(1.0)); + let model = MeanParams.chain(MeanParams); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); let fitted = model.fit(&dataset).unwrap(); @@ -475,8 +486,8 @@ mod tests { // base predicts 3.0, corrector predicts 1.0 → sum = 4.0 let model = FixedParams(3.0) - .stack_with(FixedParams(1.0).shrink_by(1.0)) - .stack_with(FixedParams(0.0).shrink_by(1.0)); + .chain(FixedParams(1.0)) + .chain(FixedParams(0.0)); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![4.0, 4.0]); let fitted = model.fit(&dataset).unwrap(); @@ -487,9 +498,9 @@ mod tests { #[test] fn deep_chain_accessors() { let model = MeanParams - .stack_with(MeanParams.shrink_by(1.0)) - .stack_with(MeanParams.shrink_by(1.0)) - .stack_with(MeanParams.shrink_by(1.0)); + .chain(MeanParams) + .chain(MeanParams) + .chain(MeanParams); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); let fitted = model.fit(&dataset).unwrap(); @@ -500,7 +511,7 @@ mod tests { fn shrinkage_scales_corrector_prediction() { // base predicts mean=2.0, corrector predicts mean=0.0 (residuals [-1,1]). // With shrinkage=0.5, corrector contributes 0.5*0.0 = 0.0 → total = 2.0. - let model = MeanParams.stack_with(MeanParams.shrink_by(0.5)); + let model = MeanParams.chain_shrunk(MeanParams.shrink_by(0.5)); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); let fitted = model.fit(&dataset).unwrap(); @@ -536,7 +547,7 @@ mod tests { } } - let model = FixedParams(3.0).stack_with(MeanParams.shrink_by(0.5)); + let model = FixedParams(3.0).chain_shrunk(MeanParams.shrink_by(0.5)); let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![4.0, 4.0]); let fitted = model.fit(&dataset).unwrap(); @@ -548,10 +559,10 @@ mod tests { #[test] fn shrinkage_invalid_value_returns_error() { let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); - let model = MeanParams.stack_with(MeanParams.shrink_by(0.0)); + let model = MeanParams.chain_shrunk(MeanParams.shrink_by(0.0)); assert!(model.fit(&dataset).is_err()); - let model = MeanParams.stack_with(MeanParams.shrink_by(1.5)); + let model = MeanParams.chain_shrunk(MeanParams.shrink_by(1.5)); assert!(model.fit(&dataset).is_err()); } } From 4a712e69986e1678e30d2b58603380565d1d7335 Mon Sep 17 00:00:00 2001 From: feiyang Date: Sun, 8 Mar 2026 17:23:06 +0000 Subject: [PATCH 19/19] add doc --- src/composing/residual_chain.rs | 34 ++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/composing/residual_chain.rs b/src/composing/residual_chain.rs index 9ed102061..3713703d1 100644 --- a/src/composing/residual_chain.rs +++ b/src/composing/residual_chain.rs @@ -173,13 +173,16 @@ impl ResidualChain { /// Blanket-implemented for all `Sized` types, so any model params type gains /// these methods automatically: /// -/// - [`chain`](Stagewise::chain): compose `self` (as the base) with -/// a [`Shrunk`] corrector that will be trained on the residuals left by -/// `self`. Returns a [`ResidualChainParams`] whose `.fit()` runs both stages. -/// Calls can be chained to build arbitrarily deep sequences. +/// - [`chain`](Stagewise::chain): compose `self` (as the base) with a corrector +/// that will be trained on the residuals left by `self`. The corrector is used +/// without shrinkage (ν = 1). Returns a [`ResidualChainParams`] whose `.fit()` +/// runs both stages. Calls can be chained to build arbitrarily deep sequences. +/// - [`chain_shrunk`](Stagewise::chain_shrunk): like `chain`, but accepts a +/// [`Shrunk`]-wrapped corrector so you can control the learning rate ν +/// explicitly via [`shrink_by`](Stagewise::shrink_by). /// - [`shrink_by`](Stagewise::shrink_by): wrap `self` in a [`Shrunk`] with the -/// given learning rate ν, making it ready to pass as the `corrector` argument -/// to [`Stagewise::chain`]. +/// given learning rate ν ∈ (0, 1], making it ready to pass as the `corrector` +/// argument to [`Stagewise::chain_shrunk`]. /// /// # Example /// @@ -200,11 +203,24 @@ impl ResidualChain { /// .unwrap(); /// ``` pub trait Stagewise: Sized { - /// Compose `self` (as the base model) with `corrector`, which will be - /// trained on the residuals left by `self`. Further stages can be appended - /// by calling `.chain_shrunk(...)` on the returned [`ResidualChainParams`]. + /// Compose `self` (as the base model) with a [`Shrunk`]-wrapped `corrector`, + /// which will be trained on the residuals left by `self`. Further stages can + /// be appended by calling `.chain(...)` or `.chain_shrunk(...)` on the + /// returned [`ResidualChainParams`]. + /// + /// Use [`chain`](Stagewise::chain) instead when you don't need to shrink + /// the corrector. fn chain_shrunk(self, corrector: Shrunk) -> ResidualChainParams; + /// Compose `self` (as the base model) with `corrector`, which will be + /// trained on the residuals left by `self`. The corrector is used without + /// shrinkage (equivalent to `shrink_by(1.0)`). Further stages can be + /// appended by calling `.chain(...)` or `.chain_shrunk(...)` on the + /// returned [`ResidualChainParams`]. + /// + /// Use [`chain_shrunk`](Stagewise::chain_shrunk) together with + /// [`shrink_by`](Stagewise::shrink_by) when you need to control the + /// learning rate ν of the corrector explicitly. fn chain(self, corrector: C) -> ResidualChainParams where C: Fit, Array1, E>,