# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Graded Intervention Time Series Experiment
This module implements experiments for estimating the causal effects of graded
interventions (e.g., media spend, policy intensity) in single-market time series
using transfer functions that model saturation and adstock (carryover) effects.
The experiment works with the TransferFunctionOLS model class (from skl_models)
to provide a complete causal inference workflow including visualization,
diagnostics, and counterfactual effect estimation.
"""
from typing import Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from patsy import dmatrix
from sklearn.base import RegressorMixin
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.stats.diagnostic import acorr_ljungbox
from causalpy.custom_exceptions import BadIndexException
from causalpy.pymc_models import PyMCModel
from causalpy.transforms import Treatment
from causalpy.utils import round_num
from .base import BaseExperiment
LEGEND_FONT_SIZE = 12
[docs]
class GradedInterventionTimeSeries(BaseExperiment):
"""
Interrupted time series experiment with graded interventions and transfer functions.
This experiment class handles causal inference for time series with graded
(non-binary) interventions, incorporating saturation and adstock effects.
Following the standard CausalPy pattern, it takes data and an unfitted model,
performs transform parameter estimation, fits the model, and provides
visualization, diagnostics, and counterfactual effect estimation.
Typical workflow:
1. Create an unfitted TransferFunctionOLS model with configuration
2. Pass data + model to this experiment class
3. Experiment estimates transforms, fits model, and provides results
4. Use experiment methods for visualization and effect estimation
Fitting Procedure
-----------------
The experiment uses a nested optimization approach to estimate transform parameters
and fit the regression model:
**Outer Loop (Transform Parameter Estimation):**
The experiment searches for optimal saturation and adstock parameters either via
grid search (exhaustive evaluation of discrete parameter combinations) or continuous
optimization (gradient-based search). For grid search with N saturation parameter
combinations and M adstock parameter combinations, all N x M combinations are
evaluated.
**Inner Loop (Model Fitting):**
For each candidate set of transform parameters, the raw treatment variable is
transformed by applying saturation (diminishing returns) and adstock (carryover
effects). The transformed treatment is combined with baseline predictors to create
a full design matrix, and an OLS or ARIMAX model is fitted. The model that achieves
the lowest root mean squared error (RMSE) determines the selected parameters.
This nested approach is computationally efficient because OLS has a closed-form
solution requiring only matrix operations, making each individual model fit very
fast. For ARIMAX error models, numerical optimization is required for each fit,
increasing computational cost but providing explicit modeling of autocorrelation
structure.
Parameters
----------
data : pd.DataFrame
Time series data with datetime or numeric index.
y_column : str
Name of the outcome variable column.
treatment_names : List[str]
List of treatment variable names (e.g., ["comm_intensity"]).
base_formula : str
Patsy formula for baseline model (e.g., "1 + t + temperature").
model : TransferFunctionOLS
UNFITTED model with configuration for transform parameter estimation.
Attributes
----------
data : pd.DataFrame
Input data.
y : np.ndarray
Outcome variable values.
X_baseline : np.ndarray
Baseline design matrix.
Z_treatment : np.ndarray
Treatment design matrix.
X_full : np.ndarray
Full design matrix.
predictions : np.ndarray
Fitted values from model.
residuals : np.ndarray
Model residuals.
score : float
R-squared of the model.
Examples
--------
.. code-block:: python
import causalpy as cp
# Step 1: Create UNFITTED model with configuration
model = cp.skl_models.TransferFunctionOLS(
saturation_type="hill",
saturation_grid={"slope": [1.0, 2.0, 3.0], "kappa": [3, 5, 7]},
adstock_grid={"half_life": [2, 3, 4, 5]},
estimation_method="grid",
error_model="hac",
)
# Step 2: Pass to experiment (experiment estimates transforms and fits model)
result = cp.GradedInterventionTimeSeries(
data=df,
y_column="water_consumption",
treatment_names=["comm_intensity"],
base_formula="1 + t + temperature + rainfall",
model=model,
)
# Step 3: Use experiment methods
result.summary()
result.plot()
result.plot_diagnostics()
effect = result.effect(window=(df.index[0], df.index[-1]), scale=0.0)
"""
expt_type = "Graded Intervention Time Series"
supports_ols = True
supports_bayes = True
[docs]
def __init__(
self,
data: pd.DataFrame,
y_column: str,
treatment_names: List[str],
base_formula: str,
model=None,
**kwargs,
):
"""
Initialize experiment with data and unfitted model (standard CausalPy pattern).
This method:
1. Validates inputs and builds baseline design matrix
2. Estimates transform parameters for each treatment
3. Applies transforms and builds full design matrix
4. Calls model.fit(X_full, y)
5. Extracts results for visualization and analysis
Parameters
----------
data : pd.DataFrame
Time series data.
y_column : str
Name of outcome variable.
treatment_names : List[str]
List of treatment variable names (e.g., ["comm_intensity"]).
base_formula : str
Patsy formula for baseline model.
model : TransferFunctionOLS
UNFITTED model with configuration for transform estimation.
"""
super().__init__(model=model)
# Validate inputs
self._validate_inputs(data, y_column, treatment_names)
# Store attributes
self.data = data.copy()
self.y_column = y_column
self.treatment_names = treatment_names
self.base_formula = base_formula
# Extract outcome variable
self.y = data[y_column].values
# Build baseline design matrix (like other experiments do)
self.X_baseline = np.asarray(dmatrix(base_formula, data))
self.baseline_labels = dmatrix(base_formula, data).design_info.column_names
# ====================================================================
# Conditional logic: Bayesian vs OLS models
# ====================================================================
if isinstance(self.model, PyMCModel):
# ================================================================
# BAYESIAN MODEL PATH
# ================================================================
# For Bayesian models, transforms are estimated jointly within PyMC
# So we skip grid/optimize parameter estimation and pass raw data
import xarray as xr
# Convert to xarray format (like DifferenceInDifferences does)
self.X = xr.DataArray(
self.X_baseline,
dims=["obs_ind", "coeffs"],
coords={
"obs_ind": np.arange(self.X_baseline.shape[0]),
"coeffs": self.baseline_labels,
},
)
self.y = xr.DataArray(
self.y.reshape(-1, 1),
dims=["obs_ind", "treated_units"],
coords={
"obs_ind": np.arange(len(self.y)),
"treated_units": ["unit_0"],
},
)
# Get raw treatment data
treatment_raw = np.column_stack(
[data[name].values for name in treatment_names]
)
treatment_data = xr.DataArray(
treatment_raw,
dims=["obs_ind", "treatment_names"],
coords={
"obs_ind": np.arange(treatment_raw.shape[0]),
"treatment_names": treatment_names,
},
)
# Setup coordinates for PyMC
COORDS = {
"coeffs": self.baseline_labels,
"obs_ind": np.arange(self.X_baseline.shape[0]),
"treated_units": ["unit_0"],
"treatment_names": treatment_names,
}
# Fit Bayesian model
self.model.fit(
X=self.X, y=self.y, coords=COORDS, treatment_data=treatment_data
)
# Store for later use
self.treatment_labels = treatment_names
self.treatments = None # Not used for Bayesian (transforms in model)
elif isinstance(self.model, RegressorMixin):
# ================================================================
# OLS MODEL PATH
# ================================================================
# Estimate transform parameters for each treatment
from causalpy.transform_optimization import (
estimate_transform_params_grid,
estimate_transform_params_optimize,
)
from causalpy.transforms import Treatment
self.treatments = []
Z_columns = []
self.treatment_labels = []
for name in treatment_names:
# Run parameter estimation using model configuration
if self.model.estimation_method == "grid":
est_results = estimate_transform_params_grid(
data=data,
y_column=y_column,
treatment_name=name,
base_formula=base_formula,
saturation_type=self.model.saturation_type,
saturation_grid=self.model.saturation_grid,
adstock_grid=self.model.adstock_grid,
coef_constraint=self.model.coef_constraint,
hac_maxlags=self.model.hac_maxlags,
error_model=self.model.error_model,
arima_order=self.model.arima_order,
)
search_space = {
"saturation_grid": self.model.saturation_grid,
"adstock_grid": self.model.adstock_grid,
}
elif self.model.estimation_method == "optimize":
est_results = estimate_transform_params_optimize(
data=data,
y_column=y_column,
treatment_name=name,
base_formula=base_formula,
saturation_type=self.model.saturation_type,
saturation_bounds=self.model.saturation_bounds,
adstock_bounds=self.model.adstock_bounds,
initial_params=None,
coef_constraint=self.model.coef_constraint,
hac_maxlags=self.model.hac_maxlags,
method="L-BFGS-B",
error_model=self.model.error_model,
arima_order=self.model.arima_order,
)
search_space = {
"saturation_bounds": self.model.saturation_bounds,
"adstock_bounds": self.model.adstock_bounds,
}
# Store estimation metadata on model
self.model.transform_estimation_results = est_results
self.model.transform_search_space = search_space
# Create Treatment with estimated transforms
treatment = Treatment(
name=name,
saturation=est_results["best_saturation"],
adstock=est_results["best_adstock"],
coef_constraint=self.model.coef_constraint,
)
self.treatments.append(treatment)
# Apply transforms
x_raw = data[name].values
x_transformed = x_raw
if treatment.saturation is not None:
x_transformed = treatment.saturation.apply(x_transformed)
if treatment.adstock is not None:
x_transformed = treatment.adstock.apply(x_transformed)
if treatment.lag is not None:
x_transformed = treatment.lag.apply(x_transformed)
Z_columns.append(x_transformed)
self.treatment_labels.append(name)
# Build full design matrix
self.Z_treatment = np.column_stack(Z_columns)
self.X_full = np.column_stack([self.X_baseline, self.Z_treatment])
self.all_labels = self.baseline_labels + self.treatment_labels
# Store treatments on model for later use
self.model.treatments = self.treatments
# Fit the model (standard CausalPy pattern)
self.model.fit(X=self.X_full, y=self.y)
# Extract results from fitted model
self.ols_result = self.model.ols_result
self.predictions = self.model.ols_result.fittedvalues
self.residuals = self.model.ols_result.resid
self.score = self.model.score
# Extract coefficients (handling ARIMAX correctly)
if self.model.error_model == "arimax":
# ARIMAX: extract only exogenous coefficients
n_exog = self.ols_result.model.k_exog
exog_params = self.ols_result.params[:n_exog]
n_baseline = self.X_baseline.shape[1]
self.beta_baseline = exog_params[:n_baseline]
self.theta_treatment = exog_params[n_baseline:]
else:
# OLS: all params are regression coefficients
n_baseline = self.X_baseline.shape[1]
self.beta_baseline = self.ols_result.params[:n_baseline]
self.theta_treatment = self.ols_result.params[n_baseline:]
# Store model metadata for summary output
self.error_model = self.model.error_model
self.hac_maxlags = self.model.hac_maxlags
self.arima_order = self.model.arima_order
self.transform_estimation_method = self.model.estimation_method
self.transform_estimation_results = self.model.transform_estimation_results
self.transform_search_space = self.model.transform_search_space
else:
raise ValueError("Model type not recognized")
def _validate_inputs(
self,
data: pd.DataFrame,
y_column: str,
treatment_names: List[str],
) -> None:
"""Validate input data and parameters."""
# Check that y_column exists
if y_column not in data.columns:
raise ValueError(f"y_column '{y_column}' not found in data columns")
# Check that treatment columns exist
for name in treatment_names:
if name not in data.columns:
raise ValueError(f"Treatment column '{name}' not found in data columns")
# Check for missing values in outcome
if data[y_column].isna().any():
raise ValueError("Outcome variable contains missing values")
# Warn about missing values in treatment columns
for name in treatment_names:
if data[name].isna().any():
print(
f"Warning: Treatment column '{name}' contains missing values. "
f"Consider forward-filling if justified by the context."
)
# Check that we have a time index
valid_index_types = (pd.DatetimeIndex, pd.RangeIndex)
is_valid_index = isinstance(data.index, valid_index_types) or (
isinstance(data.index, pd.Index)
and pd.api.types.is_integer_dtype(data.index)
)
if not is_valid_index:
raise BadIndexException(
"Data index must be DatetimeIndex, RangeIndex, or integer Index for time series"
)
def _build_treatment_matrix(
self, data: pd.DataFrame, treatments: List[Treatment]
) -> Tuple[np.ndarray, List[str]]:
"""Build the treatment design matrix by applying transforms.
Parameters
----------
data : pd.DataFrame
Input data with treatment columns.
treatments : List[Treatment]
Treatment specifications.
Returns
-------
Z : np.ndarray
Treatment design matrix (n_obs, n_treatments).
labels : List[str]
Column labels for treatments.
"""
Z_columns = []
labels = []
for treatment in treatments:
# Get raw exposure series
x_raw = data[treatment.name].values
# Apply transform pipeline: Saturation → Adstock → Lag
x_transformed = x_raw
if treatment.saturation is not None:
x_transformed = treatment.saturation.apply(x_transformed)
if treatment.adstock is not None:
x_transformed = treatment.adstock.apply(x_transformed)
if treatment.lag is not None:
x_transformed = treatment.lag.apply(x_transformed)
Z_columns.append(x_transformed)
labels.append(treatment.name)
Z = np.column_stack(Z_columns)
return Z, labels
[docs]
def effect(
self,
window: Tuple[Union[pd.Timestamp, int], Union[pd.Timestamp, int]],
channels: Optional[List[str]] = None,
scale: float = 0.0,
) -> Dict[str, Union[pd.DataFrame, float]]:
"""Estimate the causal effect of scaling treatment channels in a time window.
This method computes a counterfactual scenario by scaling the specified
treatment channels in the given window, reapplying all transforms with
the same parameters, and comparing to the observed outcome.
For Bayesian models, returns posterior distributions of effects with credible intervals.
For OLS models, returns point estimates.
Parameters
----------
window : Tuple[Union[pd.Timestamp, int], Union[pd.Timestamp, int]]
Start and end of the effect window (inclusive).
channels : List[str], optional
List of treatment channel names to scale. If None, scales all channels.
scale : float, default=0.0
Scaling factor for the counterfactual (0.0 = remove treatment).
Returns
-------
result : Dict
Dictionary containing:
- "effect_df": DataFrame with observed, counterfactual, effect, cumulative effect
- "total_effect": Total effect in window (mean for Bayesian)
- "mean_effect": Mean effect per period in window
For Bayesian models, also includes HDI bounds for counterfactual and effect.
Examples
--------
.. code-block:: python
# Estimate effect of removing treatment completely
effect = result.effect(
window=(df.index[0], df.index[-1]),
channels=["comm_intensity"],
scale=0.0,
)
print(f"Total effect: {effect['total_effect']:.2f}")
"""
# Route to appropriate method based on model type
if isinstance(self.model, PyMCModel):
return self._bayesian_effect(window=window, channels=channels, scale=scale)
else:
return self._ols_effect(window=window, channels=channels, scale=scale)
def _ols_effect(
self,
window: Tuple[Union[pd.Timestamp, int], Union[pd.Timestamp, int]],
channels: Optional[List[str]] = None,
scale: float = 0.0,
) -> Dict[str, Union[pd.DataFrame, float]]:
"""Estimate the causal effect for OLS models (point estimates)."""
# Default to all channels if not specified
if channels is None:
channels = self.treatment_names
# Validate channels
for ch in channels:
if ch not in self.treatment_names:
raise ValueError(f"Channel '{ch}' not found in treatments")
# Get window mask
window_start, window_end = window
if isinstance(self.data.index, pd.DatetimeIndex):
mask = (self.data.index >= window_start) & (self.data.index <= window_end)
else:
mask = (self.data.index >= window_start) & (self.data.index <= window_end)
# Create counterfactual data by scaling specified channels in the window
data_cf = self.data.copy()
for channel in channels:
data_cf.loc[mask, channel] = scale * data_cf.loc[mask, channel]
# Reapply transforms to counterfactual data
Z_cf, _ = self._build_treatment_matrix(data_cf, self.treatments)
# Predict counterfactual
X_cf_full = np.column_stack([self.X_baseline, Z_cf])
# For ARIMAX, extract only exogenous coefficients (exclude ARIMA params)
if hasattr(self.ols_result.model, "k_exog"):
# ARIMAX: params includes exog coefficients + ARIMA params
exog_params = self.ols_result.params[: self.ols_result.model.k_exog]
y_cf = X_cf_full @ exog_params
else:
# OLS: all params are regression coefficients
y_cf = X_cf_full @ self.ols_result.params
# Compute effect
effect = self.y - y_cf
# Create result DataFrame
effect_df = pd.DataFrame(
{
"observed": self.y,
"counterfactual": y_cf,
"effect": effect,
"effect_cumulative": np.cumsum(effect),
},
index=self.data.index,
)
# Filter to window for summary statistics
window_effect = effect[mask]
result = {
"effect_df": effect_df,
"total_effect": float(np.sum(window_effect)),
"mean_effect": float(np.mean(window_effect)),
"window_start": window_start,
"window_end": window_end,
"channels": channels,
"scale": scale,
}
return result
[docs]
def plot_effect(
self,
effect_result: Dict,
**kwargs,
) -> Tuple[plt.Figure, np.ndarray]:
"""Plot counterfactual effect analysis results.
Creates a 2-panel figure showing:
1. Observed vs counterfactual outcome
2. Cumulative effect over time
For Bayesian models, shows posterior mean with 94% credible intervals.
For OLS models, shows point estimates.
Parameters
----------
effect_result : dict
Result dictionary from effect() method containing:
- effect_df: DataFrame with observed, counterfactual, effect columns
- window_start, window_end: Effect window boundaries
- channels: List of scaled channels
- scale: Scaling factor used
- total_effect: Total effect in window
- mean_effect: Mean effect per period in window
Returns
-------
fig : matplotlib.figure.Figure
ax : array of matplotlib.axes.Axes
Array of 2 axes objects (top: observed vs counterfactual, bottom: cumulative effect).
Examples
--------
.. code-block:: python
# Estimate effect of removing treatment
effect_result = result.effect(
window=(df.index[0], df.index[-1]),
channels=["comm_intensity"],
scale=0.0,
)
fig, ax = result.plot_effect(effect_result)
"""
# Route to appropriate plot method based on model type
if isinstance(self.model, PyMCModel):
return self._bayesian_plot_effect(effect_result=effect_result, **kwargs)
else:
return self._ols_plot_effect(effect_result=effect_result, **kwargs)
def _ols_plot_effect(
self,
effect_result: Dict,
**kwargs,
) -> Tuple[plt.Figure, np.ndarray]:
"""Plot counterfactual effect analysis for OLS models."""
# Extract data from effect result
effect_df = effect_result["effect_df"]
window_start = effect_result.get("window_start")
window_end = effect_result.get("window_end")
# Create 2-panel subplot
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
# ============================================================================
# TOP PANEL: Observed vs Counterfactual
# ============================================================================
axes[0].plot(
effect_df.index,
effect_df["observed"],
label="Observed",
linewidth=1.5,
)
axes[0].plot(
effect_df.index,
effect_df["counterfactual"],
label="Counterfactual",
linewidth=1.5,
linestyle="--",
)
# Shade the effect region
axes[0].fill_between(
effect_df.index,
effect_df["observed"],
effect_df["counterfactual"],
alpha=0.3,
color="C2",
label="Effect",
)
axes[0].set_ylabel(self.y_column, fontsize=11)
axes[0].set_title("Observed vs Counterfactual", fontsize=12, fontweight="bold")
axes[0].legend(fontsize=LEGEND_FONT_SIZE)
axes[0].grid(True, alpha=0.3)
# Add window boundaries if specified
if window_start is not None and window_end is not None:
axes[0].axvline(x=window_start, color="red", linestyle=":", alpha=0.5)
axes[0].axvline(x=window_end, color="red", linestyle=":", alpha=0.5)
# ============================================================================
# BOTTOM PANEL: Cumulative effect
# ============================================================================
axes[1].plot(
effect_df.index,
effect_df["effect_cumulative"],
linewidth=2,
color="C2",
)
axes[1].fill_between(
effect_df.index,
0,
effect_df["effect_cumulative"],
alpha=0.3,
color="C2",
)
axes[1].axhline(y=0, color="k", linestyle="--", linewidth=1)
axes[1].set_ylabel("Cumulative Effect", fontsize=11)
axes[1].set_xlabel("Time", fontsize=11)
axes[1].set_title("Cumulative Effect Over Time", fontsize=12, fontweight="bold")
axes[1].grid(True, alpha=0.3)
# Add window boundaries
if window_start is not None and window_end is not None:
axes[1].axvline(x=window_start, color="red", linestyle=":", alpha=0.5)
axes[1].axvline(x=window_end, color="red", linestyle=":", alpha=0.5)
plt.tight_layout()
return fig, axes
[docs]
def plot(
self, round_to: Optional[int] = 2, **kwargs
) -> Tuple[plt.Figure, plt.Axes]:
"""Plot the model fit and results.
Creates a 2-panel figure showing:
1. Observed vs fitted values
2. Residuals over time
Parameters
----------
round_to : int, optional
Number of decimal places for rounding displayed values.
Returns
-------
fig : matplotlib.figure.Figure
ax : array of matplotlib.axes.Axes
"""
# Route to appropriate plot method based on model type
if isinstance(self.model, PyMCModel):
return self._bayesian_plot(round_to=round_to, **kwargs)
else:
return self._ols_plot(round_to=round_to, **kwargs)
def _ols_plot(
self, round_to: Optional[int] = 2, **kwargs
) -> Tuple[plt.Figure, plt.Axes]:
"""Generate OLS-specific plots."""
fig, ax = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
# Top panel: Observed vs fitted
ax[0].plot(
self.data.index, self.y, "o", label="Observed", alpha=0.6, markersize=4
)
ax[0].plot(
self.data.index,
self.predictions,
"-",
label="Fitted",
linewidth=2,
color="C1",
)
ax[0].set_ylabel("Outcome")
ax[0].set_title(f"Model Fit: R² = {round_num(self.score, round_to)}")
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
ax[0].grid(True, alpha=0.3)
# Bottom panel: Residuals
ax[1].plot(self.data.index, self.residuals, "o-", alpha=0.6, markersize=3)
ax[1].axhline(y=0, color="k", linestyle="--", linewidth=1)
ax[1].set_ylabel("Residuals")
ax[1].set_xlabel("Time")
ax[1].set_title("Model Residuals")
ax[1].grid(True, alpha=0.3)
plt.tight_layout()
return fig, ax
[docs]
def plot_irf(self, channel: str, max_lag: Optional[int] = None) -> plt.Figure:
"""Plot the Impulse Response Function (IRF) for a treatment channel.
Shows how a one-unit impulse in the (saturated) exposure propagates over
time through the adstock transformation.
Parameters
----------
channel : str
Name of the treatment channel.
max_lag : int, optional
Maximum lag to display.
Returns
-------
fig : matplotlib.figure.Figure
Examples
--------
.. code-block:: python
fig = result.plot_irf("comm_intensity", max_lag=12)
"""
# Find the treatment
treatment = None
for t in self.treatments:
if t.name == channel:
treatment = t
break
if treatment is None:
raise ValueError(f"Channel '{channel}' not found in treatments")
# Extract adstock transform
adstock = treatment.adstock
if adstock is None:
print(f"No adstock transform found for channel '{channel}'")
return None
# Get alpha parameter from adstock transform
adstock_params = adstock.get_params()
alpha = adstock_params.get("alpha")
if alpha is None:
raise ValueError(
f"Adstock transform for channel '{channel}' has alpha=None. "
"This should not happen if half_life or alpha was provided."
)
# Generate IRF (adstock weights)
if max_lag is None:
max_lag = adstock_params.get("l_max", 12)
lags = np.arange(max_lag + 1)
weights = alpha**lags
normalize = adstock_params.get("normalize", True)
if normalize:
weights = weights / weights.sum()
# Plot
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(lags, weights, alpha=0.7, color="C0")
ax.set_xlabel("Lag (periods)")
ax.set_ylabel("Weight")
# Calculate half-life
half_life_calc = np.log(0.5) / np.log(alpha)
ax.set_title(
f"Impulse Response Function: {channel}\n"
f"(alpha={alpha:.3f}, half_life={half_life_calc:.2f}, "
f"normalize={normalize})"
)
ax.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
return fig
def _ols_plot_transforms(
self,
true_saturation=None,
true_adstock=None,
x_range=None,
**kwargs,
) -> Tuple[plt.Figure, np.ndarray]:
"""Plot estimated transformation curves for OLS models."""
# Currently only supports single treatment
if len(self.treatments) != 1:
raise NotImplementedError(
"plot_transforms() currently only supports single treatment analysis"
)
treatment = self.treatments[0]
est_saturation = treatment.saturation
est_adstock = treatment.adstock
# Check which transforms exist
has_saturation = est_saturation is not None
has_adstock = est_adstock is not None
if not has_saturation and not has_adstock:
raise ValueError(
"No transforms to plot (both saturation and adstock are None). "
"At least one transform must be specified."
)
# Determine number of panels based on available transforms
n_panels = int(has_saturation) + int(has_adstock)
# Create subplot with appropriate number of panels
fig, axes = plt.subplots(1, n_panels, figsize=(7 * n_panels, 5))
# Make axes a list for consistent indexing
if n_panels == 1:
axes = [axes]
panel_idx = 0
# ============================================================================
# SATURATION PLOT (if present)
# ============================================================================
if est_saturation is not None:
ax = axes[panel_idx]
# Determine x range
if x_range is None:
# Use range from data
x_raw = self.data[treatment.name].values
x_min, x_max = x_raw.min(), x_raw.max()
# Add some padding
x_padding = (x_max - x_min) * 0.1
x_range = (max(0, x_min - x_padding), x_max + x_padding)
x_sat = np.linspace(x_range[0], x_range[1], 100)
# Plot true saturation if provided
if true_saturation is not None:
y_true_sat = true_saturation.apply(x_sat)
ax.plot(
x_sat,
y_true_sat,
"k--",
linewidth=2.5,
label="True",
alpha=0.8,
)
# Plot estimated saturation
y_est_sat = est_saturation.apply(x_sat)
ax.plot(x_sat, y_est_sat, "C0-", linewidth=2.5, label="Estimated")
ax.set_xlabel(f"{treatment.name} (raw)", fontsize=11)
ax.set_ylabel("Saturated Value", fontsize=11)
ax.set_title("Saturation Function", fontsize=12, fontweight="bold")
ax.legend(fontsize=LEGEND_FONT_SIZE, framealpha=0.9)
ax.grid(True, alpha=0.3)
# Add parameter text
est_params = est_saturation.get_params()
param_text = "Estimated:\n"
for key, val in est_params.items():
if key not in ["alpha", "l_max", "normalize"]: # Skip adstock params
param_text += f" {key}={val:.2f}\n"
if true_saturation is not None:
true_params = true_saturation.get_params()
param_text += "\nTrue:\n"
for key, val in true_params.items():
if key not in ["alpha", "l_max", "normalize"]:
param_text += f" {key}={val:.2f}\n"
ax.text(
0.05,
0.95,
param_text.strip(),
transform=ax.transAxes,
fontsize=9,
verticalalignment="top",
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
)
panel_idx += 1
# ============================================================================
# ADSTOCK PLOT (if present)
# ============================================================================
if est_adstock is not None:
ax = axes[panel_idx]
est_adstock_params = est_adstock.get_params()
l_max = est_adstock_params.get("l_max", 12)
lags = np.arange(l_max + 1)
# Compute estimated adstock weights
est_alpha = est_adstock_params["alpha"]
est_weights = est_alpha**lags
normalize = est_adstock_params.get("normalize", True)
if normalize:
est_weights = est_weights / est_weights.sum()
# Plot true adstock if provided
if true_adstock is not None:
true_adstock_params = true_adstock.get_params()
true_alpha = true_adstock_params["alpha"]
true_weights = true_alpha**lags
if true_adstock_params.get("normalize", True):
true_weights = true_weights / true_weights.sum()
# Line plot instead of bars
ax.plot(
lags,
true_weights,
"k--",
linewidth=2.5,
label="True",
alpha=0.8,
)
# Estimated adstock line
ax.plot(
lags,
est_weights,
"C0-",
linewidth=2.5,
label="Estimated",
alpha=0.8,
)
else:
# Single line for estimated only
ax.plot(
lags,
est_weights,
"C0-",
linewidth=2.5,
label="Estimated",
alpha=0.8,
)
ax.set_xlabel("Lag (periods)", fontsize=11)
ax.set_ylabel("Adstock Weight", fontsize=11)
ax.set_title(
"Adstock Function (Carryover Effect)", fontsize=12, fontweight="bold"
)
ax.legend(fontsize=LEGEND_FONT_SIZE, framealpha=0.9)
ax.grid(True, alpha=0.3, axis="y")
# Add parameter text
param_text = "Estimated:\n"
half_life_est = np.log(0.5) / np.log(est_alpha)
param_text += f" half_life={half_life_est:.2f}\n"
param_text += f" alpha={est_alpha:.3f}\n"
if true_adstock is not None:
true_alpha = true_adstock_params["alpha"]
half_life_true = np.log(0.5) / np.log(true_alpha)
param_text += "\nTrue:\n"
param_text += f" half_life={half_life_true:.2f}\n"
param_text += f" alpha={true_alpha:.3f}\n"
ax.text(
0.95,
0.95,
param_text.strip(),
transform=ax.transAxes,
fontsize=9,
verticalalignment="top",
horizontalalignment="right",
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
)
plt.tight_layout()
return fig, axes
[docs]
def plot_diagnostics(self, lags: int = 20) -> None:
"""Display diagnostic plots and tests for model residuals.
Shows:
1. ACF (autocorrelation function) plot
2. PACF (partial autocorrelation function) plot
3. Ljung-Box test for residual autocorrelation
Parameters
----------
lags : int, default=20
Number of lags to display.
"""
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
# ACF plot
plot_acf(self.residuals, lags=lags, ax=axes[0], alpha=0.05)
axes[0].set_title("Residual Autocorrelation Function (ACF)")
# PACF plot
plot_pacf(self.residuals, lags=lags, ax=axes[1], alpha=0.05, method="ywm")
axes[1].set_title("Residual Partial Autocorrelation Function (PACF)")
plt.tight_layout()
plt.show()
# Ljung-Box test
lb_result = acorr_ljungbox(self.residuals, lags=lags, return_df=True)
print("\n" + "=" * 60)
print("Ljung-Box Test for Residual Autocorrelation")
print("=" * 60)
print("H0: Residuals are independently distributed (no autocorrelation)")
print("If p-value < 0.05, reject H0 (autocorrelation present)")
print("-" * 60)
# Show summary for a few key lags
key_lags = [1, 5, 10, lags]
for lag in key_lags:
if lag <= len(lb_result):
row = lb_result.iloc[lag - 1]
sig = (
"***"
if row["lb_pvalue"] < 0.01
else ("*" if row["lb_pvalue"] < 0.05 else "")
)
print(
f"Lag {lag:2d}: LB statistic = {row['lb_stat']:8.3f}, "
f"p-value = {row['lb_pvalue']:.4f} {sig}"
)
print("-" * 60)
if lb_result["lb_pvalue"].min() < 0.05:
print(
"⚠ Warning: Significant residual autocorrelation detected.\n"
" - HAC standard errors (if used) account for this in coefficient inference.\n"
" - Consider adding more baseline controls or adjusting transform parameters."
)
else:
print("✓ No significant residual autocorrelation detected.")
print("=" * 60)
[docs]
def summary(self, round_to: Optional[int] = None) -> None:
"""Print a summary of the model results.
Parameters
----------
round_to : int, optional
Number of decimal places for rounding.
"""
import arviz as az
if round_to is None:
round_to = 2
print("=" * 80)
print("Graded Intervention Time Series Results")
print("=" * 80)
print(f"Outcome variable: {self.y_column}")
print(
f"Number of observations: {len(self.y) if isinstance(self.y, np.ndarray) else self.y.shape[0]}"
)
if isinstance(self.model, PyMCModel):
# ============================================================
# BAYESIAN MODEL SUMMARY
# ============================================================
print("Model type: Bayesian (PyMC)")
print("-" * 80)
print("Transform parameters (Posterior Mean [94% HDI]):")
# Extract transform parameters
if "half_life" in self.model.idata.posterior:
half_life_post = az.extract(self.model.idata, var_names=["half_life"])
half_life_mean = float(half_life_post.mean())
half_life_hdi = az.hdi(
self.model.idata, var_names=["half_life"], hdi_prob=0.94
)["half_life"]
print(
f" half_life: {round_num(half_life_mean, round_to)} [{round_num(float(half_life_hdi.sel(hdi='lower').values), round_to)}, {round_num(float(half_life_hdi.sel(hdi='higher').values), round_to)}]"
)
# Saturation parameters if present
if "slope" in self.model.idata.posterior:
slope_post = az.extract(self.model.idata, var_names=["slope"])
slope_mean = float(slope_post.mean())
slope_hdi = az.hdi(
self.model.idata, var_names=["slope"], hdi_prob=0.94
)["slope"]
print(
f" slope: {round_num(slope_mean, round_to)} [{round_num(float(slope_hdi.sel(hdi='lower').values), round_to)}, {round_num(float(slope_hdi.sel(hdi='higher').values), round_to)}]"
)
if "kappa" in self.model.idata.posterior:
kappa_post = az.extract(self.model.idata, var_names=["kappa"])
kappa_mean = float(kappa_post.mean())
kappa_hdi = az.hdi(
self.model.idata, var_names=["kappa"], hdi_prob=0.94
)["kappa"]
print(
f" kappa: {round_num(kappa_mean, round_to)} [{round_num(float(kappa_hdi.sel(hdi='lower').values), round_to)}, {round_num(float(kappa_hdi.sel(hdi='higher').values), round_to)}]"
)
print("-" * 80)
print("Baseline coefficients (Posterior Mean [94% HDI]):")
beta_post = az.extract(self.model.idata, var_names=["beta"]).squeeze()
# Compute HDI for all betas at once - HDI uses generic dimension names
beta_hdi_all = az.hdi(self.model.idata, var_names=["beta"], hdi_prob=0.94)[
"beta"
]
# Determine the dimension structure of the HDI result
# For multi-unit case: (treated_units, coeffs, hdi)
# For single-unit case: (coeffs, hdi) or just (hdi)
for i, label in enumerate(self.baseline_labels):
if beta_post.ndim > 1:
beta_i = beta_post[0, i] # First treated unit
# HDI for multi-dimensional: select first unit and i-th coefficient
if "treated_units" in beta_hdi_all.dims:
beta_hdi = beta_hdi_all.isel(treated_units=0, coeffs=i)
else:
beta_hdi = beta_hdi_all.isel(coeffs=i)
else:
beta_i = beta_post[i]
# Single unit case
if "coeffs" in beta_hdi_all.dims:
beta_hdi = beta_hdi_all.isel(coeffs=i)
else:
# Scalar case
beta_hdi = beta_hdi_all
beta_mean = float(beta_i.mean())
print(
f" {label:20s}: {round_num(beta_mean, round_to)} [{round_num(float(beta_hdi.sel(hdi='lower').values), round_to)}, {round_num(float(beta_hdi.sel(hdi='higher').values), round_to)}]"
)
print("-" * 80)
print("Treatment coefficients (Posterior Mean [94% HDI]):")
theta_post = az.extract(
self.model.idata, var_names=["theta_treatment"]
).squeeze()
# Compute HDI for all thetas at once - HDI uses generic dimension names
theta_hdi_all = az.hdi(
self.model.idata, var_names=["theta_treatment"], hdi_prob=0.94
)["theta_treatment"]
for i, label in enumerate(self.treatment_labels):
if theta_post.ndim > 1:
theta_i = theta_post[0, i] # First treated unit
# HDI for multi-dimensional
if "treated_units" in theta_hdi_all.dims:
theta_hdi = theta_hdi_all.isel(treated_units=0, coeffs=i)
else:
theta_hdi = theta_hdi_all.isel(coeffs=i)
else:
theta_i = theta_post[i] if theta_post.ndim > 0 else theta_post
# Single unit/treatment case
if (
"coeffs" in theta_hdi_all.dims
and theta_hdi_all.sizes["coeffs"] > 1
):
theta_hdi = theta_hdi_all.isel(coeffs=i)
else:
# Scalar or single coefficient case
theta_hdi = theta_hdi_all
theta_mean = float(theta_i.mean())
print(
f" {label:20s}: {round_num(theta_mean, round_to)} [{round_num(float(theta_hdi.sel(hdi='lower').values), round_to)}, {round_num(float(theta_hdi.sel(hdi='higher').values), round_to)}]"
)
print("=" * 80)
else:
# ============================================================
# OLS MODEL SUMMARY
# ============================================================
print(f"R-squared: {round_num(self.score, round_to)}")
print(f"Error model: {self.error_model.upper()}")
if self.error_model == "hac":
print(
f" HAC max lags: {self.hac_maxlags} "
f"(robust SEs accounting for {self.hac_maxlags} periods of autocorrelation)"
)
elif self.error_model == "arimax":
p, d, q = self.arima_order
print(f" ARIMA order: ({p}, {d}, {q})")
print(f" p={p}: AR order, d={d}: differencing, q={q}: MA order")
print("-" * 80)
print("Baseline coefficients:")
for label, coef, se in zip(
self.baseline_labels,
self.beta_baseline,
self.ols_result.bse[: len(self.baseline_labels)],
):
coef_rounded = round_num(coef, round_to)
se_rounded = round_num(se, round_to)
print(f" {label:20s}: {coef_rounded:>10} (SE: {se_rounded})")
print("-" * 80)
print("Treatment coefficients:")
n_baseline = len(self.baseline_labels)
# For ARIMAX, we need to extract only the treatment SEs from exogenous params
if self.error_model == "arimax":
n_exog = self.ols_result.model.k_exog
treatment_se = self.ols_result.bse[n_baseline:n_exog]
else:
treatment_se = self.ols_result.bse[n_baseline:]
for label, coef, se in zip(
self.treatment_labels,
self.theta_treatment,
treatment_se,
):
coef_rounded = round_num(coef, round_to)
se_rounded = round_num(se, round_to)
print(f" {label:20s}: {coef_rounded:>10} (SE: {se_rounded})")
print("=" * 80)
# Methods required by BaseExperiment
def _bayesian_plot(
self, round_to: Optional[int] = 2, **kwargs
) -> Tuple[plt.Figure, plt.Axes]:
"""Generate Bayesian-specific plots with credible intervals."""
import arviz as az
fig, ax = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
# Extract posterior predictions (mu is in posterior group as Deterministic)
mu_posterior = az.extract(self.model.idata, group="posterior", var_names="mu")
# Get mean and HDI
mu_mean = mu_posterior.mean(dim="sample").values.flatten()
# Compute HDI from posterior
mu_hdi = az.hdi(self.model.idata.posterior, var_names=["mu"], hdi_prob=0.94)[
"mu"
]
mu_lower = mu_hdi.sel(hdi="lower").values.flatten()
mu_upper = mu_hdi.sel(hdi="higher").values.flatten()
# Top panel: Observed vs fitted with credible interval
ax[0].plot(
self.data.index,
self.y.values.flatten() if hasattr(self.y, "values") else self.y,
"o",
label="Observed",
alpha=0.6,
markersize=4,
)
ax[0].plot(
self.data.index,
mu_mean,
"-",
label="Posterior Mean",
linewidth=2,
color="C1",
)
ax[0].fill_between(
self.data.index,
mu_lower,
mu_upper,
alpha=0.3,
color="C1",
label="94% HDI",
)
ax[0].set_ylabel("Outcome")
ax[0].set_title("Bayesian Model Fit")
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
ax[0].grid(True, alpha=0.3)
# Bottom panel: Residuals with uncertainty
y_obs = self.y.values.flatten() if hasattr(self.y, "values") else self.y
residuals = y_obs - mu_mean
ax[1].plot(self.data.index, residuals, "o-", alpha=0.6, markersize=3)
ax[1].axhline(y=0, color="k", linestyle="--", linewidth=1)
ax[1].set_ylabel("Residuals")
ax[1].set_xlabel("Time")
ax[1].set_title("Model Residuals")
ax[1].grid(True, alpha=0.3)
plt.tight_layout()
return fig, ax
[docs]
def get_plot_data_bayesian(self, *args, **kwargs):
"""Get plot data for Bayesian results.
Returns
-------
pd.DataFrame
DataFrame with observed and posterior mean fitted values.
"""
import arviz as az
# Extract posterior predictions
mu_posterior = az.extract(
self.model.idata, group="posterior_predictive", var_names="mu"
)
mu_mean = mu_posterior.mean(dim="sample").values.flatten()
y_obs = self.y.values.flatten() if hasattr(self.y, "values") else self.y
return pd.DataFrame(
{
"observed": y_obs,
"fitted": mu_mean,
"residuals": y_obs - mu_mean,
},
index=self.data.index,
)
def _bayesian_plot_transforms(
self,
true_saturation=None,
true_adstock=None,
x_range=None,
**kwargs,
) -> Tuple[plt.Figure, np.ndarray]:
"""Plot estimated transformation curves for Bayesian models with credible intervals."""
import arviz as az
# Check which transforms are present in the posterior
has_saturation = (
"slope" in self.model.idata.posterior
or "kappa" in self.model.idata.posterior
)
has_adstock = "half_life" in self.model.idata.posterior
if not has_saturation and not has_adstock:
raise ValueError(
"No transforms to plot (no transform parameters found in posterior). "
"At least one transform must be specified."
)
# Determine number of panels
n_panels = int(has_saturation) + int(has_adstock)
# Create subplot
fig, axes = plt.subplots(1, n_panels, figsize=(7 * n_panels, 5))
# Make axes a list for consistent indexing
if n_panels == 1:
axes = [axes]
panel_idx = 0
# ============================================================================
# SATURATION PLOT (if present)
# ============================================================================
if has_saturation:
ax = axes[panel_idx]
# TODO: Implement Bayesian saturation plotting when needed
# For now, skip saturation (since the current example uses adstock only)
ax.text(
0.5,
0.5,
"Bayesian saturation\nplotting not yet\nimplemented",
ha="center",
va="center",
transform=ax.transAxes,
fontsize=12,
)
panel_idx += 1
# ============================================================================
# ADSTOCK PLOT (if present)
# ============================================================================
if has_adstock:
ax = axes[panel_idx]
# Extract posterior samples of half_life
half_life_post = az.extract(self.model.idata, var_names=["half_life"])
l_max = self.model.adstock_config.get("l_max", 8)
lags = np.arange(l_max + 1)
# Compute adstock weights for all posterior samples
weights_posterior = []
for half_life_sample in half_life_post.values:
alpha_sample = np.power(0.5, 1 / half_life_sample)
weights_sample = alpha_sample**lags
# Normalize if needed
if self.model.adstock_config.get("normalize", True):
weights_sample = weights_sample / weights_sample.sum()
weights_posterior.append(weights_sample)
weights_posterior = np.array(
weights_posterior
) # Shape: (n_samples, n_lags)
# Compute 94% credible interval bounds at each lag
weights_lower = np.percentile(weights_posterior, 3, axis=0)
weights_upper = np.percentile(weights_posterior, 97, axis=0)
# Compute posterior mean
half_life_mean = float(half_life_post.mean())
alpha_mean = np.power(0.5, 1 / half_life_mean)
weights_mean = alpha_mean**lags
if self.model.adstock_config.get("normalize", True):
weights_mean = weights_mean / weights_mean.sum()
# Plot true adstock if provided
if true_adstock is not None:
true_adstock_params = true_adstock.get_params()
true_alpha = true_adstock_params["alpha"]
true_weights = true_alpha**lags
if true_adstock_params.get("normalize", True):
true_weights = true_weights / true_weights.sum()
ax.plot(
lags,
true_weights,
"k--",
linewidth=2.5,
label="True",
alpha=0.8,
zorder=10,
)
# Plot Bayesian uncertainty as shaded region
ax.fill_between(
lags,
weights_lower,
weights_upper,
color="C2",
alpha=0.25,
label="Bayesian 94% HDI",
zorder=1,
)
# Plot Bayesian posterior mean
ax.plot(
lags,
weights_mean,
"C2-",
linewidth=3,
label="Bayesian Mean",
alpha=1.0,
zorder=11,
)
ax.set_xlabel("Lag (periods)", fontsize=11)
ax.set_ylabel("Adstock Weight", fontsize=11)
ax.set_title(
"Adstock Function (Carryover Effect)", fontsize=12, fontweight="bold"
)
ax.legend(fontsize=LEGEND_FONT_SIZE, framealpha=0.9)
ax.grid(True, alpha=0.3, axis="y")
# Add parameter text with posterior summary
param_text = "Bayesian Posterior:\n"
half_life_hdi = az.hdi(
self.model.idata, var_names=["half_life"], hdi_prob=0.94
)["half_life"]
param_text += f" half_life={half_life_mean:.2f}\n"
param_text += f" [94% HDI: {float(half_life_hdi.sel(hdi='lower').values):.2f}, {float(half_life_hdi.sel(hdi='higher').values):.2f}]\n"
if true_adstock is not None:
true_half_life = np.log(0.5) / np.log(true_alpha)
param_text += "\nTrue:\n"
param_text += f" half_life={true_half_life:.2f}\n"
ax.text(
0.95,
0.95,
param_text.strip(),
transform=ax.transAxes,
fontsize=9,
verticalalignment="top",
horizontalalignment="right",
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
)
plt.tight_layout()
return fig, axes
def _bayesian_effect(
self,
window: Tuple[Union[pd.Timestamp, int], Union[pd.Timestamp, int]],
channels: Optional[List[str]] = None,
scale: float = 0.0,
) -> Dict[str, Union[pd.DataFrame, float]]:
"""Estimate the causal effect for Bayesian models with posterior uncertainty."""
import arviz as az
# Default to all channels if not specified
if channels is None:
channels = self.treatment_labels
# Validate channels
for ch in channels:
if ch not in self.treatment_labels:
raise ValueError(f"Channel '{ch}' not found in treatments")
# Get window mask
window_start, window_end = window
if isinstance(self.data.index, pd.DatetimeIndex):
mask = (self.data.index >= window_start) & (self.data.index <= window_end)
else:
mask = (self.data.index >= window_start) & (self.data.index <= window_end)
# Create counterfactual data by scaling specified channels in the window
data_cf = self.data.copy()
for channel in channels:
data_cf.loc[mask, channel] = scale * data_cf.loc[mask, channel]
# Get counterfactual treatment data
treatment_raw_cf = np.column_stack(
[data_cf[name].values for name in self.treatment_labels]
)
# Extract posterior samples
beta_post = az.extract(
self.model.idata, var_names=["beta"]
) # (sample, units, features)
theta_post = az.extract(
self.model.idata, var_names=["theta_treatment"]
) # (sample, units, treatments)
# Extract transform parameter samples
n_samples = len(beta_post.sample)
y_cf_samples = []
# For each posterior sample, compute counterfactual prediction
for i in range(n_samples):
# Get parameter values for this sample
if "half_life" in self.model.idata.posterior:
half_life_sample = float(
az.extract(self.model.idata, var_names=["half_life"]).isel(sample=i)
)
else:
half_life_sample = None
# Apply transforms with posterior sample parameters
treatment_transformed_cf = treatment_raw_cf.copy().astype(np.float64)
if half_life_sample is not None:
# Apply adstock with this sample's half_life (numpy implementation)
alpha_sample = np.power(0.5, 1 / half_life_sample)
l_max = self.model.adstock_config.get("l_max", 8)
normalize = self.model.adstock_config.get("normalize", True)
# Compute adstock weights
lags = np.arange(l_max + 1)
weights = alpha_sample**lags
if normalize:
weights = weights / weights.sum()
# Apply convolution for each treatment column
treatment_transformed_cf = np.zeros_like(
treatment_raw_cf, dtype=np.float64
)
for col_idx in range(treatment_raw_cf.shape[1]):
treatment_transformed_cf[:, col_idx] = np.convolve(
treatment_raw_cf[:, col_idx], weights, mode="same"
)
# Get coefficients for this sample (first unit only for single-unit case)
if beta_post.ndim > 2:
beta_sample = beta_post.isel(sample=i, treated_units=0).values.astype(
np.float64
)
else:
beta_sample = beta_post.isel(sample=i).values.astype(np.float64)
if theta_post.ndim > 2:
theta_sample = theta_post.isel(sample=i, treated_units=0).values.astype(
np.float64
)
else:
theta_sample = theta_post.isel(sample=i).values.astype(np.float64)
# Compute counterfactual prediction (ensure float64)
baseline_pred = np.asarray(self.X.values, dtype=np.float64) @ beta_sample
treatment_pred = treatment_transformed_cf @ theta_sample
y_cf_sample = baseline_pred + treatment_pred
y_cf_samples.append(y_cf_sample.flatten())
y_cf_samples = np.array(y_cf_samples) # Shape: (n_samples, n_obs)
# Compute posterior mean and HDI
y_cf_mean = y_cf_samples.mean(axis=0)
y_cf_lower = np.percentile(y_cf_samples, 3, axis=0)
y_cf_upper = np.percentile(y_cf_samples, 97, axis=0)
# Get observed data
y_obs = self.y.values.flatten() if hasattr(self.y, "values") else self.y
# Compute effect
effect_samples = (
y_obs[np.newaxis, :] - y_cf_samples
) # Shape: (n_samples, n_obs)
effect_mean = effect_samples.mean(axis=0)
effect_lower = np.percentile(effect_samples, 3, axis=0)
effect_upper = np.percentile(effect_samples, 97, axis=0)
# Cumulative effect
effect_cumulative_samples = np.cumsum(effect_samples, axis=1)
effect_cumulative_mean = effect_cumulative_samples.mean(axis=0)
effect_cumulative_lower = np.percentile(effect_cumulative_samples, 3, axis=0)
effect_cumulative_upper = np.percentile(effect_cumulative_samples, 97, axis=0)
# Create result DataFrame
effect_df = pd.DataFrame(
{
"observed": y_obs,
"counterfactual": y_cf_mean,
"counterfactual_lower": y_cf_lower,
"counterfactual_upper": y_cf_upper,
"effect": effect_mean,
"effect_lower": effect_lower,
"effect_upper": effect_upper,
"effect_cumulative": effect_cumulative_mean,
"effect_cumulative_lower": effect_cumulative_lower,
"effect_cumulative_upper": effect_cumulative_upper,
},
index=self.data.index,
)
# Filter to window for summary statistics
window_effect_samples = effect_samples[:, mask]
result = {
"effect_df": effect_df,
"total_effect": float(window_effect_samples.sum(axis=1).mean()),
"total_effect_lower": float(
np.percentile(window_effect_samples.sum(axis=1), 3)
),
"total_effect_upper": float(
np.percentile(window_effect_samples.sum(axis=1), 97)
),
"mean_effect": float(window_effect_samples.mean(axis=1).mean()),
"window_start": window_start,
"window_end": window_end,
"channels": channels,
"scale": scale,
}
return result
def _bayesian_plot_effect(
self,
effect_result: Dict,
**kwargs,
) -> Tuple[plt.Figure, np.ndarray]:
"""Plot counterfactual effect analysis for Bayesian models with credible intervals."""
# Extract data from effect result
effect_df = effect_result["effect_df"]
window_start = effect_result.get("window_start")
window_end = effect_result.get("window_end")
# Create 2-panel subplot
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
# ============================================================================
# TOP PANEL: Observed vs Counterfactual with credible intervals
# ============================================================================
# Plot counterfactual with uncertainty
axes[0].plot(
effect_df.index,
effect_df["counterfactual"],
label="Counterfactual (Posterior Mean)",
linewidth=1.5,
linestyle="--",
color="C1",
)
axes[0].fill_between(
effect_df.index,
effect_df["counterfactual_lower"],
effect_df["counterfactual_upper"],
alpha=0.2,
color="C1",
label="Counterfactual 94% HDI",
)
# Plot observed
axes[0].plot(
effect_df.index,
effect_df["observed"],
label="Observed",
linewidth=1.5,
color="C0",
)
# Shade the effect region
axes[0].fill_between(
effect_df.index,
effect_df["observed"],
effect_df["counterfactual"],
alpha=0.3,
color="C2",
label="Effect",
)
axes[0].set_ylabel(self.y_column, fontsize=11)
axes[0].set_title("Observed vs Counterfactual", fontsize=12, fontweight="bold")
axes[0].legend(fontsize=LEGEND_FONT_SIZE)
axes[0].grid(True, alpha=0.3)
# Add window boundaries if specified
if window_start is not None and window_end is not None:
axes[0].axvline(x=window_start, color="red", linestyle=":", alpha=0.5)
axes[0].axvline(x=window_end, color="red", linestyle=":", alpha=0.5)
# ============================================================================
# BOTTOM PANEL: Cumulative effect with credible intervals
# ============================================================================
axes[1].plot(
effect_df.index,
effect_df["effect_cumulative"],
linewidth=2,
color="C2",
label="Cumulative Effect (Posterior Mean)",
)
axes[1].fill_between(
effect_df.index,
effect_df["effect_cumulative_lower"],
effect_df["effect_cumulative_upper"],
alpha=0.3,
color="C2",
label="94% HDI",
)
axes[1].axhline(y=0, color="k", linestyle="--", linewidth=1)
axes[1].set_ylabel("Cumulative Effect", fontsize=11)
axes[1].set_xlabel("Time", fontsize=11)
axes[1].set_title("Cumulative Effect Over Time", fontsize=12, fontweight="bold")
axes[1].legend(fontsize=LEGEND_FONT_SIZE)
axes[1].grid(True, alpha=0.3)
# Add window boundaries
if window_start is not None and window_end is not None:
axes[1].axvline(x=window_start, color="red", linestyle=":", alpha=0.5)
axes[1].axvline(x=window_end, color="red", linestyle=":", alpha=0.5)
plt.tight_layout()
return fig, axes
[docs]
def get_plot_data_ols(self) -> pd.DataFrame:
"""Get plot data for OLS results.
Returns
-------
pd.DataFrame
DataFrame with observed, fitted, and residual values.
"""
return pd.DataFrame(
{
"observed": self.y,
"fitted": self.predictions,
"residuals": self.residuals,
},
index=self.data.index,
)