Skip to content

Commit

Permalink
fix: fix clean_air_grid being wrongly constructed of integers instead…
Browse files Browse the repository at this point in the history
… of np.float32

which caused unexpected behavior when passing a custom default_weight
param
  • Loading branch information
eladyaniv01 committed Oct 1, 2020
1 parent e0c8e19 commit 278241c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
4 changes: 2 additions & 2 deletions MapAnalyzer/Pather.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def get_climber_grid(self, default_weight: int = 1, include_destructables: bool
def get_clean_air_grid(self, default_weight: int = 1) -> ndarray:
clean_air_grid = np.ones(shape=self.map_data.path_arr.shape).astype(np.float32).T
if default_weight == 1:
return clean_air_grid
return clean_air_grid.copy()
else:
return np.where(clean_air_grid == 1, default_weight, 0)
return np.where(clean_air_grid == 1, default_weight, np.inf).astype(np.float32)

def get_air_vs_ground_grid(self, default_weight: int) -> ndarray:
grid = np.fmax(self.map_data.path_arr, self.map_data.placement_arr).T
Expand Down
22 changes: 15 additions & 7 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from MapAnalyzer.MapData import MapData
from MapAnalyzer.utils import import_bot_instance

import matplotlib.pyplot as plt

def get_random_point(minx, maxx, miny, maxy):
return (random.randint(minx, maxx), random.randint(miny, maxy))
Expand Down Expand Up @@ -51,17 +50,26 @@ def get_map_file_list() -> List[str]:
reg_end = map_data.where_all(map_data.bot.enemy_start_locations[0].position)[0]
p0 = Point2(reg_start.center)
p1 = Point2(reg_end.center)
influence_grid = map_data.get_air_vs_ground_grid(default_weight=50)
influence_grid = map_data.get_clean_air_grid(default_weight=50)
# influence_grid = map_data.get_pyastar_grid()
cost_point = (50, 130)
# cost_point = (50, 130)
cost_point = (87, 76)
influence_grid = map_data.add_cost(position=cost_point, radius=7, grid=influence_grid)
safe_points = map_data.find_lowest_cost_points(from_pos=cost_point, radius=14, grid=influence_grid)
cost_point = (108, 64)
influence_grid = map_data.add_cost(position=cost_point, radius=7, grid=influence_grid)
cost_point = (97, 53)
influence_grid = map_data.add_cost(position=cost_point, radius=7, grid=influence_grid)
# safe_points = map_data.find_lowest_cost_points(from_pos=cost_point, radius=14, grid=influence_grid)

# logger.info(safe_points)

x, y = zip(*safe_points)
plt.scatter(x, y, s=1)
map_data.plot_influenced_path(start=p0, goal=p1, weight_array=influence_grid, allow_diagonal=False)
# x, y = zip(*safe_points)
# plt.scatter(x, y, s=1)
path = map_data.pathfind(start=p0, goal=p1, grid=influence_grid, allow_diagonal=True)
from loguru import logger

logger.info(len(path))
map_data.plot_influenced_path(start=p0, goal=p1, weight_array=influence_grid, allow_diagonal=True)
# map_data.save(filename=f"{mf}")
# plt.close()
map_data.show()
Expand Down

0 comments on commit 278241c

Please sign in to comment.