Learning representations for RL in Healthcare under a POMDP assumption, honoring the sequential nature that data wherewith data is generate. This paper, accepted to the proceedings of the 2020 Machine Learning for Healthcare Workshop at NeurIPS, empirically evaluates several recurrent autoencoding architectures to assess the quality of the internal representations each learn.
The motivation for this systematic analysis is that most prior work developing RL solutions for healthcare neglect to rigorously define state representations that respect the partial or sequential nature of the data generating process. We evaluate several recurrent autoencoding architectures, trained by predicting the subsequent physiological observation, by investigating the quality of the internal learned representations of patient state as well as develop treatment policies from them.
If you use this code in your research, please cite the following publication link to PMLR version:
@inproceedings{killian2020empirical,
title={An Empirical Study of Representation Learning for Reinforcement Learning in Healthcare},
author={Killian, Taylor W and Zhang, Haoran and Subramanian, Jayakumar and Fatemi, Mehdi and Ghassemi, Marzyeh},
booktitle={Machine Learning for Health},
pages={139--160},
year={2020},
organization={PMLR}
}
This paper can also be found on arxiv: https://arxiv.org/abs/2011.11235
Run the following commands to clone this repo and create the Conda environment
git clone https://github.com/MLforHealth/rl_representations.git
cd rl_representations/
conda env create -f environment.yml
conda activate rl4h_rep
The data used to develop, run and evaluate our experiments is extracted from the MIMIC-III database, based on the Sepsis cohort used by Komorowski, et al (2018). We replicate and refine the code to extract this cohort at the following repository https://github.com/microsoft/mimic_sepsis. For replication purposes, save both the z-normalized and "RAW" features.
This extracted cohort is for general purpose use, for the usage in this paper it required additional preprocessing which we outline here.
- Run
scripts/compute_acuity_scores.py
to compute additional acuity scores with the raw patient features. - Run
scripts/split_sepsis_cohort.py
to compose a training, validation and testing split as well as remove binary or demographic information from the temporal patient features. This script also organizes the patient data into convenient trajectory formats for easier use with sequential or recurrent models. - Use Behavior Cloning on the data provided by the previous step to develop a baseline "expert" policy for use in training and evaluating RL policies from the learned patient representations.
- This is done by running
scripts/train_behavCloning.py
. An example of how this can be done is provided inslurm_scripts/slurm_build_BC.py
-->slurm_scripts/slurm_bc_exp
. - This will generate either
behav_policy_file
orbehav_policy_file_wDemo
which is used internal to Steps 2 and 3 below.
- This is done by running
- In
configs/common.yaml
, update the paths at the bottom of the file. - Select a
model
- one of [AE
,AIS
,CDE
,DDM
,DST
,ODERNN
,RNN
] - Run
python slurm_build_{model}.py
to output a text file containing an array of launch arguments toscripts/train_model.py
that were used to generate the results from the paper. - If you are using a Slurm cluster, you can run
sbatch slurm_{model}_exp
after altering the Slurm header to run a Slurm task array where each job corresponds to training a single model with one line of arguments from the generated file. - Otherwise, models can also be trained on an ad-hoc basis by manually calling
python scripts/train_model.py
with the arguments shown in the generated text file. - The configuration files in
config_sepsis_{model}.yaml
contain the hyperparameters used in the paper. These can be changed if desired.
RL policies are learned using the discretized form of Batch Constrained Q-learning from Fujimoto, et al (2019). These policies are learned as the final part of the previous Step 2 inside of scripts/train_model.py
. The policies are evaluated intermittently using a form of Weighted Importance Sampling (procedure found in scripts/dBCQ_utils.py
, line 284).
We compare various evaluations of the learned policies using the iPython notebook: notebooks/Evaluate\ dBCQ\ policies.ipynb
.
We aggregate the results of all model training and the analysis of the learned representations in the notebook notebooks/Aggregate\ Next\ Step\ Results.ipynb
.
See requirements.txt
- Taylor W. Killian @twkillian
- Haoran Zhang @hzhang0
- Jaykumar Subramanian
- Mehdi Fatemi
- Marzyeh Ghassemi
The source code and documentation are licensed under the terms of the MIT License
Please don't hesitate to log an issue and we'll be as prompt as possible when responding.