diff --git a/src/sid/seasonality.py b/src/sid/seasonality.py index 95c93ad6..6978055a 100644 --- a/src/sid/seasonality.py +++ b/src/sid/seasonality.py @@ -51,13 +51,12 @@ def prepare_seasonality_factor( "with 'dates' as index and seasonality factors as data." ) - # Make sure the highest multiplier is set to one so that random contacts only - # need to be reduced by the infection probability of the contact model. - for col in factor: - factor[col] = factor[col] / factor[col].max() - if not factor[col].between(0, 1).all(): - raise ValueError( - "The seasonality factors need to lie in the interval [0, 1]." - ) + factor = factor.astype(float) + + for col in factor: + if not factor[col].between(0, 1).all(): + raise ValueError( + "The seasonality factors need to lie in the interval [0, 1]." + ) return factor diff --git a/tests/test_seasonality.py b/tests/test_seasonality.py index f2ae8b03..cd052250 100644 --- a/tests/test_seasonality.py +++ b/tests/test_seasonality.py @@ -49,11 +49,11 @@ def test_simulate_a_simple_model(params, initial_states, tmp_path): index=pd.date_range("2020-01-01", periods=2), data=1, columns=["meet_two_people"], - ), + ).astype(float), {"meet_two_people": {}}, ), pytest.param( - lambda params, dates, seed: pd.Series(index=dates, data=[1, 2, 3]), + lambda params, dates, seed: pd.Series(index=dates, data=[1 / 3, 2 / 3, 1]), None, pd.date_range("2020-01-01", periods=3), does_not_raise(),