Skip to content

Commit

Permalink
Merge pull request #254 from stratosphereips/trajectory_graph_object
Browse files Browse the repository at this point in the history
Trajectory graph object
  • Loading branch information
ondrej-lukas authored Nov 13, 2024
2 parents 79f7e26 + bf755dd commit fd9ead4
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 26 deletions.
2 changes: 1 addition & 1 deletion env/netsecenv_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ env:
random_seed: 'random'
# Or you can fix the seed
# random_seed: 42
scenario: 'three_nets'
scenario: 'scenario1'
use_global_defender: False
max_steps: 50
use_dynamic_addresses: False
Expand Down
233 changes: 208 additions & 25 deletions utils/gamaplay_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,169 @@
import os
import utils
import argparse
import matplotlib.pyplot as plt

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__) )))
from env.game_components import GameState, Action

class TrajectoryGraph:
def __init__(self)->None:
self._checkpoints = {}
self._checkpoint_edges = {}
self._checkpoint_simple_edges = {}
self._wins_per_checkpoint = {}
self._state_to_id = {}
self._id_to_state = {}
self._action_to_id = {}
self._id_to_action = {}

@property
def num_checkpoints(self)->int:
return len(self._checkpoints)

def get_state_id(self, state:GameState)->int:
"""
Returns state id or creates new one if the state was not registered before
"""
state_str = utils.state_as_ordered_string(state)
if state_str not in self._state_to_id.keys():
self._state_to_id[state_str] = len(self._state_to_id)
self._id_to_state[self._state_to_id[state_str]] = state
return self._state_to_id[state_str]

def get_state(self, id:int)->GameState:
return self._id_to_state[id]

def get_action_id(self, action:Action)->int:
"""
Returns action id or creates new one if the state was not registered before
"""
if action not in self._action_to_id.keys():
self._action_to_id[action] = len(self._action_to_id)
self._id_to_action[self._action_to_id[action]] = action
return self._action_to_id[action]

def get_action(self, id:int)-> Action:
return self._id_to_action[id]

def add_checkpoint(self, trajectories:list, end_reason=None)->None:
# Add complete trajectory list
wins = []
edges = {}
simple_edges = {}
for play in trajectories:
if end_reason and play["end_reason"] not in end_reason:
continue
if len(play["trajectory"]["actions"]) == 0:
continue
if play["end_reason"] == "goal_reached":
wins.append(1)
else:
wins.append(0)
state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][0]))
#print(f'Trajectory len: {len(play["trajectory"]["actions"])}')
for i in range(1, len(play["trajectory"]["actions"])):
next_state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][i]))
action_id = self.get_action_id(Action.from_dict((play["trajectory"]["actions"][i])))
# fullgraph
if (state_id, next_state_id, action_id) not in edges:
edges[state_id, next_state_id, action_id] = 0
edges[state_id, next_state_id, action_id] += 1

#simplified graph
if (state_id, next_state_id)not in simple_edges:
simple_edges[state_id, next_state_id] = 0
simple_edges[state_id, next_state_id] += 1
state_id = next_state_id
self._checkpoint_simple_edges[self.num_checkpoints] = simple_edges
self._checkpoint_edges[self.num_checkpoints] = edges
self._wins_per_checkpoint[self.num_checkpoints] = np.array(wins)
self._checkpoints[self.num_checkpoints] = trajectories

def get_checkpoint_wr(self, checkpoint_id:int)->tuple:
if checkpoint_id not in self._wins_per_checkpoint:
raise IndexError(f"Checkpoint id '{checkpoint_id}' not found!")
else:
return np.mean(self._wins_per_checkpoint[checkpoint_id]), np.std(self._wins_per_checkpoint[checkpoint_id])

def get_wr_progress(self)->dict:
ret = {}
for i in self._wins_per_checkpoint.keys():
wr, std = self.get_checkpoint_wr(i)
ret[i] = {"wr":wr, "std":std}
print(f"Checkpoint {i}: WR={wr}±{std}")
return ret

def get_graph_stats_progress(self):
ret = {}
print("Checkpoint,\tWR,\tEdges,\tSimpleEdges,\tNodes,\tLoops,\tSimpleLoops")
for i in self._wins_per_checkpoint.keys():
data = self.get_checkpoint_stats(i)
ret[i] = data
print(f'{i},\t{data["winrate"]},\t{data["num_edges"]},\t{data["num_simplified_edges"]},\t{data["num_nodes"]},\t{data["num_loops"]},\t{data["num_simplified_loops"]}')
return ret

def plot_graph_stats_progress(self, filedir="figures", filename="trajectory_graph_stats.png"):
data = self.get_graph_stats_progress()
wr = [data[i]["winrate"] for i in range(len(data))]
num_nodes = [data[i]["num_nodes"] for i in range(len(data))]
num_edges = [data[i]["num_edges"] for i in range(len(data))]
num_simle_edges = [data[i]["num_simplified_edges"] for i in range(len(data))]
num_loops = [data[i]["num_loops"] for i in range(len(data))]
num_simplified_loops = [data[i]["num_simplified_loops"] for i in range(len(data))]
checkpoints = range(len(wr))
plt.plot(checkpoints, num_nodes, label='Number of nodes')
plt.plot(checkpoints, num_edges, label='Number of edges')
plt.plot(checkpoints, num_simle_edges, label='Number of simplified edges')
plt.plot(checkpoints, num_loops, label='Number of loops')
plt.plot(checkpoints, num_simplified_loops, label='Number of simplified loops')

plt.title("Graph statistics per checkpoint")
plt.yscale('log')
plt.xlabel("Checkpoints")
# Show legend
plt.legend()

