diff --git a/src/scripts/analysis_data_generation/analysis_extract_data.py b/src/scripts/analysis_data_generation/analysis_extract_data.py new file mode 100644 index 0000000000..2cfba5315b --- /dev/null +++ b/src/scripts/analysis_data_generation/analysis_extract_data.py @@ -0,0 +1,370 @@ +"""Produce plots to show the health impact (deaths, dalys) each the healthcare system (overall health impact) when +running under different MODES and POLICIES (scenario_impact_of_actual_vs_funded.py)""" + +# short tclose -> ideal case +# long tclose -> status quo +import argparse +from pathlib import Path +from typing import Tuple + +import pandas as pd + +from tlo import Date +from tlo.analysis.utils import extract_results +from datetime import datetime + +# Range of years considered +min_year = 2010 +max_year = 2040 + + +def all_columns(_df): + return pd.Series(_df.all()) + +def apply(results_folder: Path, output_folder: Path, resourcefilepath: Path = None, ): + """Produce standard set of plots describing the effect of each TREATMENT_ID. + - We estimate the epidemiological impact as the EXTRA deaths that would occur if that treatment did not occur. + - We estimate the draw on healthcare system resources as the FEWER appointments when that treatment does not occur. + """ + pd.set_option('display.max_rows', None) + pd.set_option('display.max_colwidth', None) + event_chains = extract_results( + results_folder, + module='tlo.simulation', + key='event_chains', + column='0', + #column = str(i), + #custom_generate_series=get_num_dalys_by_year, + do_scaling=False + ) + # print(event_chains.loc[0,(0, 0)]) + + eval_env = { + 'datetime': datetime, # Add the datetime class to the eval environment + 'pd': pd, # Add pandas to handle Timestamp + 'Timestamp': pd.Timestamp, # Specifically add Timestamp for eval + 'NaT': pd.NaT, + 'nan': float('nan'), # Include NaN for eval (can also use pd.NA if preferred) + } + + for item,row in event_chains.iterrows(): + value = event_chains.loc[item,(0, 0)] + if value !='': + print('') + print(value) + exit(-1) + #dict = {} + #for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: + # dict[i] = [] + + #for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: + # event_chains = extract_results( + # results_folder, + # module='tlo.simulation'#, + # key='event_chains', + # column = str(i), + # #custom_generate_series=get_num_dalys_by_year, + # do_scaling=False + # ) + # print(event_chains) + # print(event_chains.index) + # print(event_chains.columns.levels) + + # for index, row in event_chains.iterrows(): + # if event_chains.iloc[index,0] is not None: + # if(event_chains.iloc[index,0]['person_ID']==i): #and 'event' in event_chains.iloc[index,0].keys()): + # dict[i].append(event_chains.iloc[index,0]) + #elif (event_chains.iloc[index,0]['person_ID']==i and 'event' not in event_chains.iloc[index,0].keys()): + #print(event_chains.iloc[index,0]['de_depr']) + # exit(-1) + #for item in dict[0]: + # print(item) + + #exit(-1) + + TARGET_PERIOD = (Date(min_year, 1, 1), Date(max_year, 1, 1)) + + # Definitions of general helper functions + lambda stub: output_folder / f"{stub.replace('*', '_star_')}.png" # noqa: E731 + + def target_period() -> str: + """Returns the target period as a string of the form YYYY-YYYY""" + return "-".join(str(t.year) for t in TARGET_PERIOD) + + def get_parameter_names_from_scenario_file() -> Tuple[str]: + """Get the tuple of names of the scenarios from `Scenario` class used to create the results.""" + from scripts.healthsystem.impact_of_actual_vs_funded.scenario_impact_of_actual_vs_funded import ( + ImpactOfHealthSystemMode, + ) + e = ImpactOfHealthSystemMode() + return tuple(e._scenarios.keys()) + + def get_num_deaths(_df): + """Return total number of Deaths (total within the TARGET_PERIOD) + """ + return pd.Series(data=len(_df.loc[pd.to_datetime(_df.date).between(*TARGET_PERIOD)])) + + def get_num_dalys(_df): + """Return total number of DALYs (Stacked) by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year.between(*[i.year for i in TARGET_PERIOD])] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum().sum() + ) + + def get_num_dalys_by_cause(_df): + """Return number of DALYs by cause by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year.between(*[i.year for i in TARGET_PERIOD])] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum() + ) + + def set_param_names_as_column_index_level_0(_df): + """Set the columns index (level 0) as the param_names.""" + ordered_param_names_no_prefix = {i: x for i, x in enumerate(param_names)} + names_of_cols_level0 = [ordered_param_names_no_prefix.get(col) for col in _df.columns.levels[0]] + assert len(names_of_cols_level0) == len(_df.columns.levels[0]) + _df.columns = _df.columns.set_levels(names_of_cols_level0, level=0) + return _df + + def find_difference_relative_to_comparison(_ser: pd.Series, + comparison: str, + scaled: bool = False, + drop_comparison: bool = True, + ): + """Find the difference in the values in a pd.Series with a multi-index, between the draws (level 0) + within the runs (level 1), relative to where draw = `comparison`. + The comparison is `X - COMPARISON`.""" + return _ser \ + .unstack(level=0) \ + .apply(lambda x: (x - x[comparison]) / (x[comparison] if scaled else 1.0), axis=1) \ + .drop(columns=([comparison] if drop_comparison else [])) \ + .stack() + + + def get_counts_of_hsi_by_treatment_id(_df): + """Get the counts of the short TREATMENT_IDs occurring""" + _counts_by_treatment_id = _df \ + .loc[pd.to_datetime(_df['date']).between(*TARGET_PERIOD), 'TREATMENT_ID'] \ + .apply(pd.Series) \ + .sum() \ + .astype(int) + return _counts_by_treatment_id.groupby(level=0).sum() + + year_target = 2023 + def get_counts_of_hsi_by_treatment_id_by_year(_df): + """Get the counts of the short TREATMENT_IDs occurring""" + _counts_by_treatment_id = _df \ + .loc[pd.to_datetime(_df['date']).dt.year ==year_target, 'TREATMENT_ID'] \ + .apply(pd.Series) \ + .sum() \ + .astype(int) + return _counts_by_treatment_id.groupby(level=0).sum() + + def get_counts_of_hsi_by_short_treatment_id(_df): + """Get the counts of the short TREATMENT_IDs occurring (shortened, up to first underscore)""" + _counts_by_treatment_id = get_counts_of_hsi_by_treatment_id(_df) + _short_treatment_id = _counts_by_treatment_id.index.map(lambda x: x.split('_')[0] + "*") + return _counts_by_treatment_id.groupby(by=_short_treatment_id).sum() + + def get_counts_of_hsi_by_short_treatment_id_by_year(_df): + """Get the counts of the short TREATMENT_IDs occurring (shortened, up to first underscore)""" + _counts_by_treatment_id = get_counts_of_hsi_by_treatment_id_by_year(_df) + _short_treatment_id = _counts_by_treatment_id.index.map(lambda x: x.split('_')[0] + "*") + return _counts_by_treatment_id.groupby(by=_short_treatment_id).sum() + + + # Obtain parameter names for this scenario file + param_names = get_parameter_names_from_scenario_file() + print(param_names) + + # ================================================================================================ + # TIME EVOLUTION OF TOTAL DALYs + # Plot DALYs averted compared to the ``No Policy'' policy + + year_target = 2023 # This global variable will be passed to custom function + def get_num_dalys_by_year(_df): + """Return total number of DALYs (Stacked) by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year == year_target] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum().sum() + ) + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + this_min_year = 2010 + for year in range(this_min_year, max_year+1): + year_target = year + num_dalys_by_year = extract_results( + results_folder, + module='tlo.methods.healthburden', + key='dalys_stacked', + custom_generate_series=get_num_dalys_by_year, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = num_dalys_by_year + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + concatenated_df.index = concatenated_df.index.set_names(['date', 'index_original']) + concatenated_df = concatenated_df.reset_index(level='index_original',drop=True) + dalys_by_year = concatenated_df + print(dalys_by_year) + dalys_by_year.to_csv('ConvertedOutputs/Total_DALYs_with_time.csv', index=True) + + # ================================================================================================ + # Print population under each scenario + pop_model = extract_results(results_folder, + module="tlo.methods.demography", + key="population", + column="total", + index="date", + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + + pop_model.index = pop_model.index.year + pop_model = pop_model[(pop_model.index >= this_min_year) & (pop_model.index <= max_year)] + print(pop_model) + assert dalys_by_year.index.equals(pop_model.index) + assert all(dalys_by_year.columns == pop_model.columns) + pop_model.to_csv('ConvertedOutputs/Population_with_time.csv', index=True) + + # ================================================================================================ + # DALYs BROKEN DOWN BY CAUSES AND YEAR + # DALYs by cause per year + # %% Quantify the health losses associated with all interventions combined. + + year_target = 2023 # This global variable will be passed to custom function + def get_num_dalys_by_year_and_cause(_df): + """Return total number of DALYs (Stacked) by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year == year_target] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum() + ) + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + this_min_year = 2010 + for year in range(this_min_year, max_year+1): + year_target = year + num_dalys_by_year = extract_results( + results_folder, + module='tlo.methods.healthburden', + key='dalys_stacked', + custom_generate_series=get_num_dalys_by_year_and_cause, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = num_dalys_by_year #summarize(num_dalys_by_year) + + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + + concatenated_df.index = concatenated_df.index.set_names(['date', 'cause']) + + df_total = concatenated_df + df_total.to_csv('ConvertedOutputs/DALYS_by_cause_with_time.csv', index=True) + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + for year in range(min_year, max_year+1): + year_target = year + + hsi_delivered_by_year = extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='HSI_Event', + custom_generate_series=get_counts_of_hsi_by_short_treatment_id_by_year, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = hsi_delivered_by_year + + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + concatenated_df.index = concatenated_df.index.set_names(['date', 'cause']) + HSI_ran_by_year = concatenated_df + + del ALL + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + for year in range(min_year, max_year+1): + year_target = year + + hsi_not_delivered_by_year = extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='Never_ran_HSI_Event', + custom_generate_series=get_counts_of_hsi_by_short_treatment_id_by_year, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = hsi_not_delivered_by_year + + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + concatenated_df.index = concatenated_df.index.set_names(['date', 'cause']) + HSI_never_ran_by_year = concatenated_df + + HSI_never_ran_by_year = HSI_never_ran_by_year.fillna(0) #clean_df( + HSI_ran_by_year = HSI_ran_by_year.fillna(0) + HSI_total_by_year = HSI_ran_by_year.add(HSI_never_ran_by_year, fill_value=0) + HSI_ran_by_year.to_csv('ConvertedOutputs/HSIs_ran_by_area_with_time.csv', index=True) + HSI_never_ran_by_year.to_csv('ConvertedOutputs/HSIs_never_ran_by_area_with_time.csv', index=True) + print(HSI_ran_by_year) + print(HSI_never_ran_by_year) + print(HSI_total_by_year) + +if __name__ == "__main__": + rfp = Path('resources') + + parser = argparse.ArgumentParser( + description="Produce plots to show the impact each set of treatments", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--output-path", + help=( + "Directory to write outputs to. If not specified (set to None) outputs " + "will be written to value of --results-path argument." + ), + type=Path, + default=None, + required=False, + ) + parser.add_argument( + "--resources-path", + help="Directory containing resource files", + type=Path, + default=Path('resources'), + required=False, + ) + parser.add_argument( + "--results-path", + type=Path, + help=( + "Directory containing results from running " + "src/scripts/analysis_data_generation/scenario_generate_chains.py " + ), + default=None, + required=False + ) + args = parser.parse_args() + assert args.results_path is not None + results_path = args.results_path + + output_path = results_path if args.output_path is None else args.output_path + + apply( + results_folder=results_path, + output_folder=output_path, + resourcefilepath=args.resources_path + ) diff --git a/src/scripts/analysis_data_generation/postprocess_events_chain.py b/src/scripts/analysis_data_generation/postprocess_events_chain.py new file mode 100644 index 0000000000..96c27a04b1 --- /dev/null +++ b/src/scripts/analysis_data_generation/postprocess_events_chain.py @@ -0,0 +1,156 @@ +import pandas as pd +from dateutil.relativedelta import relativedelta + +# Remove from every individual's event chain all events that were fired after death +def cut_off_events_after_death(df): + + events_chain = df.groupby('person_ID') + + filtered_data = pd.DataFrame() + + for name, group in events_chain: + + # Find the first non-NaN 'date_of_death' and its index + first_non_nan_index = group['date_of_death'].first_valid_index() + + if first_non_nan_index is not None: + # Filter out all rows after the first non-NaN index + filtered_group = group.loc[:first_non_nan_index] # Keep rows up to and including the first valid index + filtered_data = pd.concat([filtered_data, filtered_group]) + else: + # If there are no non-NaN values, keep the original group + filtered_data = pd.concat([filtered_data, group]) + + return filtered_data + +# Load into DataFrame +def load_csv_to_dataframe(file_path): + try: + # Load raw chains into df + df = pd.read_csv(file_path) + print("Raw event chains loaded successfully!") + return df + except FileNotFoundError: + print(f"Error: The file '{file_path}' was not found.") + except Exception as e: + print(f"An error occurred: {e}") + +file_path = 'output.csv' # Replace with the path to your CSV file + +output = load_csv_to_dataframe(file_path) + +# Some of the dates appeared not to be in datetime format. Correct here. +output['date_of_death'] = pd.to_datetime(output['date_of_death'], errors='coerce') +output['date_of_birth'] = pd.to_datetime(output['date_of_birth'], errors='coerce') +if 'hv_date_inf' in output.columns: + output['hv_date_inf'] = pd.to_datetime(output['hv_date_inf'], errors='coerce') + + +date_start = pd.to_datetime('2010-01-01') +if 'Other' in output['cause_of_death'].values: + print("ERROR: 'Other' was included in sim as possible cause of death") + exit(-1) + +# Choose which columns in individual properties to visualise +columns_to_print =['event','is_alive','hv_inf', 'hv_art','tb_inf', 'tb_date_active', 'event_date', 'when'] +#columns_to_print =['person_ID', 'date_of_birth', 'date_of_death', 'cause_of_death','hv_date_inf', 'hv_art','tb_inf', 'tb_date_active', 'event date', 'event'] + +# When checking which individuals led to *any* changes in individual properties, exclude these columns from comparison +columns_to_exclude_in_comparison = ['when', 'event', 'event_date', 'age_exact_years', 'age_years', 'age_days', 'age_range', 'level', 'appt_footprint'] + +# If considering epidemiology consistent with sim, add check here. +check_ages_of_those_HIV_inf = False +if check_ages_of_those_HIV_inf: + for index, row in output.iterrows(): + if pd.isna(row['hv_date_inf']): + continue # Skip this iteration + diff = relativedelta(output.loc[index, 'hv_date_inf'],output.loc[index, 'date_of_birth']) + if diff.years > 1 and diff.years<15: + print("Person contracted HIV infection at age younger than 15", diff) + +# Remove events after death +filtered_data = cut_off_events_after_death(output) + +print_raw_events = True # Print raw chain of events for each individual +print_selected_changes = False +print_all_changes = True +person_ID_of_interest = 494 + +pd.set_option('display.max_rows', None) + +for name, group in filtered_data.groupby('person_ID'): + list_of_dob = group['date_of_birth'] + + # Select individuals based on when they were born + if list_of_dob.iloc[0].year<2010: + + # Check that immutable properties are fixed for this individual, i.e. that events were collated properly: + all_identical_dob = group['date_of_birth'].nunique() == 1 + all_identical_sex = group['sex'].nunique() == 1 + if all_identical_dob is False or all_identical_sex is False: + print("Immutable properties are changing! This is not chain for single individual") + print(group) + exit(-1) + + print("----------------------------------------------------------------------") + print("person_ID ", group['person_ID'].iloc[0], "d.o.b ", group['date_of_birth'].iloc[0]) + print("Number of events for this individual ", group['person_ID'].iloc[0], "is :", len(group)/2) # Divide by 2 before printing Before/After for each event + number_of_events =len(group)/2 + number_of_changes=0 + if print_raw_events: + print(group) + + if print_all_changes: + # Check each row + comparison = group.drop(columns=columns_to_exclude_in_comparison).fillna(-99999).ne(group.drop(columns=columns_to_exclude_in_comparison).shift().fillna(-99999)) + + # Iterate over rows where any column has changed + for idx, row_changed in comparison.iloc[1:].iterrows(): + if row_changed.any(): # Check if any column changed in this row + number_of_changes+=1 + changed_columns = row_changed[row_changed].index.tolist() # Get the columns where changes occurred + print(f"Row {idx} - Changes detected in columns: {changed_columns}") + columns_output = ['event', 'event_date', 'appt_footprint', 'level'] + changed_columns + print(group.loc[idx, columns_output]) # Print only the changed columns + if group.loc[idx, 'when'] == 'Before': + print('-----> THIS CHANGE OCCURRED BEFORE EVENT!') + #print(group.loc[idx,columns_to_print]) + print() # For better readability + print("Number of changes is ", number_of_changes, "out of ", number_of_events, " events") + + if print_selected_changes: + tb_inf_condition = ( + ((group['tb_inf'].shift(1) == 'uninfected') & (group['tb_inf'] == 'active')) | + ((group['tb_inf'].shift(1) == 'latent') & (group['tb_inf'] == 'active')) | + ((group['tb_inf'].shift(1) == 'active') & (group['tb_inf'] == 'latent')) | + ((group['hv_inf'].shift(1) is False) & (group['hv_inf'] is True)) | + ((group['hv_art'].shift(1) == 'not') & (group['hv_art'] == 'on_not_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'not') & (group['hv_art'] == 'on_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'on_VL_suppressed') & (group['hv_art'] == 'on_not_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'on_VL_suppressed') & (group['hv_art'] == 'not')) | + ((group['hv_art'].shift(1) == 'on_not_VL_suppressed') & (group['hv_art'] == 'on_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'on_not_VL_suppressed') & (group['hv_art'] == 'not')) + ) + + alive_condition = ( + (group['is_alive'].shift(1) is True) & (group['is_alive'] is False) + ) + # Combine conditions for rows of interest + transition_condition = tb_inf_condition | alive_condition + + if list_of_dob.iloc[0].year >= 2010: + print("DETECTED OF INTEREST") + print(group[group['event'] == 'Birth'][columns_to_print]) + + # Filter the DataFrame based on the condition + filtered_transitions = group[transition_condition] + if not filtered_transitions.empty: + if list_of_dob.iloc[0].year < 2010: + print("DETECTED OF INTEREST") + print(filtered_transitions[columns_to_print]) + + +print("Number of individuals simulated ", filtered_data.groupby('person_ID').ngroups) + + + diff --git a/src/scripts/analysis_data_generation/scenario_generate_chains.py b/src/scripts/analysis_data_generation/scenario_generate_chains.py new file mode 100644 index 0000000000..6bdcd02d90 --- /dev/null +++ b/src/scripts/analysis_data_generation/scenario_generate_chains.py @@ -0,0 +1,115 @@ +"""This Scenario file run the model to generate event chans + +Run on the batch system using: +``` +tlo batch-submit + src/scripts/analysis_data_generation/scenario_generate_chains.py +``` + +or locally using: +``` + tlo scenario-run src/scripts/analysis_data_generation/scenario_generate_chains.py +``` + +""" +from pathlib import Path +from typing import Dict + +import pandas as pd + +from tlo import Date, logging +from tlo.analysis.utils import get_parameters_for_status_quo, mix_scenarios +from tlo.methods.fullmodel import fullmodel +from tlo.methods.scenario_switcher import ImprovedHealthSystemAndCareSeekingScenarioSwitcher +from tlo.scenario import BaseScenario + + +class GenerateDataChains(BaseScenario): + def __init__(self): + super().__init__() + self.seed = 0 + self.start_date = Date(2010, 1, 1) + self.end_date = self.start_date + pd.DateOffset(months=1) + self.pop_size = 120 + self._scenarios = self._get_scenarios() + self.number_of_draws = len(self._scenarios) + self.runs_per_draw = 1 + self.generate_event_chains = True + + def log_configuration(self): + return { + 'filename': 'generate_event_chains', + 'directory': Path('./outputs'), # <- (specified only for local running) + 'custom_levels': { + '*': logging.WARNING, + 'tlo.methods.demography': logging.INFO, + 'tlo.methods.events': logging.INFO, + 'tlo.methods.demography.detail': logging.WARNING, + 'tlo.methods.healthburden': logging.INFO, + 'tlo.methods.healthsystem.summary': logging.INFO, + } + } + + def modules(self): + return ( + fullmodel(resourcefilepath=self.resources) + + [ImprovedHealthSystemAndCareSeekingScenarioSwitcher(resourcefilepath=self.resources)] + ) + + def draw_parameters(self, draw_number, rng): + if draw_number < self.number_of_draws: + return list(self._scenarios.values())[draw_number] + else: + return + + # case 1: gfHE = -0.030, factor = 1.01074 + # case 2: gfHE = -0.020, factor = 1.02116 + # case 3: gfHE = -0.015, factor = 1.02637 + # case 4: gfHE = 0.015, factor = 1.05763 + # case 5: gfHE = 0.020, factor = 1.06284 + # case 6: gfHE = 0.030, factor = 1.07326 + + def _get_scenarios(self) -> Dict[str, Dict]: + """Return the Dict with values for the parameters that are changed, keyed by a name for the scenario. + """ + + self.YEAR_OF_CHANGE = 2019 + + return { + + # =========== STATUS QUO ============ + "Baseline": + mix_scenarios( + self._baseline(), + { + "HealthSystem": { + "yearly_HR_scaling_mode": "no_scaling", + }, + } + ), + + } + + def _baseline(self) -> Dict: + """Return the Dict with values for the parameter changes that define the baseline scenario. """ + return mix_scenarios( + get_parameters_for_status_quo(), + { + "HealthSystem": { + "mode_appt_constraints": 1, # <-- Mode 1 prior to change to preserve calibration + "mode_appt_constraints_postSwitch": 2, # <-- Mode 2 post-change to show effects of HRH + "year_mode_switch": self.YEAR_OF_CHANGE, + "scale_to_effective_capabilities": True, + "policy_name": "Naive", + "tclose_overwrite": 1, + "tclose_days_offset_overwrite": 7, + "use_funded_or_actual_staffing": "actual", + "cons_availability": "default", + } + }, + ) + +if __name__ == '__main__': + from tlo.cli import scenario_run + + scenario_run([__file__]) diff --git a/src/tlo/events.py b/src/tlo/events.py index 9dd34c9448..00a6fe4e7d 100644 --- a/src/tlo/events.py +++ b/src/tlo/events.py @@ -4,11 +4,26 @@ from enum import Enum from typing import TYPE_CHECKING -from tlo import DateOffset +from tlo import DateOffset, logging if TYPE_CHECKING: from tlo import Simulation +import pandas as pd + +FACTOR_POP_DICT = 5000 + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +logger_chain = logging.getLogger('tlo.simulation') +logger_chain.setLevel(logging.INFO) + +logger_summary = logging.getLogger(f"{__name__}.summary") +logger_summary.setLevel(logging.INFO) + +debug_chains = True class Priority(Enum): """Enumeration for the Priority, which is used in sorting the events in the simulation queue.""" @@ -60,11 +75,174 @@ def apply(self, target): :param target: the target of the event """ raise NotImplementedError + + def compare_population_dataframe(self,df_before, df_after): + """ This function compares the population dataframe before/after a population-wide event has occurred. + It allows us to identify the individuals for which this event led to a significant (i.e. property) change, and to store the properties which have changed as a result of it. """ + + # Create a mask of where values are different + diff_mask = (df_before != df_after) & ~(df_before.isna() & df_after.isna()) + + # Create an empty list to store changes for each of the individuals + chain_links = {} + len_of_diff = len(diff_mask) + + # Loop through each row of the mask + + for idx, row in diff_mask.iterrows(): + changed_cols = row.index[row].tolist() + + if changed_cols: # Proceed only if there are changes in the row + # Create a dictionary for this person + # First add event info + link_info = { + 'person_ID': idx, + 'event': str(self), + 'event_date': self.sim.date, + } + + # Store the new values from df_after for the changed columns + for col in changed_cols: + link_info[col] = df_after.at[idx, col] + + # Append the event and changes to the individual key + chain_links[idx] = str(link_info) + + return chain_links + + def store_chains_to_do_before_event(self) -> tuple[bool, pd.Series, pd.DataFrame]: + """ This function checks whether this event should be logged as part of the event chains, and if so stored required information before the event has occurred. """ + + # Initialise these variables + print_chains = False + df_before = [] + row_before = pd.Series() + + # Only print event if it belongs to modules of interest and if it is not in the list of events to ignore + #if (self.module in self.sim.generate_event_chains_modules_of_interest) and .. + if all(sub not in str(self) for sub in self.sim.generate_event_chains_ignore_events): + + # Will eventually use this once I can actually GET THE NAME OF THE SELF + #if not set(self.sim.generate_event_chains_ignore_events).intersection(str(self)): + + print_chains = True + + # Target is single individual + if self.target != self.sim.population: + # Save row for comparison after event has occurred + row_before = self.sim.population.props.loc[abs(self.target)].copy().fillna(-99999) + + if debug_chains: + # TO BE REMOVED This is currently just used for debugging. Will be removed from final version of PR. + row = self.sim.population.props.loc[[abs(self.target)]] + row['person_ID'] = self.target + row['event'] = str(self) + row['event_date'] = self.sim.date + row['when'] = 'Before' + self.sim.event_chains = pd.concat([self.sim.event_chains, row], ignore_index=True) + + else: + # This will be a population-wide event. In order to find individuals for which this led to + # a meaningful change, make a copy of the pop dataframe before the event has occurred. + df_before = self.sim.population.props.copy() + + return print_chains, row_before, df_before + + def store_chains_to_do_after_event(self, print_chains, row_before, df_before) -> dict: + """ If print_chains=True, this function logs the event and identifies and logs the any property changes that have occured to one or multiple individuals as a result of the event taking place. """ + + chain_links = {} + + if print_chains: + + # Target is single individual + if self.target != self.sim.population: + row_after = self.sim.population.props.loc[abs(self.target)].fillna(-99999) + + # Create and store event for this individual, regardless of whether any property change occurred + link_info = { + #'person_ID' : self.target, + 'person_ID' : self.target, + 'event' : str(self), + 'event_date' : self.sim.date, + } + # Store (if any) property changes as a result of the event for this individual + for key in row_before.index: + if row_before[key] != row_after[key]: # Note: used fillna previously + link_info[key] = row_after[key] + + chain_links[self.target] = str(link_info) + + # TO BE REMOVED This is currently just used for debugging. Will be removed from final version of PR. + if debug_chains: + # Print entire row + row = self.sim.population.props.loc[[abs(self.target)]] # Use abs to avoid potentil issue with direct births + row['person_ID'] = self.target + row['event'] = str(self) + row['event_date'] = self.sim.date + row['when'] = 'After' + self.sim.event_chains = pd.concat([self.sim.event_chains, row], ignore_index=True) + + else: + # Target is entire population. Identify individuals for which properties have changed + # and store their changes. + + # Population frame after event + df_after = self.sim.population.props + + # Create and store the event and dictionary of changes for affected individuals + chain_links = self.compare_population_dataframe(df_before, df_after) + + # TO BE REMOVED This is currently just used for debugging. Will be removed from final version of PR. + if debug_chains: + # Or print entire rows + change = df_before.compare(df_after) + if not change.empty: + indices = change.index + new_rows_before = df_before.loc[indices] + new_rows_before['person_ID'] = new_rows_before.index + new_rows_before['event'] = self + new_rows_before['event_date'] = self.sim.date + new_rows_before['when'] = 'Before' + + new_rows_after = df_after.loc[indices] + new_rows_after['person_ID'] = new_rows_after.index + new_rows_after['event'] = self + new_rows_after['event_date'] = self.sim.date + new_rows_after['when'] = 'After' + + self.sim.event_chains = pd.concat([self.sim.event_chains,new_rows_before], ignore_index=True) + self.sim.event_chains = pd.concat([self.sim.event_chains,new_rows_after], ignore_index=True) + + return chain_links def run(self): """Make the event happen.""" + + # Collect relevant information before event takes place + if self.sim.generate_event_chains: + print_chains, row_before, df_before = self.store_chains_to_do_before_event() + self.apply(self.target) self.post_apply_hook() + + # Collect event info + meaningful property changes of individuals. Combined, these will constitute a 'link' + # in the individual's event chain. + if self.sim.generate_event_chains: + chain_links = self.store_chains_to_do_after_event(print_chains, row_before, df_before) + + # Create empty logger for entire pop + pop_dict = {i: '' for i in range(FACTOR_POP_DICT)} # Always include all possible individuals + + pop_dict.update(chain_links) + + # Log chain_links here + if len(chain_links)>0: + logger_chain.info(key='event_chains', + data= pop_dict, + description='Links forming chains of events for simulated individuals') + + #print("Chain events ", chain_links) class RegularEvent(Event): diff --git a/src/tlo/methods/demography.py b/src/tlo/methods/demography.py index e58f3895f4..4f19af6d55 100644 --- a/src/tlo/methods/demography.py +++ b/src/tlo/methods/demography.py @@ -315,9 +315,10 @@ def initialise_simulation(self, sim): # Launch the repeating event that will store statistics about the population structure sim.schedule_event(DemographyLoggingEvent(self), sim.date) - # Create (and store pointer to) the OtherDeathPoll and schedule first occurrence immediately - self.other_death_poll = OtherDeathPoll(self) - sim.schedule_event(self.other_death_poll, sim.date) + if sim.generate_event_chains is False: + # Create (and store pointer to) the OtherDeathPoll and schedule first occurrence immediately + self.other_death_poll = OtherDeathPoll(self) + sim.schedule_event(self.other_death_poll, sim.date) # Log the initial population scaling-factor (to the logger of this module and that of `tlo.methods.population`) for _logger in (logger, logger_scale_factor): diff --git a/src/tlo/methods/healthsystem.py b/src/tlo/methods/healthsystem.py index 5c6b2022e1..57e050de7b 100644 --- a/src/tlo/methods/healthsystem.py +++ b/src/tlo/methods/healthsystem.py @@ -2457,7 +2457,8 @@ def process_events_mode_2(self, hold_over: List[HSIEventQueueItem]) -> None: # Expected appt footprint before running event _appt_footprint_before_running = event.EXPECTED_APPT_FOOTPRINT - # Run event & get actual footprint + + # Run the HSI event (allowing it to return an updated appt_footprint) actual_appt_footprint = event.run(squeeze_factor=squeeze_factor) # Check if the HSI event returned updated_appt_footprint, and if so adjust original_call diff --git a/src/tlo/methods/hiv.py b/src/tlo/methods/hiv.py index d6455cc861..391cf587a8 100644 --- a/src/tlo/methods/hiv.py +++ b/src/tlo/methods/hiv.py @@ -631,11 +631,12 @@ def initialise_population(self, population): df.loc[df.is_alive, "hv_date_treated"] = pd.NaT df.loc[df.is_alive, "hv_date_last_ART"] = pd.NaT - # Launch sub-routines for allocating the right number of people into each category - self.initialise_baseline_prevalence(population) # allocate baseline prevalence + if self.sim.generate_event_chains is False or self.sim.generate_event_chains is None or self.sim.generate_event_chains_overwrite_epi is False: + # Launch sub-routines for allocating the right number of people into each category + self.initialise_baseline_prevalence(population) # allocate baseline prevalence - self.initialise_baseline_art(population) # allocate baseline art coverage - self.initialise_baseline_tested(population) # allocate baseline testing coverage + self.initialise_baseline_art(population) # allocate baseline art coverage + self.initialise_baseline_tested(population) # allocate baseline testing coverage def initialise_baseline_prevalence(self, population): """ @@ -905,10 +906,16 @@ def initialise_simulation(self, sim): df = sim.population.props p = self.parameters - # 1) Schedule the Main HIV Regular Polling Event - sim.schedule_event( - HivRegularPollingEvent(self), sim.date + DateOffset(days=0) - ) + if self.sim.generate_event_chains is True and self.sim.generate_event_chains_overwrite_epi: + print("Should be generating data") + sim.schedule_event( + HivPollingEventForDataGeneration(self), sim.date + DateOffset(days=0) + ) + else: + # 1) Schedule the Main HIV Regular Polling Event + sim.schedule_event( + HivRegularPollingEvent(self), sim.date + DateOffset(days=0) + ) # 2) Schedule the Logging Event sim.schedule_event(HivLoggingEvent(self), sim.date + DateOffset(years=1)) @@ -1662,6 +1669,37 @@ def do_at_generic_first_appt( # Main Polling Event # --------------------------------------------------------------------------- +class HivPollingEventForDataGeneration(RegularEvent, PopulationScopeEventMixin): + """ The HIV Polling Events for Data Generation + * Ensures that + """ + + def __init__(self, module): + super().__init__( + module, frequency=DateOffset(years=120) + ) # repeats every 12 months, but this can be changed + + def apply(self, population): + + df = population.props + + # Make everyone who is alive and not infected (no-one should be) susceptible + susc_idx = df.loc[ + df.is_alive + & ~df.hv_inf + ].index + + n_susceptible = len(susc_idx) + print("Number of individuals susceptible", n_susceptible) + # Schedule the date of infection for each new infection: + for i in susc_idx: + date_of_infection = self.sim.date + pd.DateOffset( + # Ensure that individual will be infected before end of sim + days=self.module.rng.randint(0, 365*(int(self.sim.end_date.year - self.sim.date.year)+1)) + ) + self.sim.schedule_event( + HivInfectionEvent(self.module, i), date_of_infection + ) class HivRegularPollingEvent(RegularEvent, PopulationScopeEventMixin): """ The HIV Regular Polling Events @@ -1683,6 +1721,7 @@ def apply(self, population): fraction_of_year_between_polls = self.frequency.months / 12 beta = p["beta"] * fraction_of_year_between_polls + # ----------------------------------- HORIZONTAL TRANSMISSION ----------------------------------- def horizontal_transmission(to_sex, from_sex): # Count current number of alive 15-80 year-olds at risk of transmission @@ -1758,6 +1797,7 @@ def horizontal_transmission(to_sex, from_sex): HivInfectionEvent(self.module, idx), date_of_infection ) + # ----------------------------------- SPONTANEOUS TESTING ----------------------------------- def spontaneous_testing(current_year): @@ -1861,11 +1901,12 @@ def vmmc_for_child(): priority=0, ) - # Horizontal transmission: Male --> Female - horizontal_transmission(from_sex="M", to_sex="F") + if self.sim.generate_event_chains is False or self.sim.generate_event_chains is None or self.sim.generate_event_chains_overwrite_epi is False: + # Horizontal transmission: Male --> Female + horizontal_transmission(from_sex="M", to_sex="F") - # Horizontal transmission: Female --> Male - horizontal_transmission(from_sex="F", to_sex="M") + # Horizontal transmission: Female --> Male + horizontal_transmission(from_sex="F", to_sex="M") # testing # if year later than 2020, set testing rates to those reported in 2020 @@ -1882,6 +1923,8 @@ def vmmc_for_child(): vmmc_for_child() + + # --------------------------------------------------------------------------- # Natural History Events # --------------------------------------------------------------------------- diff --git a/src/tlo/methods/hsi_event.py b/src/tlo/methods/hsi_event.py index b76a865d2d..d657e9d3a0 100644 --- a/src/tlo/methods/hsi_event.py +++ b/src/tlo/methods/hsi_event.py @@ -9,16 +9,28 @@ from tlo.events import Event from tlo.population import Population +import pandas as pd + +FACTOR_POP_DICT = 5000 + + if TYPE_CHECKING: from tlo import Module, Simulation from tlo.methods.healthsystem import HealthSystem +# Pointing to the logger in events +logger_chains = logging.getLogger("tlo.simulation") +logger_chains.setLevel(logging.INFO) + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger_summary = logging.getLogger(f"{__name__}.summary") logger_summary.setLevel(logging.INFO) +debug_chains = True + + # Declare the level which will be used to represent the merging of levels '1b' and '2' LABEL_FOR_MERGED_FACILITY_LEVELS_1B_AND_2 = "2" @@ -184,13 +196,138 @@ def _run_after_hsi_event(self) -> None: item_codes=self._EQUIPMENT, facility_id=self.facility_info.id ) + + def store_chains_to_do_before_event(self) -> tuple[bool, pd.Series]: + """ This function checks whether this event should be logged as part of the event chains, and if so stored required information before the event has occurred. """ + + # Initialise these variables + print_chains = False + row_before = pd.Series() + + # Only print event if it belongs to modules of interest and if it is not in the list of events to ignore + # if (self.module in self.sim.generate_event_chains_modules_of_interest) and + if all(sub not in str(self) for sub in self.sim.generate_event_chains_ignore_events): + + # Will eventually use this once I can actually GET THE NAME OF THE SELF + # if not set(self.sim.generate_event_chains_ignore_events).intersection(str(self)): + + if self.target != self.sim.population: + + # In the case of HSI events, only individual events should exist and therefore be logged + print_chains = True + + # Save row for comparison after event has occurred + row_before = self.sim.population.props.loc[abs(self.target)].copy().fillna(-99999) + + # TO BE REMOVED This is currently just used for debugging. Will be removed from final version of PR. + row = self.sim.population.props.loc[[abs(self.target)]] + row['person_ID'] = self.target + row['event'] = str(self) + row['event_date'] = self.sim.date + row['when'] = 'Before' + try: + row['appt_footprint'] = str(self.EXPECTED_APPT_FOOTPRINT) + row['level'] = self.facility_info.level + except: + row['appt_footprint'] = 'N/A' + row['level'] = 'N/A' + self.sim.event_chains = pd.concat([self.sim.event_chains, row], ignore_index=True) + + else: + # Once this has been removed from Chronic Syndrome mock module, make this a Runtime Error + # raise RuntimeError("Cannot have population-wide HSI events") + logger.debug( + key="message", + data=( + "Cannot have population-wide HSI events" + ), + ) + + + return print_chains, row_before + + def store_chains_to_do_after_event(self, print_chains, row_before, footprint) -> dict: + """ If print_chains=True, this function logs the event and identifies and logs the any property changes that have occured to one or multiple individuals as a result of the event taking place. """ + if print_chains: + # For HSI event, this will only ever occur for individual events + + row_after = self.sim.population.props.loc[abs(self.target)].fillna(-99999) + + # Create and store dictionary of changes. Note that person_ID, event, event_date, appt_foot, and level + # will be stored regardless of whether individual experienced property changes. + + # Add event details + + try: + record_footprint = str(footprint) + record_level = self.facility_info.level + except: + record_footprint = 'N/A' + record_level = 'N/A' + + link_info = { + 'person_ID': self.target, + 'event' : str(self), + 'event_date' : self.sim.date, + 'appt_footprint' : record_footprint, + 'level' : record_level, + } + + # Add changes to properties + for key in row_before.index: + if row_before[key] != row_after[key]: # Note: used fillna previously + link_info[key] = row_after[key] + + chain_links = {self.target : str(link_info)} + + # TO BE REMOVED This is currently just used for debugging. Will be removed from final version of PR. + row = self.sim.population.props.loc[[abs(self.target)]] + row['person_ID'] = self.target + row['event'] = str(self) + row['event_date'] = self.sim.date + row['when'] = 'After' + row['appt_footprint'] = record_footprint + row['level'] = record_level + self.sim.event_chains = pd.concat([self.sim.event_chains, row], ignore_index=True) + + return chain_links + def run(self, squeeze_factor): """Make the event happen.""" + + + if self.sim.generate_event_chains and self.target != self.sim.population: + print_chains, row_before = self.store_chains_to_do_before_event() + + footprint = self.EXPECTED_APPT_FOOTPRINT + updated_appt_footprint = self.apply(self.target, squeeze_factor) self.post_apply_hook() self._run_after_hsi_event() + + + if self.sim.generate_event_chains and self.target != self.sim.population: + + # If the footprint has been updated when the event ran, change it here + if updated_appt_footprint is not None: + footprint = updated_appt_footprint + + chain_links = self.store_chains_to_do_after_event(print_chains, row_before, str(footprint)) + + if len(chain_links)>0: + + pop_dict = {i: '' for i in range(FACTOR_POP_DICT)} + # pop_dict = {i: '' for i in range(1000)} # Always include all possible individuals + + pop_dict.update(chain_links) + + logger_chains.info(key='event_chains', + data = pop_dict, + description='Links forming chains of events for simulated individuals') + return updated_appt_footprint + def get_consumables( self, diff --git a/src/tlo/methods/rti.py b/src/tlo/methods/rti.py index 18c1987483..3642365976 100644 --- a/src/tlo/methods/rti.py +++ b/src/tlo/methods/rti.py @@ -2776,7 +2776,7 @@ class RTIPollingEvent(RegularEvent, PopulationScopeEventMixin): def __init__(self, module): """Schedule to take place every month """ - super().__init__(module, frequency=DateOffset(months=1)) + super().__init__(module, frequency=DateOffset(months=1000)) p = module.parameters # Parameters which transition the model between states self.base_1m_prob_rti = (p['base_rate_injrti'] / 12) @@ -2864,9 +2864,13 @@ def apply(self, population): .when('.between(70,79)', self.rr_injrti_age7079), Predictor('li_ex_alc').when(True, self.rr_injrti_excessalcohol) ) - pred = eq.predict(df.loc[rt_current_non_ind]) + if self.sim.generate_event_chains is True and self.sim.generate_event_chains_overwrite_epi is True: + pred = 1.0 + else: + pred = eq.predict(df.loc[rt_current_non_ind]) random_draw_in_rti = self.module.rng.random_sample(size=len(rt_current_non_ind)) selected_for_rti = rt_current_non_ind[pred > random_draw_in_rti] + # Update to say they have been involved in a rti df.loc[selected_for_rti, 'rt_road_traffic_inc'] = True # Set the date that people were injured to now diff --git a/src/tlo/methods/tb.py b/src/tlo/methods/tb.py index 623ee2e483..9dc05ff301 100644 --- a/src/tlo/methods/tb.py +++ b/src/tlo/methods/tb.py @@ -832,29 +832,31 @@ def initialise_population(self, population): df["tb_on_ipt"] = False df["tb_date_ipt"] = pd.NaT - # # ------------------ infection status ------------------ # - # WHO estimates of active TB for 2010 to get infected initial population - # don't need to scale or include treated proportion as no-one on treatment yet - inc_estimates = p["who_incidence_estimates"] - incidence_year = (inc_estimates.loc[ - (inc_estimates.year == self.sim.date.year), "incidence_per_100k" - ].values[0]) / 100_000 - - incidence_year = incidence_year * p["scaling_factor_WHO"] - self.assign_active_tb( - population, - strain="ds", - incidence=incidence_year) - - self.assign_active_tb( - population, - strain="mdr", - incidence=incidence_year * p['prop_mdr2010']) - - self.send_for_screening_general( - population - ) # send some baseline population for screening + # # ------------------ infection status ------------------ # + if self.sim.generate_event_chains is False or self.sim.generate_event_chains is None: + # WHO estimates of active TB for 2010 to get infected initial population + # don't need to scale or include treated proportion as no-one on treatment yet + inc_estimates = p["who_incidence_estimates"] + incidence_year = (inc_estimates.loc[ + (inc_estimates.year == self.sim.date.year), "incidence_per_100k" + ].values[0]) / 100_000 + + incidence_year = incidence_year * p["scaling_factor_WHO"] + + self.assign_active_tb( + population, + strain="ds", + incidence=incidence_year) + + self.assign_active_tb( + population, + strain="mdr", + incidence=incidence_year * p['prop_mdr2010']) + + self.send_for_screening_general( + population + ) # send some baseline population for screening def initialise_simulation(self, sim): """ @@ -867,7 +869,12 @@ def initialise_simulation(self, sim): sim.schedule_event(TbActiveEvent(self), sim.date) sim.schedule_event(TbRegularEvents(self), sim.date) sim.schedule_event(TbSelfCureEvent(self), sim.date) - sim.schedule_event(TbActiveCasePoll(self), sim.date + DateOffset(years=1)) + + if sim.generate_event_chains is True and sim.generate_event_chains_overwrite_epi is True: + sim.schedule_event(TbActiveCasePollGenerateData(self), sim.date + DateOffset(days=0)) + else: + sim.schedule_event(TbActiveCasePoll(self), sim.date + DateOffset(years=1)) + # 2) log at the end of the year # Optional: Schedule the scale-up of programs @@ -1366,6 +1373,53 @@ def is_subset(col_for_set, col_for_subset): # # TB infection event # # --------------------------------------------------------------------------- +class TbActiveCasePollGenerateData(RegularEvent, PopulationScopeEventMixin): + """The Tb Regular Poll Event for Data Generation for assigning active infections + * selects everyone to develop an active infection and schedules onset of active tb + sometime during the simulation + """ + + def __init__(self, module): + super().__init__(module, frequency=DateOffset(years=120)) + + def apply(self, population): + + df = population.props + now = self.sim.date + rng = self.module.rng + # Make everyone who is alive and not infected (no-one should be) susceptible + susc_idx = df.loc[ + df.is_alive + & (df.tb_inf != "active") + ].index + + len(susc_idx) + + middle_index = len(susc_idx) // 2 + + # Will equally split two strains among the population + list_ds = susc_idx[:middle_index] + list_mdr = susc_idx[middle_index:] + + # schedule onset of active tb. This will be equivalent to the "Onset", so it + # doesn't matter how long after we have decided which infection this is. + for person_id in list_ds: + date_progression = now + pd.DateOffset( + # At some point during their lifetime, this person will develop TB + days=self.module.rng.randint(0, 365*(int(self.sim.end_date.year - self.sim.date.year)+1)) + ) + # set date of active tb - properties will be updated at TbActiveEvent poll daily + df.at[person_id, "tb_scheduled_date_active"] = date_progression + df.at[person_id, "tb_strain"] = "ds" + + for person_id in list_mdr: + date_progression = now + pd.DateOffset( + days=rng.randint(0, 365*int(self.sim.end_date.year - self.sim.start_date.year + 1)) + ) + # set date of active tb - properties will be updated at TbActiveEvent poll daily + df.at[person_id, "tb_scheduled_date_active"] = date_progression + df.at[person_id, "tb_strain"] = "mdr" + class TbActiveCasePoll(RegularEvent, PopulationScopeEventMixin): """The Tb Regular Poll Event for assigning active infections @@ -1439,7 +1493,6 @@ def apply(self, population): self.module.update_parameters_for_program_scaleup() - class TbActiveEvent(RegularEvent, PopulationScopeEventMixin): """ * check for those with dates of active tb onset within last time-period diff --git a/src/tlo/simulation.py b/src/tlo/simulation.py index 547edf1d23..d9ba62c43a 100644 --- a/src/tlo/simulation.py +++ b/src/tlo/simulation.py @@ -8,7 +8,9 @@ import time from collections import OrderedDict from pathlib import Path +from typing import Optional from typing import TYPE_CHECKING, Optional +import pandas as pd import numpy as np @@ -35,6 +37,11 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +logger_chains = logging.getLogger("tlo.methods.event") +logger_chains.setLevel(logging.INFO) + +FACTOR_POP_DICT = 5000 + class SimulationPreviouslyInitialisedError(Exception): """Exception raised when trying to initialise an already initialised simulation.""" @@ -102,9 +109,16 @@ def __init__( self.date = self.start_date = start_date self.modules = OrderedDict() self.event_queue = EventQueue() + self.generate_event_chains = True + self.generate_event_chains_overwrite_epi = None + self.generate_event_chains_modules_of_interest = [] + self.generate_event_chains_ignore_events = [] self.end_date = None self.output_file = None self.population: Optional[Population] = None + + # TO BE REMOVED This is currently just used for debugging. Will be removed from final version of PR. + self.event_chains: Optinoal[Population] = None self.show_progress_bar = show_progress_bar self.resourcefilepath = resourcefilepath @@ -274,12 +288,31 @@ def make_initial_population(self, *, n: int) -> None: data=f"{module.name}.initialise_population() {time.time() - start1} s", ) + # TO BE REMOVED This is currently just used for debugging. Will be removed from final version of PR. + self.event_chains = pd.DataFrame(columns= list(self.population.props.columns)+['person_ID'] + ['event'] + ['event_date'] + ['when'] + ['appt_footprint'] + ['level']) + + # When logging events for each individual to reconstruct chains, only the changes in individual properties will be logged. + # At the start of the simulation + when a new individual is born, we therefore want to store all of their properties at the start. + if self.generate_event_chains: + + pop_dict = self.population.props.to_dict(orient='index') + for key in pop_dict.keys(): + pop_dict[key]['person_ID'] = key + pop_dict[key] = str(pop_dict[key]) # Log as string to avoid issues around length of properties stored later + + pop_dict_full = {i: '' for i in range(FACTOR_POP_DICT)} + pop_dict_full.update(pop_dict) + + print("Size for full sim", len(pop_dict_full)) + + logger.info(key='event_chains', + data = pop_dict_full, + description='Links forming chains of events for simulated individuals') end = time.time() logger.info(key="info", data=f"make_initial_population() {end - start} s") def initialise(self, *, end_date: Date) -> None: """Initialise all modules in simulation. - :param end_date: Date to end simulation on - accessible to modules to allow initialising data structures which may depend (in size for example) on the date range being simulated. @@ -289,6 +322,22 @@ def initialise(self, *, end_date: Date) -> None: raise SimulationPreviouslyInitialisedError(msg) self.date = self.start_date self.end_date = end_date # store the end_date so that others can reference it + + #self.generate_event_chains = generate_event_chains + if self.generate_event_chains: + # Eventually this can be made an option + self.generate_event_chains_overwrite_epi = True + # For now keep these fixed, eventually they will be input from user + self.generate_event_chains_modules_of_interest = [self.modules] + self.generate_event_chains_ignore_events = ['AgeUpdateEvent','HealthSystemScheduler', 'SimplifiedBirthsPoll','DirectBirth'] #['TbActiveCasePollGenerateData','HivPollingEventForDataGeneration','SimplifiedBirthsPoll', 'AgeUpdateEvent', 'HealthSystemScheduler'] + else: + # If not using to print chains, cannot ignore epi + self.generate_event_chains_overwrite_epi = False + + + # Reorder columns to place the new columns at the front + pd.set_option('display.max_columns', None) + for module in self.modules.values(): module.initialise_simulation(self) self._initialised = True @@ -350,6 +399,8 @@ def run_simulation_to(self, *, to_date: Date) -> None: :param to_date: Date to simulate up to but not including - must be before or equal to simulation end date specified in call to :py:meth:`initialise`. """ + open('output.txt', mode='a') + if not self._initialised: msg = "Simulation must be initialised before calling run_simulation_to" raise SimulationNotInitialisedError(msg) @@ -366,6 +417,10 @@ def run_simulation_to(self, *, to_date: Date) -> None: self._update_progress_bar(progress_bar, date) self.fire_single_event(event, date) self.date = to_date + + # TO BE REMOVED: this is currently only used for debugging, will be removed from final PR. + self.event_chains.to_csv('output.csv', index=False) + if self.show_progress_bar: progress_bar.stop() @@ -407,6 +462,7 @@ def fire_single_event(self, event: Event, date: Date) -> None: """ self.date = date event.run() + def do_birth(self, mother_id: int) -> int: """Create a new child person. @@ -420,6 +476,30 @@ def do_birth(self, mother_id: int) -> int: child_id = self.population.do_birth() for module in self.modules.values(): module.on_birth(mother_id, child_id) + + if self.generate_event_chains: + # When individual is born, store their initial properties to provide a starting point to the chain of property + # changes that this individual will undergo as a result of events taking place. + prop_dict = self.population.props.loc[child_id].to_dict() + prop_dict['event'] = 'Birth' + prop_dict['event_date'] = self.date + + pop_dict = {i: '' for i in range(FACTOR_POP_DICT)} # Always include all possible individuals + pop_dict[child_id] = str(prop_dict) # Convert to string to avoid issue of length + + print("Length at birth", len(pop_dict)) + logger.info(key='event_chains', + data = pop_dict, + description='Links forming chains of events for simulated individuals') + + # TO BE REMOVED This is currently just used for debugging. Will be removed from final version of PR. + row = self.population.props.iloc[[child_id]] + row['person_ID'] = child_id + row['event'] = 'Birth' + row['event_date'] = self.date + row['when'] = 'After' + self.event_chains = pd.concat([self.event_chains, row], ignore_index=True) + return child_id def find_events_for_person(self, person_id: int) -> list[tuple[Date, Event]]: diff --git a/tests/test_rti.py b/tests/test_rti.py index 3075d5f70b..35e5bb0f2f 100644 --- a/tests/test_rti.py +++ b/tests/test_rti.py @@ -25,6 +25,17 @@ end_date = Date(2012, 1, 1) popsize = 1000 +@pytest.mark.slow +def test_data_harvesting(seed): + """ + This test runs a simulation with a functioning health system with full service availability and no set + constraints + """ + # create sim object + sim = create_basic_rti_sim(popsize, seed) + # run simulation + sim.simulate(end_date=end_date) + exit(-1) def check_dtypes(simulation): # check types of columns in dataframe, check they are the same, list those that aren't @@ -65,6 +76,7 @@ def test_run(seed): check_dtypes(sim) + @pytest.mark.slow def test_all_injuries_run(seed): """