Skip to content

Commit

Permalink
decouple db pipeline from data grabber
Browse files Browse the repository at this point in the history
  • Loading branch information
rfl-urbaniak committed Mar 12, 2024
1 parent 4764f3d commit a0f97a5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
9 changes: 2 additions & 7 deletions cities/modeling/model_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@
from typing import Optional

import dill

import torch

import pyro
import pyro.distributions as dist
import torch

from cities.modeling.modeling_utils import (
prep_wide_data_for_inference,
train_interactions_model,
)
from cities.modeling.modeling_utils import (prep_wide_data_for_inference,
train_interactions_model)
from cities.utils.data_grabber import DataGrabber, find_repo_root


Expand Down
24 changes: 23 additions & 1 deletion cities/utils/csv_to_db_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
import logging
import os
import time
from pathlib import Path

import pandas as pd
from sqlalchemy import create_engine

from cities.utils.data_grabber import find_repo_root, list_csvs

logging.disable(logging.WARNING)


# defining util functions locally to avoid circular imports and major refactor
# TODO refactor the two functions out of data_grabber.py


def list_csvs(csv_dir):
csv_names = []
for filename in os.listdir(csv_dir):
if filename.endswith(".csv"):
csv_names.append(filename)

assert (
len(csv_names) > 10
), f"Expected to find more than 10 csv files in {csv_dir}, but found {len(csv_names)}"

return csv_names


def find_repo_root() -> Path:
return Path(__file__).parent.parent.parent



def create_database(data_dir, database_path):
engine = create_engine(f"sqlite:///{database_path}", echo=True)
csv_list = list_csvs(data_dir)
Expand Down

0 comments on commit a0f97a5

Please sign in to comment.