Skip to content

Commit

Permalink
Pipeline setup (#83)
Browse files Browse the repository at this point in the history
Note: The scripts here are not functional. There's even malformed Python
code; this PR is intended to be fixed by later ones right away.

- Read model names and score function names from config
- Do computation for every combination of data set, model, forecast
date, and score function
- Separate into scripts
- Use a Makefile

---------

Co-authored-by: Fuhan Yang <ab32@cdc.gov>
  • Loading branch information
swo and Fuhan-Yang authored Jan 10, 2025
1 parent 7e7b264 commit 5575cd3
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 161 deletions.
14 changes: 11 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@ NIS_CACHE = .cache/nisapi
TOKEN_PATH = scripts/socrata_app_token.txt
TOKEN = $(shell cat $(TOKEN_PATH))
CONFIG = scripts/config.yaml
RAW_DATA = data/nis_raw.parquet
FORECASTS = data/forecasts.parquet
SCORES = data/scores.parquet

.PHONY: cache

run: $(CONFIG) cache
python scripts/main.py --config=$(CONFIG) --cache=$(NIS_CACHE)/clean
all: $(SCORES)

data/nis_raw.parquet: scripts/preprocess.py cache
$(SCORE): scripts/eval.py $(FORECASTS)
python $< --input=$(FORECASTS) --output=$@

$(FORECASTS): scripts/forecast.py $(RAW_DATA)
python $< --input=$(RAW_DATA) --output=$@

$(RAW_DATA): scripts/preprocess.py cache
python $< --cache=$(NIS_CACHE)/clean --output=$@

cache: $(NIS_CACHE)/status.txt
Expand Down
6 changes: 3 additions & 3 deletions iup/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def check_date_match(data: IncidentUptakeData, pred: PointForecast):
(data["time_end"] == pred["time_end"]).all()

# 2. There should not be any duplicated date in either data or prediction.
assert not (
any(data["time_end"].is_duplicated())
), "Duplicated dates are found in data and prediction."
assert not (any(data["time_end"].is_duplicated())), (
"Duplicated dates are found in data and prediction."
)


def score(
Expand Down
8 changes: 3 additions & 5 deletions scripts/config_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ timeframe:
end: 2024-04-30
interval: 7d

models: [LinearIncidentUptakeModel]

# The options to return projection or evaluation metrics, can be "projection" or "evaluation"
option: projection

# The metric is only available when option == 'evaluation', it can be "mspe","mean_bias","end_of_season_error", "all"
metrics: mspe
# score metrics
score_funs: [mspe, mean_bias, eos_abe]
6 changes: 6 additions & 0 deletions scripts/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
for score_fun in score_funs:
score = eval.score(
incident_test_data, incident_projections, score_fun
)
print(f"{model=} {forecast_date=} {score_fun=} {score=}")
# save these scores somewhere
70 changes: 70 additions & 0 deletions scripts/forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import argparse

import polars as pl
import yaml

def run_all_forecasts() -> pl.DataFrame:
"""Run all forecasts
Returns:
pl.DataFrame: data frame of forecasts, organized by model and forecast date
"""
forecast_dates = pl.date_range(
config["timeframe"]["start"],
config["timeframe"]["end"],
config["timeframe"]["interval"],
eager=True,
)
models = [getattr(iup.models, model_name) for model_name in config["models"]]
assert all(issubclass(model, iup.models.UptakeModel) for model in models)

for model in models:
for forecast_date in forecast_dates:
# Get data available as of the forecast date


def run_forecast() -> pl.DataFrame:
"""Run a single model for a single forecast date"""
incident_train_data = iup.IncidentUptakeData(
iup.IncidentUptakeData.split_train_test(
incident_data, config["timeframe"]["start"], "train"
)
)

# Fit models using the training data and make projections
fit_model = model().fit(incident_train_data, grouping_factors)

cumulative_projections = fit_model.predict(
config["timeframe"]["start"],
config["timeframe"]["end"],
config["timeframe"]["interval"],
grouping_factors,
)
# save these projections somewhere

incident_projections = cumulative_projections.to_incident(
grouping_factors
)
# save these projections somewhere

# Evaluation / Post-processing --------------------------------------------

incident_test_data = iup.IncidentUptakeData(
iup.IncidentUptakeData.split_train_test(
incident_data, config["timeframe"]["start"], "test"
)
).filter(pl.col("date") <= config["timeframe"]["end"])


if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--config", help="config file", default="scripts/config.yaml")
p.add_argument("--input", help="input data")
args = p.parse_args()

with open(args.config, "r") as f:
config = yaml.safe_load(f)

input_data = pl.scan_parquet(args.input)

run_all_forecasts(config=config, cache=args.cache)
138 changes: 0 additions & 138 deletions scripts/main.py

This file was deleted.

63 changes: 51 additions & 12 deletions scripts/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,68 @@
import argparse
import datetime
from typing import List

import nisapi
import polars as pl
import yaml

import iup

def preprocess(raw_data: pl.LazyFrame) -> pl.DataFrame:
return (
raw_data.filter(
pl.col("geography_type").is_in(["nation", "admin1"]),
pl.col("domain_type") == pl.lit("age"),
pl.col("domain") == pl.lit("18+ years"),
pl.col("indicator") == pl.lit("received a vaccination"),
)
.drop(["indicator_type", "indicator"])
.head()
.collect()

def preprocess(
raw_data: pl.LazyFrame,
filters: dict,
keep: List[str],
groups: List[str],
rollout_dates: List[datetime.date],
) -> pl.DataFrame:
# Prune data to correct rows and columns
cumulative_data = iup.CumulativeUptakeData(
raw_data.filter(filters).select(keep).sort("time_end").collect()
)

# Ensure that the desired grouping factors are found in all data sets
assert set(cumulative_data.columns).issuperset(groups)

# Insert rollout dates into the data
cumulative_data = iup.CumulativeUptakeData(
cumulative_data.insert_rollout(rollout_dates, groups)
)

# Convert to incident data
incident_data = cumulative_data.to_incident(groups)

return pl.concat(
[
cumulative_data.with_columns(estimate_type="cumulative"),
incident_data.with_columns(estimate_type="incident"),
]
)


if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--config", help="config file", default="scripts/config.yaml")
p.add_argument(
"--cache", help="NIS cache directory", default=".cache/nisapi/clean/"
)
p.add_argument("--cache", help="clean cache directory")
p.add_argument("--output", help="output parquet file")
args = p.parse_args()

with open(args.config, "r") as f:
config = yaml.safe_load(f)

assert len(config["data"]) == 1, "Don't know how to preprocess multiple data sets"

raw_data = nisapi.get_nis(path=args.cache)
clean_data = preprocess(raw_data)

clean_data = preprocess(
raw_data,
filters=config["data"][0]["filters"],
keep=config["data"][0]["keep"],
groups=config["groups"],
rollout_dates=config["data"][0]["rollout"],
)

clean_data.write_parquet(args.output)

0 comments on commit 5575cd3

Please sign in to comment.