MMM#

class pymc_marketing.mmm.multidimensional.MMM(*, date_column=FieldInfo(annotation=NoneType, required=True, description='Column name of the date variable.'), channel_columns=FieldInfo(annotation=NoneType, required=True, description='Column names of the media channel variables.', metadata=[MinLen(min_length=1)]), target_column=FieldInfo(annotation=NoneType, required=False, default='y', description='The name of the target column.'), adstock=FieldInfo(annotation=NoneType, required=True, description='Type of adstock transformation to apply.'), saturation=FieldInfo(annotation=NoneType, required=True, description='The saturation transformation to apply to the channel data.'), time_varying_intercept=False, time_varying_media=False, dims=FieldInfo(annotation=NoneType, required=False, default=None, description='Additional dimensions for the model.'), scaling=FieldInfo(annotation=NoneType, required=False, default=None, description='Scaling configuration for the model.'), model_config=FieldInfo(annotation=NoneType, required=False, default=None, description='Configuration settings for the model.'), sampler_config=FieldInfo(annotation=NoneType, required=False, default=None, description='Configuration settings for the sampler.'), control_columns=None, yearly_seasonality=None, adstock_first=True, dag=FieldInfo(annotation=NoneType, required=False, default=None, description='Optional DAG provided as a string Dot format for causal identification.'), treatment_nodes=FieldInfo(annotation=NoneType, required=False, default=None, description='Column names of the variables of interest to identify causal effects on outcome.'), outcome_node=FieldInfo(annotation=NoneType, required=False, default=None, description='Name of the outcome variable.'))[source]#

Marketing Mix Model class for estimating the impact of marketing channels on a target variable.

Given a target variable \(y_{t}\) (e.g. sales or conversions), media variables \(x_{m, t}\) (e.g. impressions, clicks, or costs), and a set of control covariates \(z_{c, t}\) (e.g. holidays, pricing), we consider a Bayesian linear model of the form:

\[y_{t} = \alpha + \sum_{m=1}^{M}\beta_{m}\,f_{m}\!\bigl( \{x_{m,s}\}_{s \leq t}\bigr) + \sum_{c=1}^{C}\gamma_{c}\, z_{c, t} + \varepsilon_{t},\]

where \(\alpha\) is the intercept, \(f_{m}\) is a media transformation function that maps the history of channel \(m\) up to time \(t\) to a scalar contribution, capturing adstock (carry-over) and saturation effects, and \(\varepsilon_{t} \sim \mathcal{N}(0, \sigma^{2})\).

The model supports \(K \geq 0\) additional panel dimensions (e.g. geography, brand) specified via the dims parameter. When \(K > 0\), every variable — the target, media inputs, controls — and all parameters (\(\alpha\), \(\beta_{m}\), \(\gamma_{c}\), \(\sigma\), and the parameters of \(f_{m}\)) are implicitly indexed over the Cartesian product of those dimensions. For example, with dims=("geo",) each parameter is geo-specific — \(y_{t,g}\), \(\alpha_{g}\), \(\beta_{m,g}\), etc. — but they share hierarchical priors so that information is partially pooled across geographies. When dims=("geo", "brand"), every quantity is indexed by \((t, g, b)\). The equation above is written for a single slice of these dimensions; the full model is their product over all dimension combinations.

Attributes:
date_columnstr

The name of the column representing the date in the dataset.

channel_columnslist[str]

A list of column names representing the marketing channels.

target_columnstr, optional

The name of the column representing the target variable in the dataset. Defaults to "y".

adstockAdstockTransformation

The adstock transformation to apply to the channel data.

saturationSaturationTransformation

The saturation transformation to apply to the channel data.

time_varying_interceptbool or HSGPBase

Whether to use a time-varying intercept in the model, or an HSGPBase instance specifying dims and priors.

time_varying_mediabool or HSGPBase

Whether to use time-varying effects for media channels, or an HSGPBase instance specifying dims and priors.

dimstuple[str, …] or None

Additional panel dimensions for the model (e.g. ("geo",)). One categorical column per dimension must be present in the dataset. Data must be rectangular across these dimensions (i.e. the same dates for every combination).

scalingScaling or dict or None

