IxDRL
: Interestingness analysis for explainable Deep RL toolkit
- Citing
- Installation
- 1. Agent Training (Optional)
- 2. Interaction Data Collection
- 3. Interestingness Analysis
- 4. Clustering Traces (Optional)
- 5. Highlights
- 6. Feature Importance
- 7. GUI Tool
- License Notice
- Support
- Acknowledgements
This code corresponds to the implementation of the interestingness toolkit described in this paper (arXiv version). The paper explains the goal behind this toolkit, the interestingness dimensions' rationale, and contains experiments using the code on different simulation platforms and the corresponding analysis.
Please cite the paper if you use this code in your research:
@InProceedings{10.1007/978-3-031-44064-9_20,
author="Sequeira, Pedro and Gervasio, Melinda",
editor="Longo, Luca",
title="IxDRL: A Novel Explainable Deep Reinforcement Learning Toolkit Based on Analyses of Interestingness",
booktitle="Explainable Artificial Intelligence",
year="2023",
publisher="Springer Nature Switzerland",
pages="373--396",
isbn="978-3-031-44064-9"
}
git checkout https://github.com/SRI-AIC/IxDRL.git
cd IxDRL
pip install -e .
Python installation options:
-
[rllib]
: will installray
and necessary dependencies. Select this option to useIxDRL
to analyze policies trained with the RLLib toolkit. -
[gui]
: will installstreamlit
and necessary dependencies. Select this option to run the interestingness and agent behavior visualization GUI tool.
Note: the code has been tested on Python 3.10 running on Linux and Mac.
The input to the IxDRL
framework is a policy trained via deep Reinforcement Learning. This toolkit supports the analysis of policies trained using RLLib out-of-the-box. As explained in the next section, we need to implement the necessary data collectors for any other RL toolkit.
As for RLLib, see the documentation to train an agent using one of the implemented RL algorithms. In scripts/01_train.sh
we provide an example of training a policy using the Distributional Q algorithm on the ATARI Breakout environment.
Note: you will need to install either torch
or tensorflow
to train the agent.
Running:
source scripts/01_train.sh
will train the agent and store the corresponding results and checkpoint files under output/breakout-dist-dqn
.
Note: this code was tested with ray==2.9.2
. See requirements.txt
for specific versions of the Python packages used during testing.
After training an RL policy, we test it by deploying it in the environment a number of times under different initial conditions, resulting in a set of rollouts. As the agent interacts with the environment, we probe the learned models and collect various information about the agent�fs behavior and internal state, e.g., the value attributed to a state, or the probability each action has of being selected?we refer to all this information as the interaction data.
In the IxDRL
toolkit, the ixdrl.InteractionData
class is used to create objects storing the interaction data collected for a rollout that are supported by our interestingness extractors (detailed in the next section). The ixdrl.Rollout
class packages information about a rollout in addition to the collected interaction data, such as details about the environment's action and observation spaces, the path to the video file of the sampled rollout, etc., which are used by different components of the IxDRL
framework.
To collect interaction data, we should subclass ixdrl.data_collection.InteractionDataCollector
by implementing the collect_data
method, e.g.:
def collect_data(self, num_rollouts: int) -> Rollouts:
rollouts = []
for r in range(num_rollouts):
data = []
for t in range(self.rollout_length):
# collect the available interaction data
datapoint = InteractionData(...)
data.append(datapoint)
rollout = Rollout(str(r), data, ...)
return {r.rollout_id: r for r in rollouts}
As mentioned above, the IxDRL
toolkit supports the analysis of policies trained using RLLib. The python script to collect interaction data given a trained RL-Lib policy is at ixdrl/bin/collect/rllib.py
:
python -m ixdrl.bin.collect.rllib
checkpoint
--run RUN
--output OUTPUT
[--render]
[--episodes EPISODES]
[--config CONFIG]
[--stats-only STATS_ONLY]
[--img-format IMG_FORMAT]
[--labels-file LABELS_FILE]
[--fps FPS]
[--clear CLEAR]
[--verbosity VERBOSITY]
where:
checkpoint
: RL policy checkpoint from which to rollout and collect data.output
: The path to the directory in which to save the collected interaction data.run
: The algorithm or model to train. This may refer to the name of a built-in algorithm (e.g. RLlib's DQN or PPO), or a user-defined trainable function or class registered in the tune registry.render
: Render the environment while evaluating.episodes
: Number of complete episodes to roll out.config
: Algorithm-specific configuration (e.g. env, hyperparams). Gets merged with loaded configuration from checkpoint file.stats_only
: Whether to use previously-collected data and print stats only. IfTrue
but the data file cannot be found, then will still collect the data.img_format
: The format of produced image files.labels_file
: The path to a JSON file containing specifications for the action and observation labels of custom environments.fps
: The frames per second rate used to save the episode videos.clear
: Clear output directories before generating results.verbosity
: Verbosity level.
This script produces a pickle file named interaction_data.pkl.gz
in the output directory containing all interaction data collected given the RL policy. It also produces several plots and data with statistics about the collected data (see details below).
We can use the python function ixdrl.data_collection.print_stats
to produce plots and data files with different statistics about the given interaction data, including the mean reward received over the course of an episode, the mean action distribution and frequency, the mean observation values, etc, e.g.:
We provide a ready-to-use shell script available in scripts/02_collect.sh
to extract interaction data from the RL policy in the Breakout environment trained using the scripts/01_train.sh
script.
Note: rename the directory created by the previous script, named "DQN_ALE_Breakout-v5..."
to "model"
before executing the data collection script.
The core of the IxDRL
framework is the performance of interestingness analysis to characterize an RL agent's competence along various dimensions, each capturing a distinct aspect of the agent�fs interaction with the environment. The paper details which dimensions are analyzed by the toolkit and which interaction data artifacts are required for each analysis.
The code that extracts interestingness for each supported dimension given interaction data can be found under ixdrl/analysis
:
riskiness.py
confidence.py
goal_conduciveness.py
stochasticity.py
familiarity.py
incongruity.py
value.py
To extend the framework and create novel analyses of interestingness (new dimensions), we need to subclass ixdrl.analysis.AnalysisDimensionBase
and implement the analyze
method:
def analyze(self, rollouts: Rollouts) -> Optional[RolloutsAnalyses]:
...
which should return a RolloutsAnalyses
object containing the extracted interestingness for the new dimension for each timestep of each rollout/episode of collected data.
To extract interestingness for all dimensions given interaction data we should run the python script at ixdrl/bin/analyze.py
:
python ixdrl.bin.analyze
--input INPUT
--output OUTPUT
[--img-format IMG_FORMAT]
[--derivative_accuracy DERIVATIVE_ACCURACY]
[--processes PROCESSES]
[--clear CLEAR]
[--verbosity VERBOSITY]
where:
input
: Pickle file containing the interaction data collected usingbin.collect.*
scripts.output
: Directory in which to save results.img_format
: The format of image files.derivative_accuracy
: The accuracy used by finite difference methods. Needs to be a positive, even number.processes
: The number of parallel processes to use for this analysis. A value of-1
orNone
will use all available CPUs.clear
: Clear output directories before generating results.verbosity
: Verbosity level.
This script produces a pickle file named interestingness.pkl.gz
containing all interestingness data computed from the input interaction data.
In addition to saving the interestingness data file, the above script produces plots and data files with different statistics about the interestingness extracted from the collected policy rollouts under the /stats
subdirectory. This includes correlation matrices between the different interestingness dimensions, the mean interestingness over the course of rollouts and the interestingness profiles, e.g.:
We provide a ready-to-use shell script, scripts/03_interestingness.sh
, to compute interestingness given the interaction data of the RL policy in the Breakout environment collected using the scripts/02_collect.sh
script.
In our framework, one of the approaches to help identify distinct, meaningful behavior patterns of RL agents (which could indicate that an agent has achieved some level of competence) is by performing trace clustering solely based on interestingness (see the paper for more details).
To cluster traces based on interestingness, we use the Trace Clustering toolkit available at: https://github.com/SRI-AIC/trace-clustering, which supports clustering sequences of data (one or more features), in particular numeric data, corresponding to the data resulting from the interestingness analysis, i.e., an interestingness value for each timestep of each episode, for each dimension (feature). See the documentation to install the python package and see which clustering options are available.
Once we have clustered the RL agent rollouts using interestingness, we can use the script available at ixdrl/bin/copy_videos.py
to copy a random subset of videos sampled from each cluster (if available).
We can also analyze interestingness per cluster to help assess the competency of the RL agent at different regions of the task space, as dictated by the discovered clusters. To do that we use the script available at ixdrl/bin/analyze_by_cluster.py
, which will produce plots and data files with different statistics about the interestingness extracted for each cluster under the corresponding /cluster-X
subdirectory. In addition, plots comparing the interestingness attained in each cluster are generated and stored under the /comparison
subdirectory, e.g.:
We provide a shell script available at scripts/04_cluster_traces.sh
to cluster traces (and analyze interestingness per cluster) based on interestingness data of the RL policy in the Breakout environment computed using the scripts/03_interestingness.sh
script.
After interestingness has been extracted from an RL agent's rollouts, IxDRL
provides different tools to help interpret the agent's competency in the task. The interpretation code can be found under: ixdrl/interpretation
.
One interpretation tool consists in capturing highlights, video summaries of the agent's behavior in different situations, as captured by interestingness analysis (see the paper for details and references).
In order to extract highlights from collected interaction and interestingness data, we should run the python script at ixdrl/bin/highlights.py
:
python -m ixdrl.bin.highlights
--output OUTPUT
--interestingness INTERESTINGNESS
--metadata METADATA
[--max-highlights MAX_HIGHLIGHTS]
[--record-timesteps RECORD_TIMESTEPS]
[--fade-ratio FADE_RATIO]
[--iqr-mul IQR_MUL]
[--processes PROCESSES]
[--format FORMAT]
[--clear CLEAR]
[--verbosity VERBOSITY]
[--clusters CLUSTERS]
[--cluster-col CLUSTER_COL]
[--rollout-col ROLLOUT_COL]
where:
output
: The path to the directory in which to save the results.interestingness
: The path to the directory with the interestingness analyses for all traces and all the plots. Typically this was produced by the "analyze.py" script.metadata
: The path to the rollouts metadata file containing the names of the replay video files.max_highlights
: Maximum highlights to be extracted per interestingness dimension.record_timesteps
: The number of environment time-steps to be recorded in each video.fade_ratio
: The ratio of frames to which apply a fade-in/out effect.iqr_mul
: The IQR multiplier to determine outliers.processes
: Number of processes for parallel processing. Value < 1 uses all available CPUs.format
: The format of image files.clear
: Clear output directories before generating results.verbosity
: Verbosity level.clusters
: The path to the clusters CSV file to be used, containing a reference for the replay file of each trace and corresponding cluster. Highlights will be computed for the data in each cluster and saved to separate directories. IfNone
, highlights will be computed for the whole data.cluster_col
: The name of the column in the clusters file holding the information on the cluster id for each rollout.rollout_col
: The name of the column in the clusters file holding the information on the id for each rollout.
The script will identify timesteps, from all rollouts collected for the RL agent, where the interestingness value for some dimension is considered a statistical outlier (i.e., a very high or low value compared to the mean value). If the clusters
parameter is specified, the script will create a subdirectory named /cluster-X
where highlight data are stored, otherwise the data will be stored under /cluster-overall
.
In each highlights subdirectory, a data file named highlights.csv
is created containing the list of highlights identified for the data, specifying the corresponding interestingness dimension and value, the label ("high"
or "low"
), the id of the agent rollout in which the highlight was identified and at which timestep. Different subdirectories are also created for each analyzed interestingness dimension containing plots for each highlight showing interestingness over the course of the rollout and the timestep at which the highlight was identified, e.g.:
If video replays exist for each sampled agent rollout, the script will also generate the videos summaries corresponding to each highlight, e.g.:
We provide a shell script available at scripts/05_highlights.sh
to identify highlights of the RL policy in the Breakout environment based on the extracted interestingness data.
Feature importance analysis allows us to gain deeper insight into which task elements most affect an agent's competence, how they affect the agent's behavior as measured by interestingness, and where (in which situations) this occurs (see the paper for more details).
In IxDRL
, to perform feature importance analysis over the collected interestingness data we should use the script at ixdrl/bin/feature_importance.py
:
python -m ixdrl.bin.feature_importance
--interaction-data INTERACTION_DATA
--interestingness INTERESTINGNESS
--output OUTPUT
[--highlights HIGHLIGHTS]
[--seed SEED]
[--processes PROCESSES]
[--format FORMAT]
[--clear CLEAR]
[--verbosity VERBOSITY]
[--clusters CLUSTERS]
[--cluster-col CLUSTER_COL]
[--rollout-col ROLLOUT_COL]
where:
interaction_data
: Pickle file containing the interaction data collected usingbin.collect.*
scripts.interestingness
: Path to pandas dataframe pickle file containing the interestingness data, produced by thebin.analyze
script.output
: The path to the directory in which to save the results.highlights
: Path to CSV file containing the highlights/outliers info for which to produce local/individual explanations.seed
: Seed used for random number generation.processes
: Number of processes for parallel processing. Value < 1 uses all available CPUs.format
: The format of image files.clear
: Clear output directories before generating results.verbosity
: Verbosity level.clusters
: The path to the clusters CSV file to be used, containing a reference for the replay file of each trace and corresponding cluster. Feature importance will be computed for the data in each cluster and saved to separate directories. IfNone
, highlights will be computed for the whole data.cluster_col
: The name of the column in the clusters file holding the information on the cluster id for each rollout.rollout_col
: The name of the column in the clusters file holding the information on the id for each rollout.
Note: it might take a while to train the surrogate models and perform global and local feature importance analysis, depending on the number of features and generated rollouts.
After executing the script, the following subdirectories are created under output
:
-
data
: contains the input (features
) and output (targets
) datasets created to learn the surrogate models via gradient boosting for each interestingness dimension. -
models
: contains the trained surrogate models for each interestingness dimension. -
regression-eval
: contains plots with the evaluation of the trained surrogate models for each interestingness dimension under different metrics. -
hyperparam-optimization
: contains plots showing the impact of the hyper-parameter optimization procedure while training the surrogate models for each parameter. -
model-feature-importance
: contains plots showing feature importance analysis performed using the XGBoost toolkit (i.e., not using SHAP analysis). -
global-feature-explanations
: contains the global analysis results (see more details below). -
local-feature-explanations
: contains the local analysis results (see more details below).
The goal of the global interpretation of interestingness is to understand how distinct aspects of the task influence interestingness in general. The
The global-feature-explanations
subdirectory generated by the feature_importance
script contains, for each interestingness dimension, several plots illustrating the SHAP values obtained for each observation feature, the SHAP interaction values and effects between features, e.g.:
We use local interpretation of interestingness to understand particular key moments of the agent's interaction with the environment.
The local-feature-explanations
subdirectory generated by the feature_importance
script contains, for each identified outlier (identified using the ixdrl/bin/highlights.py
script), the SHAP values (contribution relative to the mean) impacting the prediction of the different interestingness dimensions, e.g.:
We provide a shell script available at scripts/06_feature_importance.sh
to perform feature importance of the RL policy in the Breakout environment based on the collected interaction data, extracted interestingness data and identified highlights.
The IxDRL
toolkit provides an interactive graphical tool (webapp) built using streamlit
that allows visualizing videos of traces (replays), select a dimension and sort traces by mean interestingness value, and see the various plots automatically produced by our framework during interaction data extraction and interestingness analysis (see installation requirements above).
To load the tool, run:
streamlit run ixdrl/gui/inspector.py \
--theme.base dark \
-- \
--interaction INTERACTION_DATA \
--interestingness INTERESTINGNESS_DATA \
[--training TRAINING_INFO] \
[--video VIDEO_DIR]
where:
-
interaction
: The directory containing the interaction data for a series of rollouts. -
interestingness
: The directory containing the interestingness analyses results for a series of rollouts. -
training
: The path to a CSV file containing information about the agent's training progress. -
video
: The path to the root directory of video files defined in the rollouts data. If not provided, the video path for a rollout as specified in the interaction data will be used.
This will launch a web-browser displaying the interactive application:
Use the side bar to select the agent rollout to inspect, and the main window to interact with the video replay and plots of interaction and interestingness data.
The code is provided under the GPLv3 license (see full license file). To use the code under a different set of licensing terms, please contact SRI International's licensing department at licensee-ops@sri.com.
If you have questions about using this package or find errors in the code you can post an issue or contact Pedro Sequeira or Melinda Gervasio.
This material is based upon work supported by the Defense Advanced Research Projects Agency (DARPA) under Contract No. HR001119C0112.