Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
5e3ba87
add DDM task wrapped from DiffModel.jl.
janfb Apr 26, 2021
091b971
move paramter batch for loop to julia, add potential fun.
janfb Apr 26, 2021
77a2bcf
reference posterior sampling for ddm.;
janfb Apr 27, 2021
7af30fb
adapt mcmc to work with either potential function or model.
janfb Apr 27, 2021
d9fe569
refactor log to logprob, loop over parameter batch.
janfb Apr 28, 2021
1c3f428
add transform, get posterior samples from grid.
janfb Apr 28, 2021
2694a5b
formatting
janfb Apr 29, 2021
1f4c276
fix summation bug and get reference via mcmc
janfb Apr 29, 2021
386d0ae
introduce 4 param version via asymmetric bounds.
janfb Apr 29, 2021
409dc8d
change to 3 parameters, removing ndt.
janfb Apr 30, 2021
b91e54c
add seeding to Julia wrapper.
janfb May 4, 2021
b0b8979
add seeding, add repeated seeds to compare refs.
janfb May 4, 2021
37d4eab
added analogous python functions for ddm simulator and likelihood usi…
rdgao May 4, 2021
e4d5929
add seeding, add repeated seeds to compare refs.
janfb May 4, 2021
3afdc55
add refs for 5 obs, 4 ref-methods.
janfb May 4, 2021
34864e6
custom _setup for ddm to allow seeds with different num_trials.
janfb May 5, 2021
88003d9
references for 5 observations with 4 increasing num_trials.
janfb May 5, 2021
a52db2a
add mcmc reference for 10 obs, 3d, 1024 trials.
janfb May 12, 2021
8d4e498
change references to single trial.
janfb May 14, 2021
5fdd646
change reference to obs 1 for 4 different #trials.
janfb May 14, 2021
b04b449
encode down choices as negative rts.
janfb May 17, 2021
c6e8e2e
adapt to new rt encoding, add hyperparams.
janfb May 17, 2021
b52ef02
fix pairplot imports.
janfb May 17, 2021
34c7148
add moments metrics.
janfb May 17, 2021
911ec20
add references with new rts coding.
janfb May 17, 2021
1fe72b3
add args to snle.
janfb May 19, 2021
30652a6
change reference to 10 obs with 1024 trials.
janfb May 19, 2021
253b467
refactor task and add 4 parameter version with ndt.
janfb May 20, 2021
e19e522
fix seeds and bug for 4 param model.
janfb May 21, 2021
946676f
add new 4 param refs.
janfb May 21, 2021
03f6544
run script with potential function for LAN.
janfb May 27, 2021
512c14f
fix prior log prob bug, add lower bound.
janfb Jun 9, 2021
5cf0c5f
add potential fn to utils.
janfb Jun 9, 2021
a2bf02d
add lan scripts and models.
janfb Jun 9, 2021
2126123
add lan budget in return
janfb Jun 9, 2021
d04ce31
single trial reference.
janfb Jun 11, 2021
8863ae5
add custom mcmc and train in untransformed space.
janfb Jun 15, 2021
f842c53
single trial ref.
janfb Jun 15, 2021
bb96de6
100 trials ref.
janfb Jun 15, 2021
2e2d4ee
adapt scripts.
janfb Jun 16, 2021
88580ed
add julia task for double check.
janfb Jun 16, 2021
61866c4
10 trial ref.
janfb Jun 16, 2021
9cbf4e6
100 trial ref.
janfb Jun 16, 2021
4d73769
1000 trial ref.
janfb Jun 17, 2021
4ec9fca
normalize moments with posterior var.
janfb Jun 18, 2021
6699a55
load and save density estimators.
janfb Jun 18, 2021
a68d533
apply rt>tau trick for lan. add mixed-model to utils.
janfb Jun 18, 2021
29d618f
fix bugs in pfs and lps.
janfb Jun 18, 2021
235bd78
undo loading of pretrained nets.
janfb Jun 21, 2021
e73e04a
100 trial ref.
janfb Jun 16, 2021
6f2c8d4
add mixed model as algorithm.
janfb Jun 28, 2021
1cfdeae
add mm as pretrained alg, refactor.
janfb Jul 5, 2021
02dadf2
10 references, 1, 10, 100, 1000 trials each.
janfb Jul 5, 2021
78aae4d
refactor and documentation.
janfb Jul 8, 2021
34bd050
depend on sbi branch with changes for ddm.
janfb Jul 8, 2021
3594ab6
add notebook for LAN-NLE likelihood comparison and benchmark instruct…
janfb Jul 8, 2021
8859400
apply ladj to posterior potential.
janfb Jul 14, 2021
c6e6edf
update references.
janfb Jul 14, 2021
8728b73
update potential functions, use original prior, refactor.
janfb Jul 14, 2021
c489be1
Update README.md
janfb Aug 3, 2021
366e73a
precalculate logprob for 0 1 choices to save time
janfb Aug 25, 2021
9879311
bugfix lower bound variable name.
janfb Sep 8, 2021
b4039fb
script for fig 5.
janfb Sep 10, 2021
d348791
change to 10 1024-trial obs with refs.
janfb Sep 11, 2021
473bed8
add 100 1024-trial refs.
janfb Sep 13, 2021
147e51b
add class full ddm in julia.
janfb Sep 11, 2021
1b75e38
hard code to nle model index 315.
janfb Sep 13, 2021
8d792cf
remove 1024 trials refs.
janfb Sep 28, 2021
14aafc0
switch to 10 mcmc chains.
janfb Oct 6, 2021
3916929
notebook for paper figures.
janfb Oct 6, 2021
726c13b
remove 100 100-trial refs.
janfb Oct 12, 2021
5942f27
10 references, 1, 10, 100, 1000 trials each.
janfb Jul 5, 2021
a078720
update notebook for paper figures.
janfb Oct 18, 2021
336324a
remove old refs and replace by 100x 1 and 10 trial refs.
janfb Oct 19, 2021
4265c7b
97 refs for 1 and 10 trials.
janfb Oct 21, 2021
af90af3
replace previous refs with 100x 1,10,100,1000 trial refs.
janfb Oct 21, 2021
dd0fba3
add new refs, 2 missing.
janfb Oct 22, 2021
61a21c6
adapt task to produce 1,10,100,1000 x 100 refs.
janfb Oct 22, 2021
d506ac0
change mean error to normalized by std.
janfb Oct 22, 2021
01c69b2
adapt npe script to model index 315_2.
janfb Oct 22, 2021
14c55c4
update figures nb and pretrained script.
janfb Nov 5, 2021
be38e0a
update figures
janfb Nov 11, 2021
207d3c2
update figure nb.
janfb Dec 9, 2021
b963949
add collapsing bound DDM variants.
janfb Jan 19, 2022
c735193
remove dependency on sbi-ddm branch.
janfb Jan 23, 2022
0cdbe40
remove python ddm code.
janfb Jan 23, 2022
5ee3830
update figures.
janfb Feb 22, 2022
4a3c276
fix sir import
janfb Nov 29, 2022
600e619
add ddm test
janfb Feb 14, 2023
1900d61
add inference test/
janfb Feb 14, 2023
ef60b91
update ddm inference test.
janfb Feb 14, 2023
065ff8a
refactor to new sbi api.
janfb Oct 25, 2022
aadf4a2
refactor sl run script.
janfb Oct 25, 2022
9f09b08
refactor sbi tests.
janfb Nov 8, 2022
d1fce9b
fix for snle and snre.
janfb Nov 10, 2022
8a36fd9
feedback
janfb Dec 2, 2022
127e132
update run scripts.
janfb Feb 14, 2023
e000a68
fix typing
janfb Feb 14, 2023
d197e51
fix zscore args, update snpe for iid data.
janfb Feb 14, 2023
4899bd1
fix iid snpe script.
janfb Feb 20, 2023
a3ed989
fix julia init
janfb Feb 21, 2023
94e3aeb
add mnle run fun.
janfb Feb 21, 2023
611f19e
update ddm 2d mapping.;
janfb Feb 24, 2023
ccf9e76
update run script.
janfb Feb 24, 2023
acddafc
update setup.py
janfb Feb 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3,090 changes: 3,090 additions & 0 deletions lan_nle_comparison/LAN-NLE-Figures.ipynb

