diff --git a/aequilibrae/paths/graph.py b/aequilibrae/paths/graph.py index 77676a424..a3ac40f91 100644 --- a/aequilibrae/paths/graph.py +++ b/aequilibrae/paths/graph.py @@ -137,7 +137,7 @@ def default_types(self, tp: str): else: raise ValueError("It must be either a int or a float") - def prepare_graph(self, centroids: Optional[np.ndarray]) -> None: + def prepare_graph(self, centroids: Optional[np.ndarray] = None) -> None: """ Prepares the graph for a computation for a certain set of centroids @@ -341,18 +341,22 @@ def set_graph(self, cost_field) -> None: :Arguments: **cost_field** (:obj:`str`): Field name. Must be numeric """ - if cost_field in self.graph.columns: - self.cost_field = cost_field + if cost_field not in self.graph.columns: + raise ValueError("cost_field not available in the graph:" + str(self.graph.columns)) + + self.cost_field = cost_field + + # We only have a compact graph if we have added centroids, as that's used for skimming and assignment + if not self.compact_graph.empty: self.compact_cost = np.zeros(self.compact_graph.id.max() + 2, self.__float_type) df = self.__graph_groupby.sum(numeric_only=True)[[cost_field]].reset_index() self.compact_cost[df.index.values] = df[cost_field].values - if self.graph[cost_field].dtype == self.__float_type: - self.cost = np.array(self.graph[cost_field].values, copy=True) - else: - self.cost = np.array(self.graph[cost_field].values, dtype=self.__float_type) - self.logger.warning("Cost field with wrong type. Converting to float64") + + if self.graph[cost_field].dtype == self.__float_type: + self.cost = np.array(self.graph[cost_field].values, copy=True) else: - raise ValueError("cost_field not available in the graph:" + str(self.graph.columns)) + self.cost = np.array(self.graph[cost_field].values, dtype=self.__float_type) + self.logger.warning("Cost field with wrong type. Converting to float64") self.__build_derived_properties() diff --git a/tests/aequilibrae/paths/test_graph.py b/tests/aequilibrae/paths/test_graph.py index e8ae4665d..fd6e66c2a 100644 --- a/tests/aequilibrae/paths/test_graph.py +++ b/tests/aequilibrae/paths/test_graph.py @@ -39,6 +39,11 @@ def test_prepare_graph(self): graph = self.project.network.graphs["c"] graph.prepare_graph(np.arange(5) + 1) + def test_prepare_graph_no_centroids(self): + graph = self.project.network.graphs["c"] + graph.prepare_graph() + graph.set_graph("distance") + def test_set_graph(self): self.graph.set_graph(cost_field="distance") self.graph.set_blocked_centroid_flows(block_centroid_flows=True)