Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vehicle tracker and dataframe viewer in GUI #70

Merged
merged 2 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 106 additions & 19 deletions uxsim/ResultGUIViewer/ResultGUIViewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import sys
import numpy as np
from matplotlib import colormaps
from PyQt5.QtWidgets import QApplication, QMainWindow, QGraphicsView, QGraphicsScene, QGraphicsItem, QMenu, QSlider, QVBoxLayout, QWidget, QHBoxLayout, QLabel, QPushButton
from PyQt5.QtWidgets import QApplication, QMainWindow, QGraphicsView, QGraphicsScene, QGraphicsItem, QMenu, QSlider, QVBoxLayout, QWidget, QHBoxLayout, QLabel, QPushButton, QInputDialog, QMessageBox, QTableView, QDialog, QFileDialog
from PyQt5.QtGui import QPen, QColor, QPainter, QPainterPath
from PyQt5.QtCore import Qt, QPointF, QRectF, QTimer
from PyQt5.QtCore import Qt, QPointF, QRectF, QTimer, QAbstractTableModel


class EdgeItem(QGraphicsItem):
def __init__(self, name, start_node, end_node, density_list, Link):
Expand All @@ -44,7 +45,7 @@ def shape(self):
length = (dx**2 + dy**2)**0.5
offset = length/10
if dx == 0:
offset = QPointF(0, offset if dy > 0 else -offset)
offset = QPointF(-offset*self.curve_direction if dy > 0 else offset*self.curve_direction, 0)
else:
normal = QPointF(-dy, dx)
normal /= (normal.x()**2 + normal.y()**2)**0.5
Expand All @@ -68,7 +69,6 @@ def paint(self, painter, option, widget):
lw = max([density*self.Link.delta*self.Link.lanes])*(maxlw-minlw)+minlw

c = colormaps["viridis"](speed/self.Link.u)
#color = QColor(int(density/self.Link.jam_density * 255), int(density/self.Link.jam_density * 255), 0, 255)
color = QColor(int(c[0]*255), int(c[1]*255), int(c[2]*255), 255)
pen = QPen(color, lw)
painter.setPen(pen)
Expand Down Expand Up @@ -127,6 +127,7 @@ def set_curve_direction(self, direction):
def set_show_name(self, show_name):
self.show_name = show_name


class NodeItem(QGraphicsItem):
def __init__(self, name, x, y, Node):
super().__init__()
Expand All @@ -149,6 +150,7 @@ def paint(self, painter, option, widget):
def set_show_name(self, show_name):
self.show_name = show_name


class VehicleItem(QGraphicsItem):
def __init__(self, x, y):
super().__init__()
Expand All @@ -162,6 +164,7 @@ def paint(self, painter, option, widget):
painter.setBrush(Qt.red)
painter.drawEllipse(-5, -5, 10, 10)


class GraphWidget(QGraphicsView):
def __init__(self, nodes, edges, vehicle_list):
super().__init__()
Expand All @@ -188,9 +191,12 @@ def __init__(self, nodes, edges, vehicle_list):
edge = EdgeItem(name, start_node, end_node, density_list, Node)
self.scene().addItem(edge)
self.edges.append(edge)

self.set_vehice_items()

def set_vehice_items(self):
self.vehicle_items = []
if self.show_vehicles and self.vehicle_list != None:
if self.vehicle_list != None:
for t, edge_name, x in self.vehicle_list:
edge = self.find_edge(edge_name)
if edge:
Expand Down Expand Up @@ -253,15 +259,16 @@ def set_show_vehicles(self, show_vehicles):
self.viewport().update()

class MainWindow(QMainWindow):
def __init__(self, W, nodes, edges, vehicle_list, tmax):
def __init__(self, W, nodes, edges, vehicle_list, tmax, dt):
super().__init__()
self.setWindowTitle("UXsim result viewer")
self.W = W
self.tmax = tmax
self.dt = dt
self.playing = False
self.curve_direction = 1
self.show_names = True
self.show_vehicles = True
self.show_vehicles = False

central_widget = QWidget()
layout = QVBoxLayout()
Expand Down Expand Up @@ -304,6 +311,18 @@ def __init__(self, W, nodes, edges, vehicle_list, tmax):
# menu_file = menu_bar.addMenu("File")
# acrion_save_world = menu_file.addAction("Save World")
# acrion_save_world.triggered.connect(lambda: self.save_world())