Scaling methods for the target variable and the marketing channels. Defaults to max-absolute scaling for both.

model_configdict or None

Configuration settings for the model priors and likelihood.

sampler_configdict or None

Configuration settings for the sampler.

control_columnslist[str] or None

Column names of control covariates to include in the model.

yearly_seasonalityint or None

Number of Fourier modes for yearly seasonality.

adstock_firstbool

Whether to apply adstock before saturation (default True).

Notes

  1. Before fitting, the target variable and media channels are scaled (by default using max-absolute scaling). Control variables are not scaled automatically — apply your own preprocessing if needed.

  2. Yearly seasonality can be added as Fourier modes via the yearly_seasonality parameter.

  3. The model can be calibrated with:

For details on a vanilla implementation in PyMC see [2].

References

[1]

Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017).

Methods

MMM.__init__(*[, date_column, ...])

Define the constructor method.

MMM.add_cost_per_target_calibration(data, ...)

Calibrate cost-per-target using constraints via pm.Potential.

MMM.add_events(df_events, prefix, effect)

Add event effects to the model.

MMM.add_lift_test_measurements(df_lift_test)

Add lift tests to the model.

MMM.add_original_scale_contribution_variable(var)

Add a pm.Deterministic variable to the model that multiplies by the scaler.

MMM.approximate_fit(X[, y, progressbar, ...])

Fit a model using Variational Inference and return InferenceData.

MMM.attrs_to_init_kwargs(attrs)

Convert the idata attributes to the model initialization kwargs.

MMM.build_from_idata(idata)

Rebuild the model from an InferenceData object.

MMM.build_model(X, y, **kwargs)

Build a probabilistic model using PyMC for marketing mix modeling.

MMM.create_fit_data(X, y)

Create a fit dataset aligned on date and present dimensions.

MMM.create_idata_attrs()

Return the idata attributes for the model.

MMM.fit(X[, y, progressbar, random_seed])

Fit a model using the data passed as a parameter.

MMM.forward_pass(x, dims)

Transform channel input into target contributions of each channel.

MMM.get_scales_as_xarray()

Return the saved scaling factors as xarray DataArrays.

MMM.graphviz(**kwargs)

Get the graphviz representation of the model.

MMM.idata_to_init_kwargs(idata)

Create the model configuration and sampler configuration from the InferenceData to keyword arguments.

MMM.load(fname[, check])

Create a ModelBuilder instance from a file.

MMM.load_from_idata(idata[, check])

Create a ModelBuilder instance from an InferenceData object.

MMM.post_sample_model_transformation()

Post-sample model transformation in order to store the HSGP state from fit.

MMM.predict([X, extend_idata])

Use a model to predict on unseen data and return point prediction of all the samples.

MMM.predict_posterior([X, extend_idata, ...])

Generate posterior predictive samples on unseen data.

MMM.predict_proba([X, extend_idata, combined])

Alias for predict_posterior, for consistency with scikit-learn probabilistic estimators.

MMM.sample_adstock_curve([amount, ...])

Sample adstock curves from posterior parameters.

MMM.sample_posterior_predictive([X, ...])

Sample from the model's posterior predictive distribution.

MMM.sample_prior_predictive([X, y, samples, ...])

Sample from the model's prior predictive distribution.

MMM.sample_saturation_curve([max_value, ...])

Sample saturation curves from posterior parameters.

MMM.save(fname, **kwargs)

Save the model's inference data to a file.

MMM.set_idata_attrs([idata])

Set attributes on an InferenceData object.

MMM.table(**model_table_kwargs)

Get the summary table of the model.

Attributes

data

Get data wrapper for InferenceData access and manipulation.

default_model_config

Define the default model configuration.

default_sampler_config

Default sampler configuration.

fit_result

Get the posterior fit_result.

id

Generate a unique hash value for the model.

incrementality

Access incrementality and counterfactual analysis functionality.

output_var

plot

Use the MMMPlotSuite to plot the results.

plot_interactive

Access interactive Plotly plotting functionality.

posterior

posterior_predictive

predictions

prior

prior_predictive

sensitivity

Access sensitivity analysis functionality.

summary

Access summary DataFrame generation functionality.

version

idata

sampler_config

model_config