Model comparison
Here, we will look at how we can fit multiple datasets simultaneously with some shared parameters and how we can compare models with different sharing arrangements.
First, we will import the necessary packages:
import pymc as pm
import arviz as az
import pylater
import pylater.data
Datasets
We will use all of the conditions from participant ‘b’ in the example data provided in pylater
:
datasets = [
dataset
for dataset in pylater.data.cw1995.values()
if dataset.name.startswith("b")
]
A ‘shift’ model
In our first model that we will fit to the data, we will use a ‘shift’ sharing arrangement: the datasets will have a common standard deviation (\(\sigma\)) parameter.
shift_model = pylater.build_default_model(datasets=datasets, share_type="shift")
We then fit the model:
with shift_model:
shift_idata = pm.sample()
Note that there is a lot of data here, so sampling can take a little while - we are using fewer samples than typical in this example to allow for faster execution.
A ‘swivel’ model
The second model will use a ‘swivel’ sharing arrangement: the datasets will have a common intercept (\(k\)) parameter.
swivel_model = pylater.build_default_model(datasets=datasets, share_type="swivel")
We then also fit this model:
with swivel_model:
swivel_idata = pm.sample()
Calculating log-likelihoods
In order to compare the models, we first need to compute their log-likelihoods.
with shift_model:
shift_idata = pm.compute_log_likelihood(idata=shift_idata)
with swivel_model:
swivel_idata = pm.compute_log_likelihood(idata=swivel_idata)
However, this has the log-likelihoods calculated separately for each dataset.
We can use the helper function pylater.combine_multiple_likelihoods
to gather them together:
(shift_idata, swivel_idata) = (
pylater.combine_multiple_likelihoods(idata=idata)
for idata in (shift_idata, swivel_idata)
)
Comparing models
We can then use the ArviZ function az.compare
to do the model comparison:
comparison = az.compare(
compare_dict={"shift": shift_idata, "swivel": swivel_idata},
var_name="obs",
)
comparison
rank | elpd_loo | p_loo | elpd_diff | weight | se | dse | warning | scale | |
---|---|---|---|---|---|---|---|---|---|
swivel | 0 | -36622.945463 | 13.898394 | 0.000000 | 1.000000e+00 | 135.994097 | 0.000000 | False | log |
shift | 1 | -36813.623295 | 11.367739 | 190.677832 | 1.609408e-09 | 136.091176 | 14.423685 | False | log |
See the documentation for az.compare
for details on interpreting this output.
Last updated: Thu May 23 2024
Python implementation: CPython
Python version : 3.10.14
IPython version : 8.24.0
arviz : 0.18.0
pymc : 5.15.0
pylater: 0.1