menu_data = menu_bar.addMenu("Data")
action_basic_stats = menu_data.addAction("Basic Statistics")
action_basic_stats.triggered.connect(lambda: self.show_dataframe("Basic", self.W.analyzer.basic_to_pandas()))
action_basic_stats = menu_data.addAction("Link Statistics")
action_basic_stats.triggered.connect(lambda: self.show_dataframe("Link", self.W.analyzer.link_to_pandas()))
action_basic_stats = menu_data.addAction("OD Demand Statistics")
action_basic_stats.triggered.connect(lambda: self.show_dataframe("OD Demand", self.W.analyzer.od_to_pandas()))
action_basic_stats = menu_data.addAction("Vehicle Trip Statistics")
action_basic_stats.triggered.connect(lambda: self.show_dataframe("Vehicle Trip", self.W.analyzer.vehicle_trip_to_pandas()))
action_basic_stats = menu_data.addAction("Vehicle Detailed Statistics")
action_basic_stats.triggered.connect(lambda: self.show_dataframe("Vehicle", self.W.analyzer.vehicles_to_pandas()))

menu_settings = menu_bar.addMenu("Settings")
option_curve_direction = menu_settings.addMenu("Link Curve Direction")
Expand All @@ -318,6 +337,14 @@ def __init__(self, W, nodes, edges, vehicle_list, tmax):
show_names_action.setChecked(True)
show_names_action.triggered.connect(self.toggle_show_names)

menu_Vehicle = menu_bar.addMenu("Vehicle Analysis")
# show_vehicles_action = menu_Vehicle.addAction("Show Vehicle")
# show_vehicles_action.setCheckable(True)
# show_vehicles_action.setChecked(False)
# show_vehicles_action.triggered.connect(self.toggle_show_vehicles)
action_show_vehicle = menu_Vehicle.addAction("Highlight Vehicle by ID")
action_show_vehicle.triggered.connect(self.show_vehicle_by_id)

menu_Animation = menu_bar.addMenu("Export Results")
action_csv = menu_Animation.addAction("Export Results to CSV files")
action_csv.triggered.connect(lambda: self.W.analyzer.output_data())
Expand All @@ -328,18 +355,24 @@ def __init__(self, W, nodes, edges, vehicle_list, tmax):
action_network_anim_fancy = menu_Animation.addAction("Export Network Animation (vehicle-level)")
action_network_anim_fancy.triggered.connect(lambda: self.W.analyzer.network_fancy())

# menu_Vehicle = menu_bar.addMenu("Vehicle Analysis")
# show_vehicles_action = menu_Vehicle.addAction("Show Vehicles")
# show_vehicles_action.setCheckable(True)
# show_vehicles_action.setChecked(True)
# show_vehicles_action.triggered.connect(self.toggle_show_vehicles)

self.update_graph()

def save_world(self):
import pickle
with open("World.pkl", mode="wb") as f:
pickle.dump(self.W, f)
def show_dataframe(self, title, df):
viewer = DataFrameViewer(df, title, self)
viewer.show()

def save_world(self, default_filename='untitled.pkl_dill'):
#TODO: do something about "maximum recursion depth exceeded in comparison" error
import dill as pickle
filename, _ = QFileDialog.getSaveFileName(None, 'Save the world', default_filename, 'Pickle (by Dill package) Files (*.pkl_dill);;All Files (*)')

if filename:
try:
with open(filename, 'wb') as file:
pickle.dump(self.W, file)
print(f'World saved successfully: {filename}')
except Exception as e:
print(f'Error saving object: {str(e)}')

def update_graph(self):
t = self.t_slider.value()
Expand Down Expand Up @@ -377,6 +410,60 @@ def toggle_show_names(self):
def toggle_show_vehicles(self):
self.show_vehicles = not self.show_vehicles
self.graph_widget.set_show_vehicles(self.show_vehicles)