Large diffs are not rendered by default.

591 changes: 591 additions & 0 deletions lan_nle_comparison/LAN-NLE-Likelihood-Comparison.ipynb

Large diffs are not rendered by default.

72 changes: 72 additions & 0 deletions lan_nle_comparison/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Comparison between LANs and NLE on the simlpe drift-diffusion model

This is a short demonstration for how to reproduce the comparison between LANs and NLE.

We perform the comparison in (log) likelihood space, comparing the synthetic likelihoods
of LAN and NLE against the analytic likelihoods obtained from https://github.com/DrugowitschLab/DiffModels.jl.

Additionally, we use MCMC via slice sampling to compare two approaches in posterior space.
For this comparison we have added the DDM as a task in a framework developed for benchmarking
simulation-based inference algorithms, `sbibm`.

In general, the code relies on three repositories, [`sbi`](https://github.com/mackelab/sbi) for using NLE,
[`sbibm`](https://github.com/sbi-benchmark/sbibm) for simulating the data and loading the LAN keras weights,
and [`benchmarking-results`](https://github.com/sbi-benchmark/results/tree/main/benchmarking_sbi) for running the benchmark.

## Comparison in likelihood space
For a demo of the likelihood comparison you find a jupyter notebook in this folder. For
executing the notebook locally perform the steps outlined below.

```bash
# clone repo
git clone https://github.com/mackelab/sbibm.git
# switch to branch
cd sbibm
git checkout ddm-task
# install locally (e.g., in a new conda env)
pip install -e .
# install missing nflow dependency
pip install UMNN
# open the notebook at /lan_nle_comparison
```

## Comparison in posterior space using `sbibm`

For a general overview over the benchmarking suite see https://sbi-benchmark.github.io.

To run the benchmark on your local machine, please follow the steps below.

- **optional**: create and activate a new conda environment
```bash
conda create -n ddmtest python=3.8
conda activate ddmtest
```

- clone and install `benchmarking-results` repo from https://github.com/sbi-benchmark/results/tree/main/benchmarking_sbi

```bash
git clone https://github.com/mackelab/results.git
cd results/benchmarking_sbi
git checkout ddm
pip install -r requirements.txt
cd ../..
```

- clone and install sbibm repo on the `ddm-task` branch

```bash
git clone https://github.com/mackelab/sbibm.git
cd sbibm
git checkout ddm-task
pip install -e .
cd ..
```

- run the benchmark

```bash
cd results/benchmarking_sbi
python run.py task=ddm task.num_observation=1 algorithm=lan
```

More details about how to run the benchmark can be found at https://github.com/sbi-benchmark/results/tree/main/benchmarking_sbi.
Binary file added lan_nle_comparison/ddm_transforms.p
Binary file not shown.
106 changes: 106 additions & 0 deletions lan_nle_comparison/reproduce_figure_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import keras
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import pickle
import sbibm
import torch
import time
from joblib import Parallel, delayed

from sbibm.tasks.ddm.utils import run_mcmc, LANPotentialFunctionProvider
from sbibm.algorithms.sbi.utils import wrap_prior_dist


# network trained on KDE likelihood for 4-param ddm
lan_kde_path = "../sbibm/algorithms/lan/lan_pretrained/model_final_ddm.h5"
lan_ana_path = "../sbibm/algorithms/lan/lan_pretrained/model_final_ddm_analytic.h5"
lan_kde = keras.models.load_model(lan_kde_path, compile=False)
lan_ana = keras.models.load_model(lan_ana_path, compile=False)

# Load pretrained NLE model
with open("../sbibm/algorithms/lan/nle_pretrained/mm_688_4.p", "rb") as fh:
nle = pickle.load(fh)

num_workers = 80
m = num_workers
n = 1024
l_lower_bound = 1e-7
num_samples = 10000


task = sbibm.get_task("ddm")
prior = task.get_prior_dist()
simulator = task.get_simulator(num_trials=n) # Passing the seed to Julia.

thos = prior.sample((m,))
xos = task.get_simulator()(thos)

mcmc_parameters = {
"num_chains": 100,
"thin": 10,
"warmup_steps": 100,
"init_strategy": "sir",
"sir_batch_size": 100,
"sir_num_batches": 1000,
}

with open("ddm_transforms.p", "rb") as fh:
transforms = pickle.load(fh)["transforms"]
prior_transformed = wrap_prior_dist(prior, transforms)


def local_run(xi):

tic = time.time()
# Get potential function for mixed model.
potential_fn_mm = nle.get_potential_fn(
xi.reshape(-1, 1),
transforms,
# Pass untransformed prior and correct internally with ladj.
prior=prior,
ll_lower_bound=np.log(l_lower_bound),
)

# Run MCMC in transformed space.
transformed_samples = run_mcmc(
prior=prior_transformed,
potential_fn=potential_fn_mm,
mcmc_parameters=mcmc_parameters,
num_samples=num_samples,
)

nle_samples = transforms.inv(transformed_samples)
nle_time = time.time() - tic

tic = time.time()
# Use potential function provided refactored from SBI toolbox for LAN.
potential_fn_lan = LANPotentialFunctionProvider(transforms, lan_kde, l_lower_bound)

lan_transformed_samples = run_mcmc(
prior=prior_transformed,
# Pass original prior to pf and correct potential with ladj.
potential_fn=potential_fn_lan(
prior=prior,
sbi_net=None,
x=xi.reshape(-1, 1),
mcmc_method="slice_np_vectorized",
),
mcmc_parameters=mcmc_parameters,
num_samples=num_samples,
)

lan_samples = transforms.inv(lan_transformed_samples)
lan_time = time.time() - tic

return nle_samples, lan_samples, nle_time, lan_time


# run in parallel
results = Parallel(n_jobs=num_workers)(delayed(local_run)(_) for _ in xos)

with open("figure_5_results.p", "wb") as fh:
pickle.dump(dict(thos=thos, xos=xos, results=results), fh)

print("Done")
89 changes: 89 additions & 0 deletions sbibm/algorithms/lan/julia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import logging
from typing import Any, Dict, Optional, Tuple

import torch

from sbibm.tasks.task import Task

from sbibm.algorithms.sbi.utils import wrap_prior_dist
from sbibm.tasks.ddm.utils import run_mcmc


def run(
task: Task,
num_samples: int,
num_simulations: int,
num_observation: Optional[int] = None,
observation: Optional[torch.Tensor] = None,
automatic_transforms_enabled: bool = True,
mcmc_method: str = "slice_np_vectorized",
mcmc_parameters: Dict[str, Any] = {
"num_chains": 100,
"thin": 10,
"warmup_steps": 100,
"init_strategy": "sir",
"sir_batch_size": 1000,
"sir_num_batches": 100,
},
l_lower_bound: float = 1e-7,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
"""Runs MCMC with analytical DDM likelihood.

Args:
task: Task instance, here DDM.
num_observation: Observation number to load, alternative to `observation`
observation: Observation, alternative to `num_observation`
num_samples: Number of samples to generate from posterior
num_simulations: Simulation budget
num_rounds: Number of rounds
automatic_transforms_enabled: Whether to enable automatic transforms
mcmc_method: MCMC method
mcmc_parameters: MCMC parameters
l_lower_bound: lower bound for single trial likelihood evaluations.

Returns:
Samples from posterior, number of simulator calls, log probability of true params if computable
"""
assert not (num_observation is None and observation is None)
assert not (num_observation is not None and observation is not None)
assert (
task.name == "ddm"
), "This algorithm works only for the DDM task as it uses its analytical likeklihood."

log = logging.getLogger(__name__)
log.info(f"Running MCMC with analytical likelihoods from Julia package.")

prior = task.get_prior_dist()
if observation is None:
observation = task.get_observation(num_observation)

transforms = task._get_transforms(automatic_transforms_enabled)["parameters"]
if automatic_transforms_enabled:
prior_transformed = wrap_prior_dist(prior, transforms)

# sbi needs the trials in first dimension.

llj = task._get_log_prob_fn(
None,
observation,
"experimental",
posterior=True,
automatic_transforms_enabled=automatic_transforms_enabled,
l_lower_bound=l_lower_bound,
)

def potential_fn_julia(theta):
theta = torch.as_tensor(theta, dtype=torch.float32)

return llj(theta)

# Run MCMC in transformed space.
samples = run_mcmc(
prior=prior_transformed,
potential_fn=potential_fn_julia,
mcmc_parameters=mcmc_parameters,
num_samples=num_samples,
)

# Return untransformed samples.
return transforms.inv(samples), num_simulations, None
98 changes: 98 additions & 0 deletions sbibm/algorithms/lan/lan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging
import pathlib
from typing import Any, Dict, Optional, Tuple

import keras
import torch
from sbibm.tasks.task import Task
from sbibm.tasks.ddm.utils import LANPotentialFunctionProvider, run_mcmc

from sbibm.algorithms.sbi.utils import wrap_prior_dist


def run(
task: Task,
num_samples: int,
num_simulations: int,
num_observation: Optional[int] = None,
observation: Optional[torch.Tensor] = None,
automatic_transforms_enabled: bool = True,
mcmc_method: str = "slice_np_vectorized",
mcmc_parameters: Dict[str, Any] = {
"num_chains": 10,
"thin": 10,
"warmup_steps": 100,
"init_strategy": "sir",
"sir_batch_size": 1000,
"sir_num_batches": 100,
},
l_lower_bound: float = 1e-7,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
"""Runs pretrained LAN based on analytical likelihood targets.

Args:
task: Task instance
num_observation: Observation number to load, alternative to `observation`
observation: Observation, alternative to `num_observation`
num_samples: Number of samples to generate from posterior
num_simulations: Simulation budget
num_rounds: Number of rounds
automatic_transforms_enabled: Whether to enable automatic transforms
mcmc_method: MCMC method
mcmc_parameters: MCMC parameters
l_lower_bound: lower bound for single trial likelihood evaluations.

Returns:
Samples from posterior, number of simulator calls, log probability of true params if computable
"""
assert not (num_observation is None and observation is None)
assert not (num_observation is not None and observation is not None)
assert (
task.name == "ddm"
), "This algorithm works only for the DDM task as it uses its analytical likeklihood."

log = logging.getLogger(__name__)
log.info(f"Running LAN pretrained with KDE targets.")
# Set LAN budget from paper.
lan_budget = int(1e5 * 1.5e6)

prior = task.get_prior_dist()
if observation is None:
observation = task.get_observation(num_observation)

# Maybe transform to unconstrained parameter space for MCMC.
transforms = task._get_transforms(automatic_transforms_enabled)["parameters"]
if automatic_transforms_enabled:
prior_transformed = wrap_prior_dist(prior, transforms)
else:
prior_transformed = prior

num_trials = observation.shape[1]
# sbi needs the trials in first dimension.
observation_sbi = observation.reshape(num_trials, 1)

# network trained on KDE likelihood for 4-param ddm
lan_kde_path = (
f"{pathlib.Path(__file__).parent.resolve()}/lan_pretrained/model_final_ddm.h5"
)
# load weights as keras model
lan_kde = keras.models.load_model(lan_kde_path, compile=False)

# Use potential function provided refactored from SBI toolbox for LAN.
potential_fn_lan = LANPotentialFunctionProvider(transforms, lan_kde, l_lower_bound)

samples = run_mcmc(
prior=prior_transformed,
# Pass original prior to pf and correct potential with ladj.
potential_fn=potential_fn_lan(
prior=prior,
sbi_net=None,
x=observation_sbi,
mcmc_method=mcmc_method,
),
mcmc_parameters=mcmc_parameters,
num_samples=num_samples,
)

# Return untransformed samples.
return transforms.inv(samples), lan_budget, None
Loading