diff --git a/include/graaflib/algorithm/shortest_path.h b/include/graaflib/algorithm/shortest_path.h index b2ed9e83..96c709fb 100644 --- a/include/graaflib/algorithm/shortest_path.h +++ b/include/graaflib/algorithm/shortest_path.h @@ -12,10 +12,10 @@ namespace graaf::algorithm { // TODO(bluppes): I would expose the names of the underlying algorithms here. enum class edge_strategy { WEIGHTED, UNWEIGHTED }; -template +template struct GraphPath { std::list vertices; - E total_weight; + WEIGHT_T total_weight; bool operator==(const GraphPath& other) const { return vertices == other.vertices && total_weight == other.total_weight; diff --git a/include/graaflib/algorithm/shortest_path.tpp b/include/graaflib/algorithm/shortest_path.tpp index 4fc815f9..a5524325 100644 --- a/include/graaflib/algorithm/shortest_path.tpp +++ b/include/graaflib/algorithm/shortest_path.tpp @@ -10,18 +10,48 @@ namespace graaf::algorithm { namespace detail { +template +struct PathVertex { + vertex_id_t id; + WEIGHT_T dist_from_start; + vertex_id_t prev_id; + + [[nodiscard]] bool operator>(const PathVertex& other) { + return dist_from_start > other.dist_from_start; + } +}; + +template +std::optional> reconstruct_path( + vertex_id_t start_vertex, vertex_id_t end_vertex, + std::unordered_map>& vertex_info) { + if (!vertex_info.contains(end_vertex)) { + return std::nullopt; + } + + GraphPath path; + auto current = end_vertex; + + while (current != start_vertex) { + path.vertices.push_front(current); + current = vertex_info[current].prev_id; + } + + path.vertices.push_front(start_vertex); + path.total_weight = vertex_info[end_vertex].dist_from_start; + return path; +} + template std::optional> get_unweighted_shortest_path( const graph& graph, vertex_id_t start_vertex, vertex_id_t end_vertex) { - std::unordered_set seen_vertices{}; - std::unordered_map prev_vertex{}; + std::unordered_map> vertex_info; std::queue to_explore{}; + vertex_info[start_vertex] = {start_vertex, 1, start_vertex}; to_explore.push(start_vertex); - seen_vertices.insert(start_vertex); - // TODO: align/merge with implementation of do_bfs in graph_traversal.tpp while (!to_explore.empty()) { auto current{to_explore.front()}; to_explore.pop(); @@ -31,53 +61,27 @@ std::optional> get_unweighted_shortest_path( } for (const auto& neighbor : graph.get_neighbors(current)) { - if (!seen_vertices.contains(neighbor)) { - seen_vertices.insert(neighbor); - prev_vertex[neighbor] = current; + if (!vertex_info.contains(neighbor)) { + vertex_info[neighbor] = { + neighbor, vertex_info[current].dist_from_start + 1, current}; to_explore.push(neighbor); } } } - const auto reconstruct_path = [&start_vertex, &end_vertex, &prev_vertex]() { - GraphPath path; - auto current{end_vertex}; - - while (current != start_vertex) { - path.vertices.push_front(current); - current = prev_vertex[current]; - } - - path.vertices.push_front(start_vertex); - path.total_weight = path.vertices.size(); - - return path; - }; - - if (seen_vertices.contains(end_vertex)) { - return reconstruct_path(); - } else { - return std::nullopt; - } + return reconstruct_path(start_vertex, end_vertex, vertex_info); } template std::optional> get_weighted_shortest_path( const graph& graph, vertex_id_t start_vertex, vertex_id_t end_vertex) { - struct DijkstraVertex { - vertex_id_t id; - WEIGHT_T distance; - vertex_id_t previous; - }; - - std::unordered_map vertex_info; - std::priority_queue< - DijkstraVertex, std::vector, - std::function> - to_explore([](const DijkstraVertex& v1, const DijkstraVertex& v2) { - return v1.distance > v2.distance; - }); + std::unordered_map> vertex_info; + + using weighted_path_item = PathVertex; + std::priority_queue, + std::greater<>> + to_explore{}; vertex_info[start_vertex] = {start_vertex, 0, start_vertex}; to_explore.push(vertex_info[start_vertex]); @@ -91,39 +95,18 @@ std::optional> get_weighted_shortest_path( } for (const auto& neighbor : graph.get_neighbors(current.id)) { - WEIGHT_T distance = - current.distance + graph.get_edge(current.id, neighbor)->get_weight(); + WEIGHT_T distance = current.dist_from_start + + graph.get_edge(current.id, neighbor)->get_weight(); if (!vertex_info.contains(neighbor) || - distance < vertex_info[neighbor].distance) { + distance < vertex_info[neighbor].dist_from_start) { vertex_info[neighbor] = {neighbor, distance, current.id}; to_explore.push(vertex_info[neighbor]); } } } - const auto reconstruct_path = [&start_vertex, &end_vertex, &vertex_info]() { - GraphPath path; - auto current = end_vertex; - - while (current != start_vertex) { - path.vertices.push_back(current); - current = vertex_info[current].previous; - } - - path.vertices.push_back(start_vertex); - path.total_weight = vertex_info[end_vertex].distance; - - std::reverse(path.vertices.begin(), path.vertices.end()); - - return path; - }; - - if (vertex_info.contains(end_vertex)) { - return reconstruct_path(); - } else { - return std::nullopt; - } + return reconstruct_path(start_vertex, end_vertex, vertex_info); } } // namespace detail