-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaction_blocking_helping_functions.py
90 lines (78 loc) · 2.65 KB
/
action_blocking_helping_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from envs.flatland.observations.segment_graph import Graph
def get_coords(direction):
if direction == 0:
return -1, 0
elif direction == 1:
return 0, 1
elif direction == 2:
return 1, 0
elif direction == 3:
return 0, -1
def stop_deadlock_when_unavoidable(timestamp_segment_dict, to_reset, handle, direction, action, action_mask, old_pos):
# print(obs[agent_id][8])
dx, dy = get_new_pos_dx_dy(direction, action)
new_pos = (old_pos[0] + dx, old_pos[1] + dy)
# print(handle, direction, old_pos, new_pos)
fr, to = Graph.agents[handle].CurrentNode, Graph.agents[handle].NextNodes
segments = []
for node in to:
segments.append(Graph.graph_global[fr][node]['segment'])
curr_segment = None
for segment in segments:
for x, y, _ in segment:
if new_pos == (x, y):
curr_segment = segment
break
if curr_segment is None:
return timestamp_segment_dict, to_reset, action
curr_segment = frozenset((x, y) for x, y, _ in curr_segment)
if curr_segment not in timestamp_segment_dict or not timestamp_segment_dict[curr_segment]:
timestamp_segment_dict[curr_segment] = True
# print(f"occupied by {handle} segment: {curr_segment}")
to_reset.append(curr_segment)
else:
# print(f"old action was {action}")
action = pick_new_action(action, action_mask)
# print(f"new action is {action}")
return timestamp_segment_dict, to_reset, action
def reset_timestamp_dict(timestamp_segment_dict, to_reset):
for segment in to_reset:
# print(f"removing segment {segment}")
timestamp_segment_dict[segment] = False
return timestamp_segment_dict
def pick_new_action(old_action, action_mask):
action_mask[old_action - 1] = 0
action_mask[3] = 0
available = [i + 1 for i in range(len(action_mask)) if action_mask[i] == 1]
if len(available) == 0:
return old_action
return available[0]
def get_new_pos_dx_dy(direc, action):
if direc == 2:
if action == 1:
return 0, 1
if action == 2:
return 1, 0
if action == 3:
return 0, -1
if direc == 1:
if action == 1:
return -1, 0
if action == 2:
return 0, 1
if action == 3:
return 1, 0
if direc == 0:
if action == 1:
return 0, -1
if action == 2:
return -1, 0
if action == 3:
return 0, 1
if direc == 3:
if action == 1:
return 1, 0
if action == 2:
return 0, -1
if action == 3:
return -1, 0