diff --git a/OTAnalytics/adapter_ui/view_model.py b/OTAnalytics/adapter_ui/view_model.py index 4f3d19550..b4fe4b8dc 100644 --- a/OTAnalytics/adapter_ui/view_model.py +++ b/OTAnalytics/adapter_ui/view_model.py @@ -417,3 +417,7 @@ def get_weather_types(self) -> ColumnResources: @abstractmethod def set_svz_metadata_frame(self, frame: AbstractFrameSvzMetadata) -> None: raise NotImplementedError + + @abstractmethod + def get_save_path_suggestion(self, file_type: str, context_file_type: str) -> Path: + raise NotImplementedError diff --git a/OTAnalytics/application/application.py b/OTAnalytics/application/application.py index 502627881..22f48adb7 100644 --- a/OTAnalytics/application/application.py +++ b/OTAnalytics/application/application.py @@ -54,6 +54,7 @@ GetSectionsById, ) from OTAnalytics.application.use_cases.start_new_project import StartNewProject +from OTAnalytics.application.use_cases.suggest_save_path import SavePathSuggester from OTAnalytics.application.use_cases.track_repository import ( GetAllTrackFiles, TrackRepositorySize, @@ -129,6 +130,7 @@ def __init__( load_otconfig: LoadOtconfig, config_has_changed: ConfigHasChanged, export_road_user_assignments: ExportRoadUserAssignments, + file_name_suggester: SavePathSuggester, ) -> None: self._datastore: Datastore = datastore self.track_state: TrackState = track_state @@ -168,6 +170,7 @@ def __init__( self._load_otconfig = load_otconfig self._config_has_changed = config_has_changed self._export_road_user_assignments = export_road_user_assignments + self._file_name_suggester = file_name_suggester def connect_observers(self) -> None: """ @@ -640,6 +643,30 @@ def get_road_user_export_formats( ) -> Iterable[ExportFormat]: return self._export_road_user_assignments.get_supported_formats() + def suggest_save_path(self, file_type: str, context_file_type: str = "") -> Path: + """Suggests a save path based on the given file type and an optional + related file type. + + The suggested path is in the following format: + /.. + + The base folder will be determined in the following precedence: + 1. First loaded config file (otconfig or otflow) + 2. First loaded track file (ottrk) + 3. First loaded video file + 4. Default: Current working directory + + The file stem suggestion will be determined in the following precedence: + 1. The file stem of the loaded config file (otconfig or otflow) + 2. _ + 3. Default: + + Args: + file_type (str): the file type. + context_file_type (str): the context file type. + """ + return self._file_name_suggester.suggest(file_type, context_file_type) + class MissingTracksError(Exception): pass diff --git a/OTAnalytics/application/config.py b/OTAnalytics/application/config.py index edc5f73e7..4d85ec209 100644 --- a/OTAnalytics/application/config.py +++ b/OTAnalytics/application/config.py @@ -13,7 +13,6 @@ CLI_CUTTING_SECTION_MARKER: str = "#clicut" DEFAULT_EVENTLIST_FILE_STEM: str = "events" DEFAULT_EVENTLIST_FILE_TYPE: str = "otevents" -DEFAULT_COUNTS_FILE_STEM: str = "counts" DEFAULT_COUNTS_FILE_TYPE: str = "csv" DEFAULT_COUNT_INTERVAL_TIME_UNIT: str = "min" DEFAULT_TRACK_FILE_TYPE: str = "ottrk" @@ -23,6 +22,15 @@ DEFAULT_PROGRESSBAR_STEP_PERCENTAGE: int = 5 DEFAULT_NUM_PROCESSES = 4 + +# File Types +CONTEXT_FILE_TYPE_ROAD_USER_ASSIGNMENTS = "road_user_assignments" +CONTEXT_FILE_TYPE_EVENTS = "events" +CONTEXT_FILE_TYPE_COUNTS = "counts" +OTCONFIG_FILE_TYPE = "otconfig" +OTFLOW_FILE_TYPE = "otflow" + + # OTConfig Default Values DEFAULT_DO_EVENTS = True DEFAULT_DO_COUNTING = True diff --git a/OTAnalytics/application/use_cases/suggest_save_path.py b/OTAnalytics/application/use_cases/suggest_save_path.py new file mode 100644 index 000000000..20ae98c51 --- /dev/null +++ b/OTAnalytics/application/use_cases/suggest_save_path.py @@ -0,0 +1,113 @@ +from datetime import datetime +from pathlib import Path +from typing import Callable + +from OTAnalytics.application.state import FileState +from OTAnalytics.application.use_cases.get_current_project import GetCurrentProject +from OTAnalytics.application.use_cases.track_repository import GetAllTrackFiles +from OTAnalytics.application.use_cases.video_repository import GetAllVideos + +DATETIME_FORMAT = "%Y-%m-%d_%H-%M-%S" + + +class SavePathSuggester: + """ + Class for suggesting save paths based on the config file, otflow file, + the first track file, and video file. + + Args: + file_state (FileState): Holds information on files loaded in application. + get_all_track_files (GetAllTrackFiles): A use case that retrieves + all track files. + get_all_videos (GetAllVideos): A use case that retrieves all + video files. + get_project (GetCurrentProject): A use case that retrieves + the current project. + """ + + @property + def __config_file(self) -> Path | None: + """The path to the last loaded or saved configuration file.""" + if config_file := self._file_state.last_saved_config.get(): + return config_file.file + return None + + @property + def __first_track_file(self) -> Path | None: + """The path to the first track file.""" + + if track_files := self._get_all_track_files(): + return next(iter(track_files)) + return None + + @property + def __first_video_file(self) -> Path | None: + """The path to the first video file.""" + + if video_files := self._get_all_videos.get(): + return video_files[0].get_path() + return None + + def __init__( + self, + file_state: FileState, + get_all_track_files: GetAllTrackFiles, + get_all_videos: GetAllVideos, + get_project: GetCurrentProject, + provide_datetime: Callable[[], datetime] = datetime.now, + ) -> None: + self._file_state = file_state + self._get_all_track_files = get_all_track_files + self._get_all_videos = get_all_videos + self._get_project = get_project + self._provide_datetime = provide_datetime + + def suggest(self, file_type: str, context_file_type: str = "") -> Path: + """Suggests a save path based on the given file type and an optional + related file type. + + The suggested path is in the following format: + /.. + + The base folder will be determined in the following precedence: + 1. First loaded config file (otconfig or otflow) + 2. First loaded track file (ottrk) + 3. First loaded video file + 4. Default: Current working directory + + The file stem suggestion will be determined in the following precedence: + 1. The file stem of the loaded config file (otconfig or otflow) + 2. _ + 3. Default: + + Args: + file_type (str): the file type. + context_file_type (str): the context file type. + """ + + base_folder = self._retrieve_base_folder() + file_stem = self._suggest_file_stem() + if context_file_type: + return base_folder / f"{file_stem}.{context_file_type}.{file_type}" + return base_folder / f"{file_stem}.{file_type}" + + def _retrieve_base_folder(self) -> Path: + """Returns the base folder for suggesting a new file name.""" + if self.__config_file: + return self.__config_file.parent + if self.__first_track_file: + return self.__first_track_file.parent + if self.__first_video_file: + return self.__first_video_file.parent + return Path.cwd() + + def _suggest_file_stem(self) -> str: + """Generates a suggestion for the file stem.""" + + if self.__config_file: + return f"{self.__config_file.stem}" + + current_time = self._provide_datetime().strftime(DATETIME_FORMAT) + if project_name := self._get_project.get().name: + return f"{project_name}_{current_time}" + return current_time diff --git a/OTAnalytics/plugin_ui/cli.py b/OTAnalytics/plugin_ui/cli.py index 091d6f082..f67a0d5a9 100644 --- a/OTAnalytics/plugin_ui/cli.py +++ b/OTAnalytics/plugin_ui/cli.py @@ -6,8 +6,9 @@ CountingSpecificationDto, ) from OTAnalytics.application.config import ( + CONTEXT_FILE_TYPE_COUNTS, + CONTEXT_FILE_TYPE_ROAD_USER_ASSIGNMENTS, DEFAULT_COUNT_INTERVAL_TIME_UNIT, - DEFAULT_COUNTS_FILE_STEM, DEFAULT_COUNTS_FILE_TYPE, DEFAULT_SECTIONS_FILE_TYPE, DEFAULT_TRACK_FILE_TYPE, @@ -244,7 +245,9 @@ def _export_events(self, sections: Iterable[Section], save_path: Path) -> None: event_list_exporter.export(events, sections, actual_save_path) logger().info(f"Event list saved at '{actual_save_path}'") - assignment_path = save_path.with_suffix(".road_user_assignment.csv") + assignment_path = save_path.with_suffix( + f".{CONTEXT_FILE_TYPE_ROAD_USER_ASSIGNMENTS}.csv" + ) specification = ExportSpecification( save_path=assignment_path, format=CSV_FORMAT.name ) @@ -267,7 +270,7 @@ def _do_export_counts(self, save_path: Path) -> None: raise ValueError("modes is None but has to be defined for exporting counts") for count_interval in self._run_config.count_intervals: output_file = save_path.with_suffix( - f".{DEFAULT_COUNTS_FILE_STEM}_{count_interval}" + f".{CONTEXT_FILE_TYPE_COUNTS}_{count_interval}" f"{DEFAULT_COUNT_INTERVAL_TIME_UNIT}." f"{DEFAULT_COUNTS_FILE_TYPE}" ) diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py index 04d92c17b..4cef7f10e 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py @@ -61,6 +61,8 @@ from OTAnalytics.application.config import ( CUTTING_SECTION_MARKER, DEFAULT_COUNTING_INTERVAL_IN_MINUTES, + OTCONFIG_FILE_TYPE, + OTFLOW_FILE_TYPE, ) from OTAnalytics.application.logger import logger from OTAnalytics.application.parser.flow_parser import FlowParser @@ -174,14 +176,12 @@ LINE_SECTION: str = "line_section" TO_SECTION = "to_section" FROM_SECTION = "from_section" -OTFLOW = "otflow" MISSING_TRACK_FRAME_MESSAGE = "tracks frame" MISSING_VIDEO_FRAME_MESSAGE = "videos frame" MISSING_VIDEO_CONTROL_FRAME_MESSAGE = "video control frame" MISSING_SECTION_FRAME_MESSAGE = "sections frame" MISSING_FLOW_FRAME_MESSAGE = "flows frame" MISSING_ANALYSIS_FRAME_MESSAGE = "analysis frame" -OTCONFIG = "otconfig" class MissingInjectedInstanceError(Exception): @@ -515,16 +515,17 @@ def _show_current_project(self) -> None: self._frame_project.update(name=project.name, start_date=project.start_date) def save_otconfig(self) -> None: - title = "Save configuration as" - file_types = [(f"{OTCONFIG} file", f"*.{OTCONFIG}")] - defaultextension = f".{OTCONFIG}" - initialfile = f"config.{OTCONFIG}" - otconfig_file: Path = ask_for_save_file_path( - title, file_types, defaultextension, initialfile=initialfile + suggested_save_path = self._application.suggest_save_path(OTCONFIG_FILE_TYPE) + configuration_file = ask_for_save_file_path( + title="Save configuration as", + filetypes=[(f"{OTCONFIG_FILE_TYPE} file", f"*.{OTCONFIG_FILE_TYPE}")], + defaultextension=f".{OTCONFIG_FILE_TYPE}", + initialfile=suggested_save_path.name, + initialdir=suggested_save_path.parent, ) - if not otconfig_file: + if not configuration_file: return - self._save_otconfig(otconfig_file) + self._save_otconfig(configuration_file) def _save_otconfig(self, otconfig_file: Path) -> None: logger().info(f"Config file to save: {otconfig_file}") @@ -574,10 +575,10 @@ def load_otconfig(self) -> None: askopenfilename( title="Load configuration file", filetypes=[ - (f"{OTFLOW} file", f"*.{OTFLOW}"), - (f"{OTCONFIG} file", f"*.{OTCONFIG}"), + (f"{OTFLOW_FILE_TYPE} file", f"*.{OTFLOW_FILE_TYPE}"), + (f"{OTCONFIG_FILE_TYPE} file", f"*.{OTCONFIG_FILE_TYPE}"), ], - defaultextension=f".{OTFLOW}", + defaultextension=f".{OTFLOW_FILE_TYPE}", ) ) if not otconfig_file: @@ -596,7 +597,7 @@ def _load_otconfig(self, otconfig_file: Path) -> None: ) if proceed.canceled: return - logger().info(f"{OTCONFIG} file to load: {otconfig_file}") + logger().info(f"{OTCONFIG_FILE_TYPE} file to load: {otconfig_file}") self._application.load_otconfig(file=Path(otconfig_file)) self._show_current_project() self._show_current_svz_metadata() @@ -727,17 +728,17 @@ def load_configuration(self) -> None: # sourcery skip: avoid-builtin-shadow askopenfilename( title="Load sections file", filetypes=[ - (f"{OTFLOW} file", f"*.{OTFLOW}"), - (f"{OTCONFIG} file", f"*.{OTCONFIG}"), + (f"{OTFLOW_FILE_TYPE} file", f"*.{OTFLOW_FILE_TYPE}"), + (f"{OTCONFIG_FILE_TYPE} file", f"*.{OTCONFIG_FILE_TYPE}"), ], - defaultextension=f".{OTFLOW}", + defaultextension=f".{OTFLOW_FILE_TYPE}", ) ) if not configuration_file.stem: return - elif configuration_file.suffix == f".{OTFLOW}": + elif configuration_file.suffix == f".{OTFLOW_FILE_TYPE}": self._load_otflow(configuration_file) - elif configuration_file.suffix == f".{OTCONFIG}": + elif configuration_file.suffix == f".{OTCONFIG_FILE_TYPE}": self._load_otconfig(configuration_file) else: raise ValueError("Configuration file to load has unknown file extension") @@ -764,25 +765,22 @@ def _load_otflow(self, otflow_file: Path) -> None: self.refresh_items_on_canvas() def save_configuration(self) -> None: - initial_dir = Path.cwd() - if config_file := self._application.file_state.last_saved_config.get(): - initial_dir = config_file.file.parent - + suggested_save_path = self._application.suggest_save_path(OTFLOW_FILE_TYPE) configuration_file = ask_for_save_file_path( title="Save configuration as", filetypes=[ - (f"{OTFLOW} file", f"*.{OTFLOW}"), - (f"{OTCONFIG} file", f"*.{OTCONFIG}"), + (f"{OTFLOW_FILE_TYPE} file", f"*.{OTFLOW_FILE_TYPE}"), + (f"{OTCONFIG_FILE_TYPE} file", f"*.{OTCONFIG_FILE_TYPE}"), ], - defaultextension=f".{OTFLOW}", - initialfile=f"flows.{OTFLOW}", - initialdir=initial_dir, + defaultextension=f".{OTFLOW_FILE_TYPE}", + initialfile=suggested_save_path.name, + initialdir=suggested_save_path.parent, ) if not configuration_file.stem: return - elif configuration_file.suffix == f".{OTFLOW}": + elif configuration_file.suffix == f".{OTFLOW_FILE_TYPE}": self._save_otflow(configuration_file) - elif configuration_file.suffix == f".{OTCONFIG}": + elif configuration_file.suffix == f".{OTCONFIG_FILE_TYPE}": self._save_otconfig(configuration_file) else: raise ValueError("Configuration file to save has unknown file extension") @@ -1398,6 +1396,7 @@ def _configure_event_exporter( initial_position=(50, 50), input_values=default_values, export_format_extensions=export_format_extensions, + viewmodel=self, ).get_data() file = input_values[toplevel_export_events.EXPORT_FILE] export_format = input_values[toplevel_export_events.EXPORT_FORMAT] @@ -1760,6 +1759,7 @@ def export_road_user_assignments(self) -> None: input_values=default_values, export_format_extensions=export_formats, initial_file_stem="road_user_assignments", + viewmodel=self, ).get_data() logger().debug(export_values) save_path = export_values[toplevel_export_events.EXPORT_FILE] @@ -1833,3 +1833,6 @@ def _show_current_svz_metadata(self) -> None: self._frame_svz_metadata.update(metadata=metadata.to_dict()) else: self._frame_svz_metadata.update({}) + + def get_save_path_suggestion(self, file_type: str, context_file_type: str) -> Path: + return self._application.suggest_save_path(file_type, context_file_type) diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_counts.py b/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_counts.py index 0858beffa..18d43436f 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_counts.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_counts.py @@ -5,6 +5,10 @@ from customtkinter import CTkEntry, CTkLabel, CTkOptionMenu from OTAnalytics.adapter_ui.view_model import ViewModel +from OTAnalytics.application.config import ( + CONTEXT_FILE_TYPE_COUNTS, + DEFAULT_COUNT_INTERVAL_TIME_UNIT, +) from OTAnalytics.plugin_ui.customtkinter_gui.constants import PADX, PADY, STICKY from OTAnalytics.plugin_ui.customtkinter_gui.frame_filter import DateRow from OTAnalytics.plugin_ui.customtkinter_gui.helpers import ask_for_save_file_name @@ -18,7 +22,6 @@ END = "end" EXPORT_FORMAT = "export_format" EXPORT_FILE = "export_file" -INITIAL_FILE_STEM = "counts" class CancelExportCounts(Exception): @@ -130,11 +133,17 @@ def _create_frame_content(self, master: Any) -> FrameContent: def _choose_file(self) -> None: export_format = self._input_values[EXPORT_FORMAT] # export_extension = self._export_formats[export_format] + suggested_save_path = self._viewmodel.get_save_path_suggestion( + export_extension[1:], + f"{CONTEXT_FILE_TYPE_COUNTS}" + f"_{self._input_values[INTERVAL]}{DEFAULT_COUNT_INTERVAL_TIME_UNIT}", + ) export_file = ask_for_save_file_name( title="Save counts as", filetypes=[(export_format, export_extension)], defaultextension=export_extension, - initialfile=INITIAL_FILE_STEM, + initialfile=suggested_save_path.name, + initialdir=suggested_save_path.parent, ) self._input_values[EXPORT_FILE] = export_file if export_file == "": diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py b/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py index c6f58295e..50ff7391f 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py @@ -3,6 +3,8 @@ from customtkinter import CTkLabel, CTkOptionMenu +from OTAnalytics.adapter_ui.view_model import ViewModel +from OTAnalytics.application.config import CONTEXT_FILE_TYPE_EVENTS from OTAnalytics.plugin_ui.customtkinter_gui.constants import PADX, PADY, STICKY from OTAnalytics.plugin_ui.customtkinter_gui.helpers import ask_for_save_file_name from OTAnalytics.plugin_ui.customtkinter_gui.toplevel_template import ( @@ -12,7 +14,6 @@ EXPORT_FORMAT = "export_format" EXPORT_FILE = "export_file" -INITIAL_FILE_STEM = "events" class CancelExportEvents(Exception): @@ -70,11 +71,13 @@ def _is_int_above_zero(self, entry_value: Any) -> bool: class ToplevelExportEvents(ToplevelTemplate): def __init__( self, + viewmodel: ViewModel, export_format_extensions: dict[str, str], input_values: dict, - initial_file_stem: str = INITIAL_FILE_STEM, + initial_file_stem: str = CONTEXT_FILE_TYPE_EVENTS, **kwargs: Any, ) -> None: + self._viewmodel = viewmodel self._input_values = input_values self._export_format_extensions = export_format_extensions self._initial_file_stem = initial_file_stem @@ -89,12 +92,17 @@ def _create_frame_content(self, master: Any) -> FrameContent: def _choose_file(self) -> None: export_format = self._input_values[EXPORT_FORMAT] # - export_extension = f"*{self._export_format_extensions[export_format]}" + export_file_type = self._export_format_extensions[export_format][1:] + export_extension = f"*.{export_file_type}" + suggested_save_path = self._viewmodel.get_save_path_suggestion( + export_file_type, context_file_type=self._initial_file_stem + ) export_file = ask_for_save_file_name( title="Save counts as", filetypes=[(export_format, export_extension)], defaultextension=export_extension, - initialfile=self._initial_file_stem, + initialfile=suggested_save_path.name, + initialdir=suggested_save_path.parent, ) self._input_values[EXPORT_FILE] = export_file if export_file == "": diff --git a/OTAnalytics/plugin_ui/main_application.py b/OTAnalytics/plugin_ui/main_application.py index fbb18230f..690d9232c 100644 --- a/OTAnalytics/plugin_ui/main_application.py +++ b/OTAnalytics/plugin_ui/main_application.py @@ -127,6 +127,7 @@ RemoveSection, ) from OTAnalytics.application.use_cases.start_new_project import StartNewProject +from OTAnalytics.application.use_cases.suggest_save_path import SavePathSuggester from OTAnalytics.application.use_cases.track_repository import ( AddAllTracks, ClearAllTracks, @@ -474,13 +475,15 @@ def start_gui(self, run_config: RunConfiguration) -> None: load_track_files, parse_json, ) + get_all_videos = GetAllVideos(video_repository) + get_current_project = GetCurrentProject(datastore) config_has_changed = ConfigHasChanged( OtconfigHasChanged( config_parser, get_sections, get_flows, - GetCurrentProject(datastore), - GetAllVideos(video_repository), + get_current_project, + get_all_videos, get_all_track_files, ), OtflowHasChanged(flow_parser, get_sections, get_flows), @@ -493,6 +496,9 @@ def start_gui(self, run_config: RunConfiguration) -> None: flow_repository, create_events, ) + save_path_suggester = SavePathSuggester( + file_state, get_all_track_files, get_all_videos, get_current_project + ) application = OTAnalyticsApplication( datastore, track_state, @@ -526,6 +532,7 @@ def start_gui(self, run_config: RunConfiguration) -> None: load_otconfig, config_has_changed, export_road_user_assignments, + save_path_suggester, ) section_repository.register_sections_observer(cut_tracks_intersecting_section) section_repository.register_section_changed_observer( diff --git a/tests/OTAnalytics/application/use_cases/test_suggest_save_path.py b/tests/OTAnalytics/application/use_cases/test_suggest_save_path.py new file mode 100644 index 000000000..6864b822d --- /dev/null +++ b/tests/OTAnalytics/application/use_cases/test_suggest_save_path.py @@ -0,0 +1,166 @@ +from datetime import datetime +from pathlib import Path +from unittest.mock import Mock + +import pytest + +from OTAnalytics.application.state import ConfigurationFile, FileState +from OTAnalytics.application.use_cases.get_current_project import GetCurrentProject +from OTAnalytics.application.use_cases.suggest_save_path import ( + DATETIME_FORMAT, + SavePathSuggester, +) +from OTAnalytics.application.use_cases.track_repository import GetAllTrackFiles +from OTAnalytics.application.use_cases.video_repository import GetAllVideos + +FIRST_TRACK_FILE = Path("path/to/tracks/first.ottrk") +SECOND_TRACK_FILE = Path("path/to/tracks/second.ottrk") +FIRST_VIDEO_FILE = Path("path/to/videos/first.mp4") +SECOND_VIDEO_FILE = Path("path/to/videos/second.mp4") +PROJECT_NAME = "My Project Name" +DATETIME_NOW = datetime(2024, 1, 2, 3, 4, 5) +DATETIME_NOW_FORMATTED = DATETIME_NOW.strftime(DATETIME_FORMAT) +LAST_SAVED_OTCONFIG = Path("path/to/config/last.otconfig") +LAST_SAVED_OTFLOW = Path("path/to/otflow/last.otflow") + + +def create_file_state(last_saved_config_file: Path | None = None) -> FileState: + state = FileState() + if last_saved_config_file: + state.last_saved_config.set(ConfigurationFile(last_saved_config_file, {})) + return state + + +def create_track_file_provider( + track_files: set[Path] | None = None, +) -> GetAllTrackFiles: + if track_files: + return Mock(return_value=track_files) + return Mock(return_value=set()) + + +def create_video_provider(video_files: list[Path] | None = None) -> GetAllVideos: + videos = [] + if video_files: + for video_file in video_files: + video = Mock() + video.get_path.return_value = video_file + videos.append(video) + get_videos = Mock() + get_videos.get.return_value = videos + return get_videos + + +def create_project_provider(project_name: str = "") -> GetCurrentProject: + project = Mock() + project.name = project_name + get_project = Mock() + get_project.get.return_value = project + return get_project + + +def create_suggestor( + project_name: str = "", + last_saved_config: Path | None = None, + track_files: set[Path] | None = None, + video_files: list[Path] | None = None, +) -> SavePathSuggester: + get_project = create_project_provider(project_name) + file_state = create_file_state(last_saved_config) + get_track_files = create_track_file_provider(track_files) + get_videos = create_video_provider(video_files) + return SavePathSuggester( + file_state, + get_track_files, + get_videos, + get_project, + provide_datetime, + ) + + +def provide_datetime() -> datetime: + return DATETIME_NOW + + +class TestSavePathSuggester: + @pytest.mark.parametrize( + ( + "project_name,last_saved_config,track_files,video_files," + "context_file_type,file_type,expected" + ), + [ + ( + "", + None, + None, + None, + "", + "otconfig", + Path.cwd() / f"{DATETIME_NOW_FORMATTED}.otconfig", + ), + ( + PROJECT_NAME, + LAST_SAVED_OTCONFIG, + {FIRST_TRACK_FILE, SECOND_TRACK_FILE}, + [FIRST_VIDEO_FILE, SECOND_VIDEO_FILE], + "", + "otconfig", + LAST_SAVED_OTCONFIG.with_name(f"{LAST_SAVED_OTCONFIG.stem}.otconfig"), + ), + ( + PROJECT_NAME, + LAST_SAVED_OTCONFIG, + {FIRST_TRACK_FILE, SECOND_TRACK_FILE}, + [FIRST_VIDEO_FILE, SECOND_VIDEO_FILE], + "events", + "csv", + LAST_SAVED_OTCONFIG.with_name(f"{LAST_SAVED_OTCONFIG.stem}.events.csv"), + ), + ( + PROJECT_NAME, + None, + {FIRST_TRACK_FILE, SECOND_TRACK_FILE}, + [FIRST_VIDEO_FILE, SECOND_VIDEO_FILE], + "events", + "csv", + FIRST_TRACK_FILE.with_name( + f"{PROJECT_NAME}_{DATETIME_NOW_FORMATTED}.events.csv" + ), + ), + ( + PROJECT_NAME, + None, + None, + [FIRST_VIDEO_FILE, SECOND_VIDEO_FILE], + "events", + "csv", + FIRST_VIDEO_FILE.with_name( + f"{PROJECT_NAME}_{DATETIME_NOW_FORMATTED}.events.csv" + ), + ), + ( + PROJECT_NAME, + LAST_SAVED_OTCONFIG, + None, + [FIRST_VIDEO_FILE, SECOND_VIDEO_FILE], + "events", + "csv", + LAST_SAVED_OTCONFIG.with_name(f"{LAST_SAVED_OTCONFIG.stem}.events.csv"), + ), + ], + ) + def test_suggest_default( + self, + project_name: str, + last_saved_config: Path | None, + track_files: set[Path] | None, + video_files: list[Path] | None, + context_file_type: str, + file_type: str, + expected: Path, + ) -> None: + suggestor = create_suggestor( + project_name, last_saved_config, track_files, video_files + ) + suggestion = suggestor.suggest(file_type, context_file_type) + assert suggestion == expected diff --git a/tests/OTAnalytics/plugin_ui/test_cli.py b/tests/OTAnalytics/plugin_ui/test_cli.py index d22e1de0b..d151a1358 100644 --- a/tests/OTAnalytics/plugin_ui/test_cli.py +++ b/tests/OTAnalytics/plugin_ui/test_cli.py @@ -17,8 +17,8 @@ CountingSpecificationDto, ) from OTAnalytics.application.config import ( + CONTEXT_FILE_TYPE_COUNTS, DEFAULT_COUNT_INTERVAL_TIME_UNIT, - DEFAULT_COUNTS_FILE_STEM, DEFAULT_COUNTS_FILE_TYPE, DEFAULT_EVENTLIST_FILE_TYPE, DEFAULT_NUM_PROCESSES, @@ -564,7 +564,7 @@ def test_use_video_start_and_end_for_counting( interval = 15 filename = "filename" expected_output_file = ( - test_data_tmp_dir / f"{filename}.{DEFAULT_COUNTS_FILE_STEM}_{interval}" + test_data_tmp_dir / f"{filename}.{CONTEXT_FILE_TYPE_COUNTS}_{interval}" f"{DEFAULT_COUNT_INTERVAL_TIME_UNIT}." f"{DEFAULT_COUNTS_FILE_TYPE}" )