Model comparison

Here, we will look at how we can fit multiple datasets simultaneously with some shared parameters and how we can compare models with different sharing arrangements.

First, we will import the necessary packages:

import pymc as pm
import arviz as az

import pylater
import pylater.data

Datasets

We will use all of the conditions from participant ‘b’ in the example data provided in pylater:

datasets = [
    dataset
    for dataset in pylater.data.cw1995.values()
    if dataset.name.startswith("b")
]

A ‘shift’ model

In our first model that we will fit to the data, we will use a ‘shift’ sharing arrangement: the datasets will have a common standard deviation (\(\sigma\)) parameter.

shift_model = pylater.build_default_model(datasets=datasets, share_type="shift")

We then fit the model:

with shift_model:
    shift_idata = pm.sample()

Note that there is a lot of data here, so sampling can take a little while - we are using fewer samples than typical in this example to allow for faster execution.

A ‘swivel’ model

The second model will use a ‘swivel’ sharing arrangement: the datasets will have a common intercept (\(k\)) parameter.

swivel_model = pylater.build_default_model(datasets=datasets, share_type="swivel")

We then also fit this model:

with swivel_model:
    swivel_idata = pm.sample()

Calculating log-likelihoods

In order to compare the models, we first need to compute their log-likelihoods.

with shift_model:
    shift_idata = pm.compute_log_likelihood(idata=shift_idata)
with swivel_model:
    swivel_idata = pm.compute_log_likelihood(idata=swivel_idata)

However, this has the log-likelihoods calculated separately for each dataset. We can use the helper function pylater.combine_multiple_likelihoods to gather them together:

(shift_idata, swivel_idata) = (
    pylater.combine_multiple_likelihoods(idata=idata)
    for idata in (shift_idata, swivel_idata)
)

Comparing models

We can then use the ArviZ function az.compare to do the model comparison:

comparison = az.compare(
    compare_dict={"shift": shift_idata, "swivel": swivel_idata},
    var_name="obs",
)
comparison
rank elpd_loo p_loo elpd_diff weight se dse warning scale
swivel 0 -36622.945463 13.898394 0.000000 1.000000e+00 135.994097 0.000000 False log
shift 1 -36813.623295 11.367739 190.677832 1.609408e-09 136.091176 14.423685 False log

See the documentation for az.compare for details on interpreting this output.

Last updated: Thu May 23 2024

Python implementation: CPython
Python version       : 3.10.14
IPython version      : 8.24.0

arviz  : 0.18.0
pymc   : 5.15.0
pylater: 0.1