From 0e175bcf1564342a4fa63fc6873947f4f69e95b7 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 12 Nov 2024 13:41:08 +0100 Subject: [PATCH 1/7] Create class Trajectory graphs and move functionality to it --- utils/gamaplay_graphs.py | 145 ++++++++++++++++++++++++++++++++------- 1 file changed, 120 insertions(+), 25 deletions(-) diff --git a/utils/gamaplay_graphs.py b/utils/gamaplay_graphs.py index 7ea99dd..ee27464 100644 --- a/utils/gamaplay_graphs.py +++ b/utils/gamaplay_graphs.py @@ -8,6 +8,91 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__) ))) from env.game_components import GameState, Action +class TrajectoryGraph: + def __init__(self)->None: + self._checkpoints = {} + self._checkpoint_edges = {} + self._checkpoint_simple_edges = {} + self._wins_per_checkpoint = {} + self._state_to_id = {} + self._id_to_state = {} + self._action_to_id = {} + self._id_to_action = {} + + @property + def num_checkpoints(self)->int: + return len(self._checkpoints) + + def get_state_id(self, state:GameState)->int: + """ + Returns state id or creates new one if the state was not registered before + """ + state_str = utils.state_as_ordered_string(state) + if state_str not in self._state_to_id.keys(): + self._state_to_id[state_str] = len(self._state_to_id) + self._id_to_state[self._state_to_id[state_str]] = state + return self._state_to_id[state_str] + + def get_state(self, id:int)->GameState: + return self._id_to_state[id] + + def get_action_id(self, action:Action)->int: + """ + Returns action id or creates new one if the state was not registered before + """ + if action not in self._action_to_id.keys(): + self._action_to_id[action] = len(self._action_to_id) + self._id_to_action[self._action_to_id[action]] = action + return self._action_to_id[action] + + def add_checkpoint(self, trajectories:list, end_reason=None)->None: + # Add complete trajectory list + wins = [] + edges = {} + simple_edges = {} + for play in trajectories: + if end_reason and play["end_reason"] not in end_reason: + continue + if len(play["trajectory"]["actions"]) == 0: + continue + if play["end_reason"] == "goal_reached": + wins.append(1) + else: + wins.append(0) + state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][0])) + #print(f'Trajectory len: {len(play["trajectory"]["actions"])}') + for i in range(1, len(play["trajectory"]["actions"])): + next_state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][i])) + action_id = self.get_action_id(Action.from_dict((play["trajectory"]["actions"][i]))) + # fullgraph + if (state_id, next_state_id, action_id) not in edges: + edges[state_id, next_state_id, action_id] = 0 + edges[state_id, next_state_id, action_id] += 1 + + #simplified graph + if (state_id, next_state_id)not in simple_edges: + simple_edges[state_id, next_state_id] = 0 + simple_edges[state_id, next_state_id] += 1 + state_id = next_state_id + self._checkpoint_simple_edges[self.num_checkpoints] = simple_edges + self._checkpoint_edges[self.num_checkpoints] = edges + self._wins_per_checkpoint[self.num_checkpoints] = np.array(wins) + self._checkpoints[self.num_checkpoints] = trajectories + + def get_checkpoint_wr(self, checkpoint_id:int)->tuple: + if checkpoint_id not in self._wins_per_checkpoint: + raise IndexError(f"Checkpoint id '{checkpoint_id}' not found!") + else: + return np.mean(self._wins_per_checkpoint[checkpoint_id]), np.std(self._wins_per_checkpoint[checkpoint_id]) + + def get_wr_progress(self)->dict: + ret = {} + for i in self._wins_per_checkpoint.keys(): + wr, std = self.get_checkpoint_wr(i) + ret[i] = {"wr":wr, "std":std} + print(f"Checkpoint {i}: WR={wr}±{std}") + return ret + def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple: @@ -94,33 +179,43 @@ def get_graph_modificiation(edge_list1, edge_list2): parser = argparse.ArgumentParser() - parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True) - parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True) + # parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True) + # parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True) parser.add_argument("--end_reason", help="Filter options for trajectories", default=None, type=str, action='store', required=False) - parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=1000, required=False) + parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=10000, required=False) args = parser.parse_args() - trajectories1 = read_json(args.t1, max_lines=args.n_trajectories) - trajectories2 = read_json(args.t2, max_lines=args.n_trajectories) - states = {} - actions = {} + # trajectories1 = read_json(args.t1, max_lines=args.n_trajectories) + # trajectories2 = read_json(args.t2, max_lines=args.n_trajectories) + # states = {} + # actions = {} - graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason) - graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason) + # graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason) + # graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason) + + # state_to_id = {v:k for k,v in states.items()} + # action_to_id = {v:k for k,v in states.items()} + + # print(f"Trajectory 1: {args.t1}") + # print(f"WR={t1_wr_mean}±{t1_wr_std}") + # get_graph_stats(graph_t1, state_to_id, action_to_id) + # print(f"Trajectory 2: {args.t2}") + # print(f"WR={t2_wr_mean}±{t2_wr_std}") + # get_graph_stats(graph_t2, state_to_id, action_to_id) + + # a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2) + # print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}") + # # print("positions of same states:") + # # for node in node_set(graph_t1).intersection(node_set(graph_t2)): + # # print(g1_timestaps[node], g2_timestaps[node]) + # # print("-----------------------") + tg = TrajectoryGraph() - state_to_id = {v:k for k,v in states.items()} - action_to_id = {v:k for k,v in states.items()} - - print(f"Trajectory 1: {args.t1}") - print(f"WR={t1_wr_mean}±{t1_wr_std}") - get_graph_stats(graph_t1, state_to_id, action_to_id) - print(f"Trajectory 2: {args.t2}") - print(f"WR={t2_wr_mean}±{t2_wr_std}") - get_graph_stats(graph_t2, state_to_id, action_to_id) - - a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2) - print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}") - # print("positions of same states:") - # for node in node_set(graph_t1).intersection(node_set(graph_t2)): - # print(g1_timestaps[node], g2_timestaps[node]) - # print("-----------------------") \ No newline at end of file + tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-2000.jsonl",max_lines=args.n_trajectories)) + tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-4000.jsonl",max_lines=args.n_trajectories)) + tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-6000.jsonl",max_lines=args.n_trajectories)) + tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-8000.jsonl",max_lines=args.n_trajectories)) + tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories)) + tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories)) + print(tg.num_checkpoints) + tg.get_wr_progress() From 8fda3f2052c9aec4416d7edd036433cf1d78dfdd Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 12 Nov 2024 15:30:15 +0100 Subject: [PATCH 2/7] add progress in time --- utils/gamaplay_graphs.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/utils/gamaplay_graphs.py b/utils/gamaplay_graphs.py index ee27464..65236d2 100644 --- a/utils/gamaplay_graphs.py +++ b/utils/gamaplay_graphs.py @@ -93,6 +93,38 @@ def get_wr_progress(self)->dict: print(f"Checkpoint {i}: WR={wr}±{std}") return ret + def get_graph_stats_progress(self): + ret = {} + print("Checkpoint,\tWR,\tEdges,\tSimpleEdges,\tNodes,\tLoops,\tSimpleLoops") + for i in self._wins_per_checkpoint.keys(): + data = self.get_checkpoint_stats(i) + ret[i] = data + print(f'{i},\t{data["winrate"]},\t{data["num_edges"]},\t{data["num_simplified_edges"]},\t{data["num_nodes"]},\t{data["num_loops"]},\t{data["num_simpligied_loops"]}') + return ret + + def get_checkpoint_stats(self, checkpoint_id:int)->dict: + if checkpoint_id not in self._wins_per_checkpoint: + raise IndexError(f"Checkpoint id '{checkpoint_id}' not found!") + else: + data = {} + data["winrate"] = np.mean(self._wins_per_checkpoint[checkpoint_id]) + data["winrate_std"] = np.std(self._wins_per_checkpoint[checkpoint_id]) + data["num_edges"] = len(self._checkpoint_edges[checkpoint_id]) + data["num_simplified_edges"] = len(self._checkpoint_simple_edges[checkpoint_id]) + data["num_loops"] = len([edge for edge in self._checkpoint_edges[checkpoint_id].keys() if edge[0]==edge[1]]) + data["num_simpligied_loops"] = len([edge for edge in self._checkpoint_simple_edges[checkpoint_id].keys() if edge[0]==edge[1]]) + node_set = set([src_node for src_node,_,_ in self._checkpoint_edges[checkpoint_id].keys()]) | set([dst_node for _,dst_node,_ in self._checkpoint_edges[checkpoint_id].keys()]) + data["num_nodes"] = len(node_set) + return data + + def get_graph_structure_progress(self)->dict: + + all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values())) + super_graph = {key:np.zeros(self.num_checkpoints) for key in all_edges} + for i, edge_list in self._checkpoint_edges.items(): + for edge in edge_list: + super_graph[edge][i] = 1 + return super_graph def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple: @@ -218,4 +250,8 @@ def get_graph_modificiation(edge_list1, edge_list2): tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories)) tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories)) print(tg.num_checkpoints) - tg.get_wr_progress() + tg.get_graph_stats_progress() + super_graph = tg.get_graph_structure_progress() + for k,v in super_graph.items(): + if np.sum(v) > 3: + print(k, v) From d93db8cd13990d80da7048bc81653297faebe700 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 12 Nov 2024 15:54:33 +0100 Subject: [PATCH 3/7] Add plots --- utils/gamaplay_graphs.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/utils/gamaplay_graphs.py b/utils/gamaplay_graphs.py index 65236d2..e048e39 100644 --- a/utils/gamaplay_graphs.py +++ b/utils/gamaplay_graphs.py @@ -4,6 +4,7 @@ import os import utils import argparse +import matplotlib.pyplot as plt sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__) ))) from env.game_components import GameState, Action @@ -99,9 +100,33 @@ def get_graph_stats_progress(self): for i in self._wins_per_checkpoint.keys(): data = self.get_checkpoint_stats(i) ret[i] = data - print(f'{i},\t{data["winrate"]},\t{data["num_edges"]},\t{data["num_simplified_edges"]},\t{data["num_nodes"]},\t{data["num_loops"]},\t{data["num_simpligied_loops"]}') + print(f'{i},\t{data["winrate"]},\t{data["num_edges"]},\t{data["num_simplified_edges"]},\t{data["num_nodes"]},\t{data["num_loops"]},\t{data["num_simplified_loops"]}') return ret + def plot_graph_stats_progress(self): + data = self.get_graph_stats_progress() + wr = [data[i]["winrate"] for i in range(len(data))] + num_nodes = [data[i]["num_nodes"] for i in range(len(data))] + num_edges = [data[i]["num_edges"] for i in range(len(data))] + num_simle_edges = [data[i]["num_simplified_edges"] for i in range(len(data))] + num_loops = [data[i]["num_loops"] for i in range(len(data))] + num_simplified_loops = [data[i]["num_simplified_loops"] for i in range(len(data))] + checkpoints = range(len(wr)) + plt.plot(checkpoints, num_nodes, label='Number of nodes') + plt.plot(checkpoints, num_edges, label='Number of edges') + plt.plot(checkpoints, num_simle_edges, label='Number of simplified edges') + plt.plot(checkpoints, num_loops, label='Number of loops') + plt.plot(checkpoints, num_simplified_loops, label='Number of simplified loops') + + plt.title("Line Graph with Multiple Lines") + plt.yscale('log') + plt.xlabel("Checkpoints") + # Show legend + plt.legend() + + # Save the figure as an image file + plt.savefig("multi_line_graph.png") + def get_checkpoint_stats(self, checkpoint_id:int)->dict: if checkpoint_id not in self._wins_per_checkpoint: raise IndexError(f"Checkpoint id '{checkpoint_id}' not found!") @@ -112,7 +137,7 @@ def get_checkpoint_stats(self, checkpoint_id:int)->dict: data["num_edges"] = len(self._checkpoint_edges[checkpoint_id]) data["num_simplified_edges"] = len(self._checkpoint_simple_edges[checkpoint_id]) data["num_loops"] = len([edge for edge in self._checkpoint_edges[checkpoint_id].keys() if edge[0]==edge[1]]) - data["num_simpligied_loops"] = len([edge for edge in self._checkpoint_simple_edges[checkpoint_id].keys() if edge[0]==edge[1]]) + data["num_simplified_loops"] = len([edge for edge in self._checkpoint_simple_edges[checkpoint_id].keys() if edge[0]==edge[1]]) node_set = set([src_node for src_node,_,_ in self._checkpoint_edges[checkpoint_id].keys()]) | set([dst_node for _,dst_node,_ in self._checkpoint_edges[checkpoint_id].keys()]) data["num_nodes"] = len(node_set) return data @@ -250,7 +275,7 @@ def get_graph_modificiation(edge_list1, edge_list2): tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories)) tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories)) print(tg.num_checkpoints) - tg.get_graph_stats_progress() + tg.plot_graph_stats_progress() super_graph = tg.get_graph_structure_progress() for k,v in super_graph.items(): if np.sum(v) > 3: From 7c930eb02def71a315bf231bbcb099a9513910b3 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 12 Nov 2024 16:22:01 +0100 Subject: [PATCH 4/7] Add probabilistic graph --- utils/gamaplay_graphs.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/utils/gamaplay_graphs.py b/utils/gamaplay_graphs.py index e048e39..8694287 100644 --- a/utils/gamaplay_graphs.py +++ b/utils/gamaplay_graphs.py @@ -151,6 +151,19 @@ def get_graph_structure_progress(self)->dict: super_graph[edge][i] = 1 return super_graph + def get_graph_structure_probabilistic_progress(self)->dict: + + all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values())) + super_graph = {key:np.zeros(self.num_checkpoints) for key in all_edges} + for i, edge_list in self._checkpoint_edges.items(): + total_out_edges_use = {} + for (src, _, _), frequency in edge_list.items(): + if src not in total_out_edges_use: + total_out_edges_use[src] = 0 + total_out_edges_use[src] += frequency + for (src,dst,edge), value in edge_list.items(): + super_graph[(src,dst,edge)][i] = value/total_out_edges_use[src] + return super_graph def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple: edges = {} @@ -276,7 +289,7 @@ def get_graph_modificiation(edge_list1, edge_list2): tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories)) print(tg.num_checkpoints) tg.plot_graph_stats_progress() - super_graph = tg.get_graph_structure_progress() + super_graph = tg.get_graph_structure_probabilistic_progress() for k,v in super_graph.items(): - if np.sum(v) > 3: + if np.mean(v) > 0.25: print(k, v) From 45b1b12dab76a1838840c416ac743f7d5d3dbf91 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Wed, 13 Nov 2024 11:03:41 +0100 Subject: [PATCH 5/7] Add searching for action based on the ID --- utils/gamaplay_graphs.py | 42 ++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/utils/gamaplay_graphs.py b/utils/gamaplay_graphs.py index 8694287..b952330 100644 --- a/utils/gamaplay_graphs.py +++ b/utils/gamaplay_graphs.py @@ -44,7 +44,10 @@ def get_action_id(self, action:Action)->int: if action not in self._action_to_id.keys(): self._action_to_id[action] = len(self._action_to_id) self._id_to_action[self._action_to_id[action]] = action - return self._action_to_id[action] + return self._action_to_id[action] + + def get_action(self, id:int)-> Action: + return self._id_to_action[id] def add_checkpoint(self, trajectories:list, end_reason=None)->None: # Add complete trajectory list @@ -279,17 +282,28 @@ def get_graph_modificiation(edge_list1, edge_list2): # # for node in node_set(graph_t1).intersection(node_set(graph_t2)): # # print(g1_timestaps[node], g2_timestaps[node]) # # print("-----------------------") - tg = TrajectoryGraph() + # tg_no_blocks = TrajectoryGraph() + + # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-2000.jsonl",max_lines=args.n_trajectories)) + # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-4000.jsonl",max_lines=args.n_trajectories)) + # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-6000.jsonl",max_lines=args.n_trajectories)) + # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-8000.jsonl",max_lines=args.n_trajectories)) + # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories)) + # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories)) - tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-2000.jsonl",max_lines=args.n_trajectories)) - tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-4000.jsonl",max_lines=args.n_trajectories)) - tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-6000.jsonl",max_lines=args.n_trajectories)) - tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-8000.jsonl",max_lines=args.n_trajectories)) - tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories)) - tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories)) - print(tg.num_checkpoints) - tg.plot_graph_stats_progress() - super_graph = tg.get_graph_structure_probabilistic_progress() - for k,v in super_graph.items(): - if np.mean(v) > 0.25: - print(k, v) + # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_no_blocks.jsonl",max_lines=args.n_trajectories)) + # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_no_blocks.jsonl",max_lines=args.n_trajectories)) + # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_no_blocks.jsonl",max_lines=args.n_trajectories)) + # tg_no_blocks.plot_graph_stats_progress() + + tg_blocks = TrajectoryGraph() + tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_blocks.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_blocks.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_blocks.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-20000_blocks.jsonl",max_lines=args.n_trajectories)) + tg_blocks.plot_graph_stats_progress() + + super_graph = tg_blocks.get_graph_structure_probabilistic_progress() + print(len(super_graph)) + edges_present_everycheckpoint = [k for k,v in super_graph.items() if np.min(v) > 0] + print(len(edges_present_everycheckpoint)) \ No newline at end of file From cfe40dc3d53ebe2862ea5a3c879a24e66702a7d5 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Wed, 13 Nov 2024 11:43:04 +0100 Subject: [PATCH 6/7] give the filename in the arguments --- utils/gamaplay_graphs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/gamaplay_graphs.py b/utils/gamaplay_graphs.py index b952330..c42e794 100644 --- a/utils/gamaplay_graphs.py +++ b/utils/gamaplay_graphs.py @@ -106,7 +106,7 @@ def get_graph_stats_progress(self): print(f'{i},\t{data["winrate"]},\t{data["num_edges"]},\t{data["num_simplified_edges"]},\t{data["num_nodes"]},\t{data["num_loops"]},\t{data["num_simplified_loops"]}') return ret - def plot_graph_stats_progress(self): + def plot_graph_stats_progress(self, filedir="figures", filename="trajectory_graph_stats.png"): data = self.get_graph_stats_progress() wr = [data[i]["winrate"] for i in range(len(data))] num_nodes = [data[i]["num_nodes"] for i in range(len(data))] @@ -121,14 +121,14 @@ def plot_graph_stats_progress(self): plt.plot(checkpoints, num_loops, label='Number of loops') plt.plot(checkpoints, num_simplified_loops, label='Number of simplified loops') - plt.title("Line Graph with Multiple Lines") + plt.title("Graph statistics per checkpoint") plt.yscale('log') plt.xlabel("Checkpoints") # Show legend plt.legend() # Save the figure as an image file - plt.savefig("multi_line_graph.png") + plt.savefig(os.path.join(filedir, filename)) def get_checkpoint_stats(self, checkpoint_id:int)->dict: if checkpoint_id not in self._wins_per_checkpoint: From bf755dd28c005d551744890a90df8011102a2750 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Wed, 13 Nov 2024 11:43:26 +0100 Subject: [PATCH 7/7] do not use three nets by default --- env/netsecenv_conf.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/env/netsecenv_conf.yaml b/env/netsecenv_conf.yaml index e74fcbb..45675b2 100644 --- a/env/netsecenv_conf.yaml +++ b/env/netsecenv_conf.yaml @@ -97,7 +97,7 @@ env: random_seed: 'random' # Or you can fix the seed # random_seed: 42 - scenario: 'three_nets' + scenario: 'scenario1' use_global_defender: False max_steps: 50 use_dynamic_addresses: False