Source code for causalpy.skl_models

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Custom scikit-learn models for causal inference"""

from functools import partial
from typing import Optional, Tuple

import numpy as np
import statsmodels.api as sm
from scipy.optimize import fmin_slsqp
from sklearn.base import RegressorMixin
from sklearn.linear_model._base import LinearModel

from causalpy.utils import round_num


[docs] class ScikitLearnAdaptor: """Base class for scikit-learn models that can be used for causal inference."""
[docs] def calculate_impact(self, y_true, y_pred): """Calculate the causal impact of the intervention.""" return y_true - y_pred
[docs] def calculate_cumulative_impact(self, impact): """Calculate the cumulative impact intervention.""" return np.cumsum(impact)
[docs] def print_coefficients(self, labels, round_to=None) -> None: """Print the coefficients of the model with the corresponding labels.""" print("Model coefficients:") coef_ = self.get_coeffs() # Determine the width of the longest label max_label_length = max(len(name) for name in labels) # Print each coefficient with formatted alignment for name, val in zip(labels, coef_): # Left-align the name formatted_name = f"{name:<{max_label_length}}" # Right-align the value with width 10 formatted_val = f"{round_num(val, round_to):>10}" print(f" {formatted_name}\t{formatted_val}")
[docs] def get_coeffs(self): """Get the coefficients of the model as a numpy array.""" return np.squeeze(self.coef_)
[docs] class WeightedProportion(ScikitLearnAdaptor, LinearModel, RegressorMixin): """Weighted proportion model for causal inference. Used for synthetic control methods for example"""
[docs] def loss(self, W, X, y): """Compute root mean squared loss with data X, weights W, and predictor y""" return np.sqrt(np.mean((y - np.dot(X, W.T)) ** 2))
[docs] def fit(self, X, y): """Fit model on data X with predictor y""" w_start = [1 / X.shape[1]] * X.shape[1] coef_ = fmin_slsqp( partial(self.loss, X=X, y=y), np.array(w_start), f_eqcons=lambda w: np.sum(w) - 1, bounds=[(0.0, 1.0)] * len(w_start), disp=False, ) self.coef_ = np.atleast_2d(coef_) # return as column vector self.mse = self.loss(W=self.coef_, X=X, y=y) return self
[docs] def predict(self, X): """Predict results for data X""" return np.dot(X, self.coef_.T)
[docs] class TransferFunctionOLS(ScikitLearnAdaptor, LinearModel, RegressorMixin): """ OLS model with transfer functions for graded interventions. This model supports: - HAC (Newey-West) standard errors for robust inference (default) - ARIMAX error models for explicit autocorrelation modeling - Saturation and adstock transforms for treatment effects This model is designed to work with the GradedInterventionTimeSeries experiment class following the standard CausalPy pattern where the experiment handles data preparation and calls model.fit(). Parameters ---------- saturation_type : str, default="hill" Type of saturation function: "hill", "logistic", or "michaelis_menten". saturation_grid : dict, optional For grid search: dict mapping parameter names to lists of values. E.g., {"slope": [1.0, 2.0], "kappa": [3, 5]}. saturation_bounds : dict, optional For optimization: dict mapping parameter names to (min, max) tuples. E.g., {"slope": (0.5, 5.0), "kappa": (2, 10)}. adstock_grid : dict, optional For grid search: dict mapping parameter names to lists of values. E.g., {"half_life": [2, 3, 4]}. adstock_bounds : dict, optional For optimization: dict mapping parameter names to (min, max) tuples. E.g., {"half_life": (1, 10)}. estimation_method : str, default="grid" Method for parameter estimation: "grid" or "optimize". error_model : str, default="hac" Error model specification: "hac" or "arimax". arima_order : tuple of (int, int, int), optional ARIMA order (p, d, q) when error_model="arimax". hac_maxlags : int, optional Maximum lags for HAC standard errors. coef_constraint : str, default="nonnegative" Constraint on treatment coefficients. Attributes ---------- ols_result : statsmodels regression result Fitted OLS or ARIMAX model result. treatments : List[Treatment] Treatment specifications with transform objects. score : float R-squared of the model. coef_ : np.ndarray Model coefficients (for sklearn compatibility). Examples -------- .. code-block:: python # Create unfitted model with configuration model = cp.skl_models.TransferFunctionOLS( saturation_type="hill", saturation_grid={"slope": [1.0, 2.0], "kappa": [3, 5]}, adstock_grid={"half_life": [2, 3, 4]}, estimation_method="grid", error_model="hac", ) # Use with experiment class (experiment calls fit()) result = cp.GradedInterventionTimeSeries( data=df, y_column="outcome", treatment_names=["exposure"], base_formula="1 + t", model=model, ) """
[docs] def __init__( self, saturation_type: str = "hill", saturation_grid: Optional[dict] = None, saturation_bounds: Optional[dict] = None, adstock_grid: Optional[dict] = None, adstock_bounds: Optional[dict] = None, estimation_method: str = "grid", error_model: str = "hac", arima_order: Optional[Tuple[int, int, int]] = None, hac_maxlags: Optional[int] = None, coef_constraint: str = "nonnegative", ): """Initialize model with configuration parameters.""" # Store configuration self.saturation_type = saturation_type self.saturation_grid = saturation_grid self.saturation_bounds = saturation_bounds self.adstock_grid = adstock_grid self.adstock_bounds = adstock_bounds self.estimation_method = estimation_method self.error_model = error_model self.arima_order = arima_order self.hac_maxlags = hac_maxlags self.coef_constraint = coef_constraint # Validate error model if error_model not in ["hac", "arimax"]: raise ValueError( f"error_model must be 'hac' or 'arimax', got '{error_model}'" ) if error_model == "arimax" and arima_order is None: raise ValueError( "arima_order must be provided when error_model='arimax'. " "E.g., arima_order=(1, 0, 0) for AR(1) errors" ) # Validate estimation method and required parameters if estimation_method == "grid": if saturation_grid is None: raise ValueError( "saturation_grid is required for grid search method. " "E.g., saturation_grid={'slope': [1.0, 2.0], 'kappa': [3, 5]}" ) if adstock_grid is None: raise ValueError( "adstock_grid is required for grid search method. " "E.g., adstock_grid={'half_life': [2, 3, 4]}" ) elif estimation_method == "optimize": if saturation_bounds is None: raise ValueError( "saturation_bounds is required for optimize method. " "E.g., saturation_bounds={'slope': (0.5, 5.0), 'kappa': (2, 10)}" ) if adstock_bounds is None: raise ValueError( "adstock_bounds is required for optimize method. " "E.g., adstock_bounds={'half_life': (1, 10)}" ) else: raise ValueError( f"estimation_method must be 'grid' or 'optimize', got '{estimation_method}'" ) # Initialize attributes (set by fit()) self.ols_result = None self.treatments = None self.score = None self.coef_ = None # For sklearn compatibility self.arimax_model = None # Transform estimation metadata (set by fit()) self.transform_estimation_results = None self.transform_search_space = None
[docs] def fit(self, X: np.ndarray, y: np.ndarray): """ Fit OLS model with HAC/ARIMAX errors. Note: This method expects X to already contain the transformed treatment variables. Transform parameter estimation is handled by the experiment class. Parameters ---------- X : np.ndarray Full design matrix (n_obs, n_features) including baseline AND transformed treatment variables. y : np.ndarray Outcome variable (n_obs,). Returns ------- self : TransferFunctionOLS Fitted model. """ # Fit model with chosen error structure if self.error_model == "hac": # Fit OLS with HAC standard errors if self.hac_maxlags is None: # Newey & West (1994) rule of thumb n = len(y) self.hac_maxlags = int(np.floor(4 * (n / 100) ** (2 / 9))) self.ols_result = sm.OLS(y, X).fit( cov_type="HAC", cov_kwds={"maxlags": self.hac_maxlags} ) elif self.error_model == "arimax": # Fit ARIMAX model import warnings from statsmodels.tsa.statespace.sarimax import SARIMAX # Suppress convergence warnings with warnings.catch_warnings(): warnings.simplefilter("ignore") self.arimax_model = SARIMAX(y, exog=X, order=self.arima_order) self.ols_result = self.arimax_model.fit( disp=0, maxiter=200, method="lbfgs", ) # Compute R-squared if hasattr(self.ols_result, "rsquared"): self.score = self.ols_result.rsquared else: # For ARIMAX, compute R-squared manually residuals = self.ols_result.resid ss_res = np.sum(residuals**2) ss_tot = np.sum((y - np.mean(y)) ** 2) self.score = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0.0 # Store coefficients for sklearn compatibility self.coef_ = self.ols_result.params.reshape(1, -1) return self
[docs] def predict(self, X: np.ndarray) -> np.ndarray: """ Predict using the fitted model. Parameters ---------- X : np.ndarray Design matrix (n_obs, n_features). Returns ------- y_pred : np.ndarray Predicted values (n_obs,). """ if self.ols_result is None: raise ValueError("Model has not been fitted yet. Call fit() first.") return X @ self.ols_result.params
[docs] def create_causalpy_compatible_class( estimator: type[RegressorMixin], ) -> type[RegressorMixin]: """This function takes a scikit-learn estimator and returns a new class that is compatible with CausalPy.""" _add_mixin_methods(estimator, ScikitLearnAdaptor) return estimator
def _add_mixin_methods(model_instance, mixin_class): """Utility function to bind mixin methods to an existing model instance.""" for attr_name in dir(mixin_class): attr = getattr(mixin_class, attr_name) if callable(attr) and not attr_name.startswith("__"): # Bind the method to the instance method = attr.__get__(model_instance, model_instance.__class__) setattr(model_instance, attr_name, method) return model_instance