Skip to content

Commit

Permalink
Add option to add node_labels to graph (#250)
Browse files Browse the repository at this point in the history
# Labelled graphs

## Description
Allows node labels to be passed (as csv file) to be plotting script, example:
```
python plot_lm_graph.py /mnt/ssd-apollo/stefan/modular_arithmetic_interaction_graph.pt --nodes_per_layer 10 --labels_file test.csv  --out_file test.png
```
Also added other parameters as command line arguments.

## Motivation
We can now generate labels (interp repo) with
```
python generate_labels_mod_add.py /mnt/ssd-apollo/stefan/modular_arithmetic_interaction_graph.pt --out_file=test.csv
```
and want to add them to our graphs.

Labels currently only available for mod add, but this setup works for any model.

## Tested
Made labelled and unlabelled plots.

## Breaking changes
No
  • Loading branch information
stefan-apollo authored Dec 7, 2023
1 parent 5779b6a commit 7ac5fec
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
27 changes: 23 additions & 4 deletions experiments/lm_rib_build/plot_lm_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
The results_pt_file should be the output of the run_lm_rib_build.py script.
"""
import csv
from pathlib import Path
from typing import Optional, Union

import fire
import torch
Expand All @@ -14,29 +16,46 @@
from rib.utils import check_outfile_overwrite


def main(results_file: str, force: bool = False) -> None:
def main(
results_file: str,
nodes_per_layer: int = 40,
labels_file: Optional[str] = None,
out_file: Optional[Union[str, Path]] = None,
force: bool = False,
) -> None:
"""Plot an interaction graph given a results file contain the graph edges."""
results = torch.load(results_file)
out_dir = Path(__file__).parent / "out"
out_file = out_dir / f"{results['exp_name']}_rib_graph.png"
if out_file is None:
out_file = out_dir / f"{results['exp_name']}_rib_graph.png"
else:
out_file = Path(out_file)

if not check_outfile_overwrite(out_file, force):
return

# Ensure that we have edges
assert results["edges"], "The results file does not contain any edges."

# Set all layers to have the same number of nodes
nodes_per_layer = 40
# Add labels if provided
if labels_file is not None:
with open(labels_file, "r", newline="") as file:
reader = csv.reader(file)
node_labels = list(reader)
else:
node_labels = None

plot_interaction_graph(
raw_edges=results["edges"],
layer_names=results["config"]["node_layers"],
exp_name=results["exp_name"],
nodes_per_layer=nodes_per_layer,
out_file=out_file,
node_labels=node_labels,
)

print(f"Saved plot to {out_file}")


if __name__ == "__main__":
fire.Fire(main)
15 changes: 14 additions & 1 deletion rib/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,15 @@ def plot_interaction_graph(
exp_name: str,
nodes_per_layer: Union[int, list[int]],
out_file: Path,
node_labels: Optional[list[list[str]]] = None,
) -> None:
"""Plot the interaction graph for the given edges.
Args:
raw_edges (list[tuple[str, torch.Tensor]]): List of edges which are tuples of
(module, edge_weights), each edge with shape (n_nodes_in_l+1, n_nodes_in_l)
layer_names (list[str]): The names of the layers. These should correspond to the first
element of each tuple in raw_edges, but also include a name for the output layer.
element of each tuple in raw_edges, but also include a name for the final node_layer.
exp_name (str): The name of the experiment.
nodes_per_layer (Union[int, list[int]]): The number of nodes in each layer. If int, then
all layers have the same number of nodes. If list, then the number of nodes in each
Expand All @@ -124,6 +125,10 @@ def plot_interaction_graph(

# Verify that the layer names match the edge names
edge_names = [edge_info[0] for edge_info in raw_edges]
if len(edge_names) != len(layer_names) - 1:
print(
f"WARNING: len(edge_names) != len(layer_names) - 1. edge_names={edge_names}, layer_names={layer_names}. This will probably cause the last layer in the plot to have no nodes. Are you using an old file?"
)
for edge_name, layer_name in zip(edge_names, layer_names[:-1]):
assert edge_name == layer_name, "The layer names must match the edge names."

Expand Down Expand Up @@ -157,6 +162,14 @@ def plot_interaction_graph(
# Add layer label above the nodes
plt.text(i, max_layer_height, layer_name, ha="center", va="center", fontsize=12)

# Label nodes if node_labels is provided
if node_labels is not None:
node_label_dict = {}
for i, layer in enumerate(layers):
for j, node in enumerate(layer):
node_label_dict[node] = node_labels[i][j].replace("|", "\n")
nx.draw_networkx_labels(graph, pos, node_label_dict, font_size=8)

# Draw edges
width_factor = 15
# for edge in graph.edges(data=True):
Expand Down

0 comments on commit 7ac5fec

Please sign in to comment.