From 8cdc9ee79f7ca0201764208fb85a9ab6871c344c Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Mon, 12 Aug 2024 05:36:46 -0700 Subject: [PATCH] Blackjax sampler fix for breaking change / enable progress bar under parallel chain_method (#7453) * remove blackjax pmap warning * use gen_scan_fn * remove labels * retrigger checks * retrigger checks --- pymc/sampling/jax.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index c4d9099b90..c530af8d9a 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -278,15 +278,10 @@ def _one_step(state, xs): return state, (position, stats) progress_bar = adaptation_kwargs.pop("progress_bar", False) - if progress_bar: - from blackjax.progress_bar import progress_bar_scan - - one_step = jax.jit(progress_bar_scan(draws)(_one_step)) - else: - one_step = jax.jit(_one_step) keys = jax.random.split(seed, draws) - _, (samples, stats) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys)) + scan_fn = blackjax.progress_bar.gen_scan_fn(draws, progress_bar) + _, (samples, stats) = scan_fn(_one_step, last_state, (jnp.arange(draws), keys)) return samples, stats @@ -365,14 +360,6 @@ def _sample_blackjax_nuts( # Adapted from numpyro if chain_method == "parallel": map_fn = jax.pmap - if progressbar: - import warnings - - warnings.warn( - "BlackJax currently only display progress bar correctly under " - "`chain_method == 'vectorized'`. Setting `progressbar=False`." - ) - progressbar = False elif chain_method == "vectorized": map_fn = jax.vmap else: