diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b275f07d..1e9dff20 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,12 +16,12 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10"] os: ["ubuntu-latest"] include: - - python-version: "3.8" + - python-version: "3.9" os: "macos-latest" - - python-version: "3.8" + - python-version: "3.9" os: "windows-latest" steps: diff --git a/.gitignore b/.gitignore index a92a8e56..b4f553f2 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ emcee_version.py .tox env .eggs +.coverage.* diff --git a/docs/conf.py b/docs/conf.py index 742d2ea2..b52d97fc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,5 +48,5 @@ "use_repository_button": True, "use_download_button": True, } -jupyter_execute_notebooks = "off" -execution_timeout = -1 +nb_execution_mode = "off" +nb_execution_timeout = -1 diff --git a/src/emcee/state.py b/src/emcee/state.py index e6555f3c..9d3413f6 100644 --- a/src/emcee/state.py +++ b/src/emcee/state.py @@ -44,6 +44,11 @@ def __init__( self.blobs = dc(blobs) self.random_state = dc(random_state) + def __len__(self): + if self.blobs is None: + return 3 + return 4 + def __repr__(self): return "State({0}, log_prob={1}, blobs={2}, random_state={3})".format( self.coords, self.log_prob, self.blobs, self.random_state @@ -55,3 +60,16 @@ def __iter__(self): return iter( (self.coords, self.log_prob, self.random_state, self.blobs) ) + + def __getitem__(self, index): + if index < 0: + return self[len(self) + index] + if index == 0: + return self.coords + elif index == 1: + return self.log_prob + elif index == 2: + return self.random_state + elif index == 3 and self.blobs is not None: + return self.blobs + raise IndexError("Invalid index '{0}'".format(index)) diff --git a/src/emcee/tests/unit/test_state.py b/src/emcee/tests/unit/test_state.py index 9ecaa838..0150eb46 100644 --- a/src/emcee/tests/unit/test_state.py +++ b/src/emcee/tests/unit/test_state.py @@ -1,11 +1,16 @@ # -*- coding: utf-8 -*- import numpy as np +import pytest from emcee import EnsembleSampler from emcee.state import State +def check_rstate(a, b): + assert all(np.allclose(a_, b_) for a_, b_ in zip(a[1:], b[1:])) + + def test_back_compat(seed=1234): np.random.seed(seed) coords = np.random.randn(16, 3) @@ -18,13 +23,13 @@ def test_back_compat(seed=1234): assert np.allclose(coords, c) assert np.allclose(log_prob, l) assert np.allclose(blobs, b) - assert all(np.allclose(a, b) for a, b in zip(rstate[1:], r[1:])) + check_rstate(rstate, r) state = State(coords, log_prob, None, rstate) c, l, r = state assert np.allclose(coords, c) assert np.allclose(log_prob, l) - assert all(np.allclose(a, b) for a, b in zip(rstate[1:], r[1:])) + check_rstate(rstate, r) def test_overwrite(seed=1234): @@ -40,3 +45,28 @@ def ll(x): sampler = EnsembleSampler(nwalkers, 1, ll) sampler.run_mcmc(p0, 10) assert np.allclose(init, p0) + + +def test_indexing(seed=1234): + np.random.seed(seed) + coords = np.random.randn(16, 3) + log_prob = np.random.randn(len(coords)) + blobs = np.random.randn(len(coords)) + rstate = np.random.get_state() + + state = State(coords, log_prob, blobs, rstate) + np.testing.assert_allclose(state[0], state.coords) + np.testing.assert_allclose(state[1], state.log_prob) + check_rstate(state[2], state.random_state) + np.testing.assert_allclose(state[3], state.blobs) + np.testing.assert_allclose(state[-1], state.blobs) + with pytest.raises(IndexError): + state[4] + + state = State(coords, log_prob, random_state=rstate) + np.testing.assert_allclose(state[0], state.coords) + np.testing.assert_allclose(state[1], state.log_prob) + check_rstate(state[2], state.random_state) + check_rstate(state[-1], state.random_state) + with pytest.raises(IndexError): + state[3] diff --git a/tox.ini b/tox.ini index 68e259a2..d759fa63 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,12 @@ [tox] -envlist = py{37,38,39}{,-extras},lint +envlist = py{37,38,39,310}{,-extras},lint [gh-actions] python = 3.7: py37 3.8: py38 3.9: py39-extras + 3.10: py310 [testenv] deps = coverage[toml]