Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,25 @@ jobs:
runs-on: ubuntu-22.04
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
sklearn-version: ["1.5.0", "1.6.1", "1.8.0", ""]
exclude:
# sklearn 1.5+ requires Python >=3.9
- python-version: "3.7"
sklearn-version: "1.5.0"
- python-version: "3.8"
sklearn-version: "1.5.0"
- python-version: "3.7"
sklearn-version: "1.6.1"
- python-version: "3.8"
sklearn-version: "1.6.1"
- python-version: "3.7"
sklearn-version: "1.8.0"
- python-version: "3.8"
sklearn-version: "1.8.0"
# sklearn 1.5 doesn't have wheels for Python 3.13
- python-version: "3.13"
sklearn-version: "1.5.0"

steps:
- uses: actions/checkout@v3
Expand All @@ -21,6 +39,10 @@ jobs:
python -m pip install --upgrade pip
pip install .[dev]
pip install flake8
- name: Pin scikit-learn version (if specified)
if: matrix.sklearn-version != ''
run: |
pip install scikit-learn==${{ matrix.sklearn-version }}
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ keywords = ["clustering", "mixtures", "lca", "em", "latent-class-analysis", "exp
dependencies = [
"numpy",
"pandas",
"scikit-learn >= 1.0.0, <=1.5.0",
"scikit-learn >= 1.0.0",
"scipy",
"tqdm",
]
Expand Down
24 changes: 24 additions & 0 deletions stepmix/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Compatibility layer for different scikit-learn versions."""

try:
from sklearn.utils.validation import validate_data
except ImportError:

def validate_data(estimator, X="no_validation", y="no_validation", **kwargs):
return estimator._validate_data(X, y, **kwargs)


try:
from sklearn.utils.validation import check_array as _check_array

_check_array([[0]], ensure_all_finite=True)
_USES_ENSURE_ALL_FINITE = True
except TypeError:
_USES_ENSURE_ALL_FINITE = False


def _finite_param(ensure_all_finite):
"""Return the correct kwarg dict for the current sklearn version."""
if _USES_ENSURE_ALL_FINITE:
return {"ensure_all_finite": ensure_all_finite}
return {"force_all_finite": ensure_all_finite}
30 changes: 20 additions & 10 deletions stepmix/stepmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
import numpy as np

from scipy.special import logsumexp
from sklearn.mixture._base import BaseEstimator
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.validation import (
check_random_state,
check_is_fitted,
check_array,
_check_sample_weight,
)
from sklearn.cluster import KMeans

from ._compat import validate_data, _finite_param
import tqdm

from . import utils
Expand Down Expand Up @@ -464,24 +467,24 @@ def _check_x_y(self, X=None, Y=None, reset=False):
Validated structural data or None if not provided.

"""
# We use reset True since we take care of dimensions in this class (and not in the parent)
if X is not None:
X_names = utils.extract_column_names(X)
X = self._validate_data(
X = validate_data(
self,
X,
dtype=[np.float64, np.float32],
reset=True,
force_all_finite=self._force_all_finite_mm,
reset=reset,
**_finite_param(self._force_all_finite_mm),
)
if Y is not None:
# Handle 1D Y array
Y_names = utils.extract_column_names(Y)
Y = self._validate_data(
# check_array (not validate_data) to avoid overwriting n_features_in_
Y = check_array(
Y,
dtype=[np.float64, np.float32],
reset=True,
ensure_2d=False,
force_all_finite=self._force_all_finite_sm,
**_finite_param(self._force_all_finite_sm),
)

# Force a matrix format
Expand Down Expand Up @@ -1237,7 +1240,7 @@ def _pivot_cw(self, df, aggfunc=np.std):

########################################################################################################################
# INFERENCE
def score(self, X, Y=None, sample_weight=None):
def score(self, X, Y=None, sample_weight=None, y=None):
"""Compute the average log-likelihood over samples.

Setting Y=None will ignore the structural likelihood.
Expand All @@ -1261,6 +1264,9 @@ def score(self, X, Y=None, sample_weight=None):
avg_ll: float
Average log likelihood over samples.
"""
if y is not None and Y is None:
Y = y

check_is_fitted(self)
X, Y = self._check_x_y(X, Y)
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype, copy=True)
Expand Down Expand Up @@ -1495,7 +1501,7 @@ def predict_proba_Y(self, X):

return self._sm.predict_proba(log_resp)

def predict(self, X, Y=None):
def predict(self, X, Y=None, y=None):
"""Predict the cluster/latent class/component labels for the data samples in X.

Optionally, an array-like Y can be provided to predict the labels based on both the measurement and structural
Expand All @@ -1511,11 +1517,15 @@ def predict(self, X, Y=None):
List of n_features-dimensional data points to fit the structural model. Each row
corresponds to a single data point. If the data is categorical, by default it should be
0-indexed and integer encoded (not one-hot encoded).
y : array-like of shape (n_samples, n_features_structural), default=None
Alias for Y to maintain scikit-learn API compatibility.
Returns
-------
labels : array, shape (n_samples,)
Component labels.
"""
if y is not None and Y is None:
Y = y
return self.predict_class(X, Y)

def predict_proba(self, X, Y=None):
Expand Down