# Save the figure as an image file
plt.savefig(os.path.join(filedir, filename))

def get_checkpoint_stats(self, checkpoint_id:int)->dict:
if checkpoint_id not in self._wins_per_checkpoint:
raise IndexError(f"Checkpoint id '{checkpoint_id}' not found!")
else:
data = {}
data["winrate"] = np.mean(self._wins_per_checkpoint[checkpoint_id])
data["winrate_std"] = np.std(self._wins_per_checkpoint[checkpoint_id])
data["num_edges"] = len(self._checkpoint_edges[checkpoint_id])
data["num_simplified_edges"] = len(self._checkpoint_simple_edges[checkpoint_id])
data["num_loops"] = len([edge for edge in self._checkpoint_edges[checkpoint_id].keys() if edge[0]==edge[1]])
data["num_simplified_loops"] = len([edge for edge in self._checkpoint_simple_edges[checkpoint_id].keys() if edge[0]==edge[1]])
node_set = set([src_node for src_node,_,_ in self._checkpoint_edges[checkpoint_id].keys()]) | set([dst_node for _,dst_node,_ in self._checkpoint_edges[checkpoint_id].keys()])
data["num_nodes"] = len(node_set)
return data

def get_graph_structure_progress(self)->dict:

all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values()))
super_graph = {key:np.zeros(self.num_checkpoints) for key in all_edges}
for i, edge_list in self._checkpoint_edges.items():
for edge in edge_list:
super_graph[edge][i] = 1
return super_graph

def get_graph_structure_probabilistic_progress(self)->dict:

all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values()))
super_graph = {key:np.zeros(self.num_checkpoints) for key in all_edges}
for i, edge_list in self._checkpoint_edges.items():
total_out_edges_use = {}
for (src, _, _), frequency in edge_list.items():
if src not in total_out_edges_use:
total_out_edges_use[src] = 0
total_out_edges_use[src] += frequency
for (src,dst,edge), value in edge_list.items():
super_graph[(src,dst,edge)][i] = value/total_out_edges_use[src]
return super_graph

def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple:
edges = {}
Expand Down Expand Up @@ -94,33 +252,58 @@ def get_graph_modificiation(edge_list1, edge_list2):


parser = argparse.ArgumentParser()
parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True)
parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True)
# parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True)
# parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True)
parser.add_argument("--end_reason", help="Filter options for trajectories", default=None, type=str, action='store', required=False)
parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=1000, required=False)
parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=10000, required=False)

args = parser.parse_args()
trajectories1 = read_json(args.t1, max_lines=args.n_trajectories)
trajectories2 = read_json(args.t2, max_lines=args.n_trajectories)
states = {}
actions = {}
# trajectories1 = read_json(args.t1, max_lines=args.n_trajectories)
# trajectories2 = read_json(args.t2, max_lines=args.n_trajectories)
# states = {}
# actions = {}

graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason)
graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason)
# graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason)
# graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason)

state_to_id = {v:k for k,v in states.items()}
action_to_id = {v:k for k,v in states.items()}

print(f"Trajectory 1: {args.t1}")
print(f"WR={t1_wr_mean}±{t1_wr_std}")
get_graph_stats(graph_t1, state_to_id, action_to_id)
print(f"Trajectory 2: {args.t2}")
print(f"WR={t2_wr_mean}±{t2_wr_std}")
get_graph_stats(graph_t2, state_to_id, action_to_id)

a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2)
print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}")
# print("positions of same states:")
# for node in node_set(graph_t1).intersection(node_set(graph_t2)):
# print(g1_timestaps[node], g2_timestaps[node])
# print("-----------------------")
# state_to_id = {v:k for k,v in states.items()}
# action_to_id = {v:k for k,v in states.items()}

# print(f"Trajectory 1: {args.t1}")
# print(f"WR={t1_wr_mean}±{t1_wr_std}")
# get_graph_stats(graph_t1, state_to_id, action_to_id)
# print(f"Trajectory 2: {args.t2}")
# print(f"WR={t2_wr_mean}±{t2_wr_std}")
# get_graph_stats(graph_t2, state_to_id, action_to_id)

# a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2)
# print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}")
# # print("positions of same states:")
# # for node in node_set(graph_t1).intersection(node_set(graph_t2)):
# # print(g1_timestaps[node], g2_timestaps[node])
# # print("-----------------------")
# tg_no_blocks = TrajectoryGraph()

# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-2000.jsonl",max_lines=args.n_trajectories))
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-4000.jsonl",max_lines=args.n_trajectories))
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-6000.jsonl",max_lines=args.n_trajectories))
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-8000.jsonl",max_lines=args.n_trajectories))
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories))
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories))

# tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_no_blocks.jsonl",max_lines=args.n_trajectories))
# tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_no_blocks.jsonl",max_lines=args.n_trajectories))
# tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_no_blocks.jsonl",max_lines=args.n_trajectories))
# tg_no_blocks.plot_graph_stats_progress()

tg_blocks = TrajectoryGraph()
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_blocks.jsonl",max_lines=args.n_trajectories))
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_blocks.jsonl",max_lines=args.n_trajectories))
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_blocks.jsonl",max_lines=args.n_trajectories))
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-20000_blocks.jsonl",max_lines=args.n_trajectories))
tg_blocks.plot_graph_stats_progress()

super_graph = tg_blocks.get_graph_structure_probabilistic_progress()
print(len(super_graph))
edges_present_everycheckpoint = [k for k,v in super_graph.items() if np.min(v) > 0]
print(len(edges_present_everycheckpoint))

0 comments on commit fd9ead4

Please sign in to comment.