def show_vehicle_by_id(self):
vehicle_id, ok = QInputDialog.getText(self, "Highlight Vehicle", "<b>Enter Vehicle ID</b><br>Note that fast vehicles will be plotted as multiple dots in the animation.")
if ok and vehicle_id:
self.vehicle_id = vehicle_id
if vehicle_id not in self.W.VEHICLES:
QMessageBox.warning(self, "Vehicle ID not found", "The specified Vehicle ID was not found.")
return
veh = self.W.VEHICLES[vehicle_id]
self.graph_widget.vehicle_list = [(int(veh.log_t[i]/self.dt), veh.log_link[i].name, veh.log_x[i]/veh.log_link[i].length) for i in range(len(veh.log_t)) if veh.log_link[i] != -1]
print(veh, self.graph_widget.vehicle_list)

self.graph_widget.set_vehice_items()

self.graph_widget.set_show_vehicles(True)


class PandasModel(QAbstractTableModel):
def __init__(self, data):
super(PandasModel, self).__init__()
self._data = data

def rowCount(self, parent=None):
return self._data.shape[0]

def columnCount(self, parent=None):
return self._data.shape[1]

def data(self, index, role=Qt.DisplayRole):
if index.isValid() and role == Qt.DisplayRole:
return str(self._data.iloc[index.row(), index.column()])
return None

def headerData(self, section, orientation, role=Qt.DisplayRole):
if role == Qt.DisplayRole:
if orientation == Qt.Horizontal:
return str(self._data.columns[section])
elif orientation == Qt.Vertical:
return str(self._data.index[section])
return None


class DataFrameViewer(QDialog):
def __init__(self, data, title, parent=None):
super(DataFrameViewer, self).__init__(parent)
self.setWindowTitle(title)
self.setLayout(QVBoxLayout())
self.model = PandasModel(data)
self.view = QTableView()
self.view.setModel(self.model)
self.layout().addWidget(self.view)

self.resize(1200, 600)


def launch_World_viewer(W, return_app_window=False):
"""
Expand Down Expand Up @@ -424,10 +511,10 @@ def launch_World_viewer(W, return_app_window=False):
node[2] = (maxy - node[2]) / (maxy - miny) * (xysize - xybuffer*2) + xybuffer

edges = [[l.name, l.start_node.name, l.end_node.name, l.k_mat, l] for l in W.LINKS]
vehicle_list = None
dt = W.LINKS[0].edie_dt

app = QApplication(sys.argv)
window = MainWindow(W, nodes, edges, vehicle_list, tmax)
window = MainWindow(W, nodes, edges, None, tmax, dt)
window.show()
if return_app_window:
return app, window
Expand Down
29 changes: 28 additions & 1 deletion uxsim/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def time_space_diagram_traj_links(s, linkslist, figsize=(12,4), plot_signal=True
figsize : tuple of int, optional
The size of the figure to be plotted, default is (12,4).
plot_signal : bool, optional
Plot the signal red light.
Plot the signal red light.
"""
if s.W.vehicle_logging_timestep_interval != 1:
warnings.warn("vehicle_logging_timestep_interval is not 1. The plot is not exactly accurate.", LoggingWarning)
Expand Down Expand Up @@ -1191,6 +1191,33 @@ def log_vehicles_to_pandas(s):
"""
return s.vehicles_to_pandas()

def vehicle_trip_to_pandas(s):
"""
Converts the vehicle trip top to a pandas DataFrame.

Returns
-------
pd.DataFrame
A DataFrame containing the top of the vehicle trip logs, with the columns:
- 'name': the name of the vehicle (platoon).
- 'orig': the origin node of the vehicle's trip.
- 'dest': the destination node of the vehicle's trip.
- 'departure_time': the departure time of the vehicle.
- 'final_state': the final state of the vehicle.
- 'travel_time': the travel time of the vehicle.
- 'average_speed': the average speed of the vehicle.
"""
out = [["name", "orig", "dest", "departure_time", "final_state", "travel_time", "average_speed"]]
for veh in s.W.VEHICLES.values():
veh_dest_name = veh.dest.name if veh.dest != None else None
veh_state = veh.log_state[-1]
veh_ave_speed = np.average([v for v in veh.log_v if v != -1])

out.append([veh.name, veh.orig.name, veh_dest_name, veh.departure_time*s.W.DELTAT, veh_state, veh.travel_time, veh_ave_speed])

s.df_vehicle_trip = pd.DataFrame(out[1:], columns=out[0])
return s.df_vehicle_trip

def basic_to_pandas(s):
"""
Converts the basic stats to a pandas DataFrame.
Expand Down