diff --git a/src/av2/datasets/sensor/sensor_dataloader.py b/src/av2/datasets/sensor/sensor_dataloader.py index fe221149..9e152c01 100644 --- a/src/av2/datasets/sensor/sensor_dataloader.py +++ b/src/av2/datasets/sensor/sensor_dataloader.py @@ -1,4 +1,5 @@ # + """Dataloader for the Argoverse 2 (AV2) sensor dataset.""" from __future__ import annotations @@ -19,13 +20,14 @@ from av2.structures.cuboid import CuboidList from av2.structures.sweep import Sweep from av2.structures.timestamped_image import TimestampedImage +from av2.utils.constants import HOME from av2.utils.io import TimestampedCitySE3EgoPoses, read_city_SE3_ego, read_feather, read_img from av2.utils.metric_time import TimeUnit, to_metric_time logger = logging.Logger(__name__) -LIDAR_PATTERN: Final[str] = "*/sensors/lidar/*.feather" -CAMERA_PATTERN: Final[str] = "*/sensors/cameras/*/*.jpg" +LIDAR_PATTERN: Final[str] = "**/sensors/lidar/*.feather" +CAMERA_PATTERN: Final[str] = "**/sensors/cameras/*/*.jpg" Millisecond = TimeUnit.Millisecond Nanosecond = TimeUnit.Nanosecond @@ -56,13 +58,13 @@ class SynchronizedSensorData: Enables motion compensation between the sweep and associated images. Args: - sweep: lidar sweep. - timestamp_city_SE3_ego_dict: mapping from vehicle timestamp to the egovehicle's pose in the city frame. - log_id: unique identifier for the AV2 vehicle log. - sweep_number: index of the sweep in [0, N-1], of all N sweeps in the log. - num_sweeps_in_log: number of sweeps in the log. - annotations: cuboids that have been annotated within the sweep, or None. - synchronized_imagery: mapping from camera name to timestamped imagery, or None. + sweep: Lidar sweep. + timestamp_city_SE3_ego_dict: Mapping from vehicle timestamp to the egovehicle's pose in the city frame. + log_id: Unique identifier for the AV2 vehicle log. + sweep_number: Index of the sweep in [0, N-1], of all N sweeps in the log. + num_sweeps_in_log: Number of sweeps in the log. + annotations: Cuboids that have been annotated within the sweep, or None. + synchronized_imagery: Mapping from camera name to timestamped imagery, or None. """ sweep: Sweep @@ -70,40 +72,38 @@ class SynchronizedSensorData: log_id: str sweep_number: int num_sweeps_in_log: int - annotations: Optional[CuboidList] = None synchronized_imagery: Optional[Dict[str, TimestampedImage]] = None @dataclass class SensorDataloader: - """ - Sensor dataloader for the Argoverse 2 sensor dataset. + """Sensor dataloader for the Argoverse 2 sensor dataset. NOTE: We build a cache of sensor records and synchronization information to reduce I/O overhead. Args: - sensor_dataset_dir: Path to the sensor dataset directory. + dataset_dir: Path to the sensor dataset directory. with_annotations: Flag to return annotations in the __getitem__ method. with_cams: Flag to load and return synchronized imagery in the __getitem__ method. with_cache: Flag to enable file directory caching. matching_criterion: either "nearest" or "forward". Returns: - AV2 Sensor dataset. + Argoverse 2 sensor dataloader. """ - sensor_dataset_dir: Path + dataset_dir: Path with_annotations: bool = True with_cache: bool = True - cam_names: Tuple[Union[RingCameras, StereoCameras], ...] = tuple(RingCameras) + tuple(StereoCameras) + cam_names: Tuple[Union[RingCameras, StereoCameras], ...] = tuple(RingCameras) matching_criterion = "nearest" - sensor_records: pd.DataFrame = field(init=False) + sensor_cache: pd.DataFrame = field(init=False) # Initialize synchronized metadata variable. # This is only populated when self.use_imagery is set. - sync_records: Optional[pd.DataFrame] = None + synchronization_cache: Optional[pd.DataFrame] = None def __post_init__(self) -> None: """Index the dataset for fast sensor data lookup. @@ -122,7 +122,7 @@ def __post_init__(self) -> None: lidar 315971436260099000 lidar 315971436359632000 lidar 315971436459828000 - ... ... + ... ... ff0dbfc5-8a7b-3a6e-8936-e5e812e45408 stereo_front_right 315972918949927214 stereo_front_right 315972918999927217 stereo_front_right 315972919049927212 @@ -143,30 +143,32 @@ def __post_init__(self) -> None: 315972918960050000 315972918949927220 ... 315972918949927214 315972919060249000 315972919049927214 ... 315972919049927212 """ - # Load log_id, sensor_type, and timestamp_ns information. - self.sensor_records = self._load_sensor_records() + # Load split, log_id, sensor_type, and timestamp_ns information. + self.sensor_cache = self._build_sensor_cache() # Populate synchronization database. if self.cam_names: - sync_records_path = self.sensor_dataset_dir / "._sync_records" + synchronization_cache_path = HOME / ".cache" / "av2" / "synchronization_cache.feather" + synchronization_cache_path.parent.mkdir(parents=True, exist_ok=True) # If caching is enabled AND the path exists, then load from the cache file. - if self.with_cache and sync_records_path.exists(): - self.sync_records = read_feather(sync_records_path) + if self.with_cache and synchronization_cache_path.exists(): + self.synchronization_cache = read_feather(synchronization_cache_path) else: - self.sync_records = self._build_sync_records() + self.synchronization_cache = self._build_synchronization_cache() # If caching is enabled and we haven't created the cache, then save to disk. - if self.with_cache and not sync_records_path.exists(): - self.sync_records.to_feather(str(sync_records_path)) + if self.with_cache and not synchronization_cache_path.exists(): + self.synchronization_cache.to_feather(str(synchronization_cache_path)) # Finally, create a MultiIndex set the sync records index and sort it. - self.sync_records = self.sync_records.set_index(keys=["log_id", "sensor_name", "timestamp_ns"]).sort_index() + self.synchronization_cache.set_index(keys=["split", "log_id", "sensor_name"], inplace=True) + self.synchronization_cache.sort_index(inplace=True) @cached_property def num_logs(self) -> int: """Return the number of unique logs.""" - return len(self.sensor_records.index.unique("log_id")) + return len(self.sensor_cache.index.unique("log_id")) @cached_property def num_sweeps(self) -> int: @@ -176,7 +178,7 @@ def num_sweeps(self) -> int: @cached_property def sensor_counts(self) -> pd.Series: """Return the number of records for each sensor.""" - sensor_counts: pd.Series = self.sensor_records.index.get_level_values("sensor_name").value_counts() + sensor_counts: pd.Series = self.sensor_cache.index.get_level_values("sensor_name").value_counts() return sensor_counts @property @@ -184,7 +186,7 @@ def num_sensors(self) -> int: """Return the number of sensors present throughout the dataset.""" return len(self.sensor_counts) - def _load_sensor_records(self) -> pd.DataFrame: + def _build_sensor_cache(self) -> pd.DataFrame: """Load the sensor records from the root directory. We glob the filesystem for all LiDAR and camera filepaths, and then convert each file path @@ -201,11 +203,12 @@ def _load_sensor_records(self) -> pd.DataFrame: logger.info("Building metadata ...") # Create the cache file path. - sensor_records_path = self.sensor_dataset_dir / "._sensor_records" + sensor_cache_path = HOME / ".cache" / "av2" / "sensor_cache.feather" + sensor_cache_path.parent.mkdir(parents=True, exist_ok=True) - if sensor_records_path.exists(): + if self.with_cache and sensor_cache_path.exists(): logger.info("Cache found. Loading from disk ...") - sensor_records = read_feather(sensor_records_path) + sensor_cache = read_feather(sensor_cache_path) else: lidar_records = self.populate_lidar_records() # Load camera records if enabled. @@ -213,20 +216,21 @@ def _load_sensor_records(self) -> pd.DataFrame: logger.info("Loading camera data ...") cam_records = self.populate_image_records() # Concatenate lidar and camera records. - sensor_records = pd.concat([lidar_records, cam_records]) + sensor_cache = pd.concat([lidar_records, cam_records]) else: - sensor_records = lidar_records + sensor_cache = lidar_records # Save the metadata if caching is enable. if self.with_cache: - sensor_records.reset_index(drop=True).to_feather(str(sensor_records_path)) + sensor_cache.reset_index(drop=True).to_feather(sensor_cache_path) - # Set index as tuples of the form: (log_id, sensor_name, timestamp_ns) and sort the index. - # sorts by log_id, and then by sensor name, and then by timestamp. - sensor_records = sensor_records.set_index(["log_id", "sensor_name", "timestamp_ns"]).sort_index() + # Set index as tuples of the form: (split, log_id, sensor_name, timestamp_ns) and sort the index. + # sorts by split, log_id, and then by sensor name, and then by timestamp. + sensor_cache.set_index(["split", "log_id", "sensor_name", "timestamp_ns"], inplace=True) + sensor_cache.sort_index(inplace=True) # Return all of the sensor records. - return sensor_records + return sensor_cache def populate_lidar_records(self) -> pd.DataFrame: """Obtain (log_id, sensor_name, timestamp_ns) 3-tuples for all LiDAR sweeps in the dataset. @@ -236,7 +240,7 @@ def populate_lidar_records(self) -> pd.DataFrame: N is the number of sweeps for all logs in the dataset, and the `sensor_name` column should be populated with `lidar` in every entry. """ - lidar_paths = sorted(self.sensor_dataset_dir.glob(LIDAR_PATTERN), key=lambda x: int(x.stem)) + lidar_paths = sorted(self.dataset_dir.glob(LIDAR_PATTERN), key=lambda x: int(x.stem)) lidar_record_list = [ convert_path_to_named_record(x) for x in track(lidar_paths, description="Loading lidar records ...") ] @@ -255,7 +259,7 @@ def populate_image_records(self) -> pd.DataFrame: every entry. """ # Get sorted list of camera paths. - cam_paths = sorted(self.sensor_dataset_dir.glob(CAMERA_PATTERN), key=lambda x: int(x.stem)) + cam_paths = sorted(self.dataset_dir.glob(CAMERA_PATTERN), key=lambda x: int(x.stem)) # Load entire set of camera records. cam_record_list = [ @@ -301,20 +305,19 @@ def __getitem__(self, idx: int) -> SynchronizedSensorData: """ # Grab the lidar record at the specified index. # Selects data at a particular level of a MultiIndex. - record: Tuple[str, int] = self.sensor_records.xs(key="lidar", level=1).iloc[idx].name + record: Tuple[str, str, int] = self.sensor_cache.xs(key="lidar", level=2).iloc[idx].name # Grab the identifying record fields. - log_id, timestamp_ns = record - log_lidar_records = self.sensor_records.xs((log_id, "lidar")).index + split, log_id, timestamp_ns = record + log_lidar_records = self.sensor_cache.xs((split, log_id, "lidar")).index num_frames = len(log_lidar_records) idx = np.where(log_lidar_records == timestamp_ns)[0].item() - sensor_dir = self.sensor_dataset_dir / log_id / "sensors" - lidar_feather_path = sensor_dir / "lidar" / f"{str(timestamp_ns)}.feather" + log_dir = self.dataset_dir / split / log_id + sensor_dir = log_dir / "sensors" + lidar_feather_path = sensor_dir / "lidar" / f"{timestamp_ns}.feather" sweep = Sweep.from_feather(lidar_feather_path=lidar_feather_path) - - log_dir = self.sensor_dataset_dir / log_id timestamp_city_SE3_ego_dict = read_city_SE3_ego(log_dir=log_dir) # Construct output datum. @@ -328,16 +331,16 @@ def __getitem__(self, idx: int) -> SynchronizedSensorData: # Load annotations if enabled. if self.with_annotations: - datum.annotations = self._load_annotations(log_id, timestamp_ns) + datum.annotations = self._load_annotations(split, log_id, timestamp_ns) # Load camera imagery if enabled. if self.cam_names: - datum.synchronized_imagery = self._load_synchronized_cams(sensor_dir, log_id, timestamp_ns) + datum.synchronized_imagery = self._load_synchronized_cams(split, sensor_dir, log_id, timestamp_ns) # Return datum at the specified index. return datum - def _build_sync_records(self) -> pd.DataFrame: + def _build_synchronization_cache(self) -> pd.DataFrame: """Build the synchronization records for lidar-camera synchronization. This function builds a set of records to efficiently associate auxiliary sensors @@ -348,129 +351,153 @@ def _build_sync_records(self) -> pd.DataFrame: NOTE: This function is NOT intended to be used outside of SensorDataset initialization. Returns: - (self.num_sweeps, self.num_sensors) DataFrame where each row corresponds to the nanosecond camera - timestamp that is closest (in absolute value) to the corresponding nanonsecond lidar sweep timestamp. + (self.num_sweeps, self.num_sensors) DataFrame where each row corresponds to the nanosecond camera timestamp + that is closest (in absolute value) to the corresponding nanonsecond lidar sweep timestamp. """ logger.info("Building synchronization database ...") - # Get unique log ids from the entire set of sensor data records. - log_ids: List[str] = self.sensor_records.index.unique(level="log_id").to_list() - # Create list to store synchronized data frames. sync_list: List[pd.DataFrame] = [] - - # Iterate over all log ids. - for log_id in track(log_ids, description="Building sync records ..."): - - # Select records associated with the current log id. - log_sensor_records = self.sensor_records.xs(key=log_id, level=0, drop_level=False) - - # Get unique sensor names for a particular log. - # If the entire dataset is available, each log should have 7 ring cameras - # and 2 stereo cameras. The uniqueness check is required in case a subset of - # the data is being used by the end-user. - sensor_names: List[str] = log_sensor_records.index.unique(level="sensor_name").tolist() - - # Remove lidar since we're using it as the reference sensor. - sensor_names.remove("lidar") - - # Get lidar records for the selected log. - target_records = log_sensor_records.xs(key="lidar", level=1, drop_level=False).reset_index() - for sensor_name in sensor_names: - # Obtain tuples, convert tuples back to DataFrame, then rename `timestamp_ns' col, to sensor name, - # and finally remove the `log_id` column, to leave only a single column of timestamps. - src_records: pd.DataFrame = ( - log_sensor_records.xs(sensor_name, level=1) - .reset_index() - .rename({"timestamp_ns": sensor_name}, axis=1) - .drop(["log_id"], axis=1) + unique_sensor_names: List[str] = self.sensor_cache.index.unique(level=2).tolist() + + # Associate a 'source' sensor to a 'target' sensor for all available sensors. + # For example, we associate the lidar sensor with each ring camera which + # produces a mapping from lidar -> all-other-sensors. + for src_sensor_name in unique_sensor_names: + src_records = self.sensor_cache.xs(src_sensor_name, level=2, drop_level=False).reset_index() + src_records = src_records.rename({"timestamp_ns": src_sensor_name}, axis=1).sort_values(src_sensor_name) + + # _Very_ important to convert to timedelta. Tolerance below causes precision loss otherwise. + src_records[src_sensor_name] = pd.to_timedelta(src_records[src_sensor_name]) + for target_sensor_name in unique_sensor_names: + if src_sensor_name == target_sensor_name: + continue + target_records = self.sensor_cache.xs(target_sensor_name, level=2).reset_index() + target_records = target_records.rename({"timestamp_ns": target_sensor_name}, axis=1).sort_values( + target_sensor_name ) - # Match on the closest nanosecond timestamp. - # we do not pad the values, as NaN entries are meaningful. - target_records = pd.merge_asof( - target_records, + # Merge based on matching criterion. + # _Very_ important to convert to timedelta. Tolerance below causes precision loss otherwise. + target_records[target_sensor_name] = pd.to_timedelta(target_records[target_sensor_name]) + tolerence = pd.to_timedelta(CAM_SHUTTER_INTERVAL_MS / 2 * 1e6) + if "ring" in src_sensor_name: + tolerence = pd.to_timedelta(LIDAR_SWEEP_INTERVAL_W_BUFFER_NS / 2) + src_records = pd.merge_asof( src_records, - left_on="timestamp_ns", - right_on=sensor_name, + target_records, + left_on=src_sensor_name, + right_on=target_sensor_name, + by=["split", "log_id"], direction=self.matching_criterion, - tolerance=int(LIDAR_SWEEP_INTERVAL_W_BUFFER_NS), + tolerance=tolerence, ) - - sync_list.append(target_records) - return pd.concat(sync_list).reset_index(drop=True) - - def get_closest_img_fpath(self, log_id: str, cam_name: str, lidar_timestamp_ns: int) -> Optional[Path]: - """Find the filepath to the image from a particular a camera, w/ closest timestamp to a lidar sweep timestamp. + sync_list.append(src_records) + sync_records = pd.concat(sync_list).reset_index(drop=True) + return sync_records + + def find_closest_target_fpath( + self, + split: str, + log_id: str, + src_sensor_name: str, + src_timestamp_ns: int, + target_sensor_name: str, + ) -> Optional[Path]: + """Find the file path to the target sensor from a source sensor. Args: - log_id: unique ID of vehicle log. - cam_name: name of camera. - lidar_timestamp_ns: integer timestamp of LiDAR sweep capture, in nanoseconds + split: Dataset split. + log_id: Vehicle log uuid. + src_sensor_name: Name of the source sensor. + src_timestamp_ns: Nanosecond timestamp of the source sensor (vehicle time). + target_sensor_name: Name of the target sensor. Returns: - img_fpath, string representing path to image, or else None. + The target sensor file path if it exists, otherwise None. Raises: RuntimeError: if the synchronization database (sync_records) has not been created. """ - if self.sync_records is None: + if self.synchronization_cache is None: raise RuntimeError("Requested synchronized data, but the synchronization database has not been created.") - if lidar_timestamp_ns not in self.sync_records.loc[(log_id, "lidar")].index: - # this timestamp does not correspond to any LiDAR sweep. + src_timedelta_ns = pd.Timedelta(src_timestamp_ns) + src_to_target_records = self.synchronization_cache.loc[(split, log_id, src_sensor_name)].set_index( + src_sensor_name + ) + index = src_to_target_records.index + if src_timedelta_ns not in index: + # This timestamp does not correspond to any lidar sweep. return None - # Create synchronization key. - key = (log_id, "lidar", lidar_timestamp_ns) - # Grab the synchronization record. - timestamp_ns = self.sync_records.loc[key, cam_name] - - if pd.isna(timestamp_ns): - # no match was found within tolerance. + target_timestamp_ns = src_to_target_records.loc[src_timedelta_ns, target_sensor_name] + if pd.isna(target_timestamp_ns): + # No match was found within tolerance. return None - sensor_dir = self.sensor_dataset_dir / log_id - img_path = sensor_dir / "cameras" / str(cam_name) / f"{int(timestamp_ns)}.jpg" - return img_path + sensor_dir = self.dataset_dir / split / log_id / "sensors" + valid_cameras = [x.value for x in list(RingCameras)] + [x.value for x in list(StereoCameras)] + timestamp_ns_str = str(target_timestamp_ns.asm8.item()) + if target_sensor_name in valid_cameras: + target_path = sensor_dir / "cameras" / target_sensor_name / f"{timestamp_ns_str}.jpg" + else: + target_path = sensor_dir / target_sensor_name / f"{timestamp_ns_str}.feather" + return target_path - def get_closest_lidar_fpath(self, log_id: str, cam_name: str, cam_timestamp_ns: int) -> Optional[Path]: - """Get file path for lidar sweep accumulated to a timestamp closest to a camera timestamp. + def get_closest_img_fpath(self, split: str, log_id: str, cam_name: str, lidar_timestamp_ns: int) -> Optional[Path]: + """Find the file path to the closest image from the reference camera name to the lidar timestamp. Args: - log_id: unique ID of vehicle log. - cam_name: name of camera. - cam_timestamp_ns: integer timestamp of image capture, in nanoseconds + split: Dataset split. + log_id: Vehicle log uuid. + cam_name: Name of the camera. + lidar_timestamp_ns: Name of the target sensor. Returns: - lidar_fpath: path representing path to .feather file, or else None. - - Raises: - RuntimeError: if the synchronization database (sync_records) has not been created. + File path to observation from the camera (if it exists), otherwise None. """ - if self.sync_records is None: - raise RuntimeError("Requested synchronized data, but the synchronization database has not been created.") + return self.find_closest_target_fpath( + split=split, + log_id=log_id, + src_sensor_name="lidar", + src_timestamp_ns=lidar_timestamp_ns, + target_sensor_name=cam_name, + ) - idx = np.argwhere(self.sync_records.xs(log_id)[cam_name].values == cam_timestamp_ns) - if len(idx) == 0: - # There is no image within the requested interval (50 ms). - return None + def get_closest_lidar_fpath(self, split: str, log_id: str, cam_name: str, cam_timestamp_ns: int) -> Optional[Path]: + """Find the file path to the closest image from the lidar to the reference camera. - lidar_timestamp_ns = self.sync_records.xs(log_id).loc["lidar"].index[int(idx)] - return self.sensor_dataset_dir / log_id / "sensors" / "lidar" / f"{lidar_timestamp_ns}.feather" + Args: + split: Dataset split. + log_id: Vehicle log uuid. + cam_name: Name of the camera. + cam_timestamp_ns: Name of the target sensor. - def _load_annotations(self, log_id: str, sweep_timestamp_ns: int) -> CuboidList: + Returns: + File path to observation from the lidar (if it exists), otherwise None. + """ + return self.find_closest_target_fpath( + split=split, + log_id=log_id, + src_sensor_name=cam_name, + src_timestamp_ns=cam_timestamp_ns, + target_sensor_name="lidar", + ) + + def _load_annotations(self, split: str, log_id: str, sweep_timestamp_ns: int) -> CuboidList: """Load the sweep annotations at the provided timestamp. Args: + split: Split name. log_id: Log unique id. sweep_timestamp_ns: Nanosecond timestamp. Returns: Cuboid list of annotations. """ - annotations_feather_path = self.sensor_dataset_dir / log_id / "annotations.feather" + annotations_feather_path = self.dataset_dir / split / log_id / "annotations.feather" # Load annotations from disk. # NOTE: This contains annotations for the ENTIRE sequence. @@ -480,11 +507,12 @@ def _load_annotations(self, log_id: str, sweep_timestamp_ns: int) -> CuboidList: return CuboidList(cuboids=cuboids) def _load_synchronized_cams( - self, sensor_dir: Path, log_id: str, sweep_timestamp_ns: int + self, split: str, sensor_dir: Path, log_id: str, sweep_timestamp_ns: int ) -> Optional[Dict[str, TimestampedImage]]: """Load the synchronized imagery for a lidar sweep. Args: + split: Dataset split. sensor_dir: Sensor directory. log_id: Log unique id. sweep_timestamp_ns: Nanosecond timestamp. @@ -495,16 +523,21 @@ def _load_synchronized_cams( Raises: RuntimeError: if the synchronization database (sync_records) has not been created. """ - if self.sync_records is None: + if self.synchronization_cache is None: raise RuntimeError("Requested synchronized data, but the synchronization database has not been created.") cam_paths = [ - self.get_closest_img_fpath(log_id=log_id, cam_name=cam_name, lidar_timestamp_ns=sweep_timestamp_ns) + self.find_closest_target_fpath( + split=split, + log_id=log_id, + src_sensor_name="lidar", + target_sensor_name=cam_name.value, + src_timestamp_ns=sweep_timestamp_ns, + ) for cam_name in self.cam_names ] log_dir = sensor_dir.parent - cams: Dict[str, TimestampedImage] = {} for p in cam_paths: if p is not None: diff --git a/src/av2/datasets/sensor/utils.py b/src/av2/datasets/sensor/utils.py index 8316abbf..2996c473 100644 --- a/src/av2/datasets/sensor/utils.py +++ b/src/av2/datasets/sensor/utils.py @@ -18,12 +18,16 @@ def convert_path_to_named_record(path: Path) -> Dict[str, Union[str, int]]: Returns: Mapping of name to record field. """ - sensor_name = path.parent.stem + sensor_path = path.parent + sensor_name = sensor_path.stem + log_path = sensor_path.parent.parent if sensor_name == "lidar" else sensor_path.parent.parent.parent # log_id is 2 directories up for the lidar filepaths, but 3 levels up for images # {log_id}/sensors/cameras/ring_*/*.jpg vs. # {log_id}/sensors/lidar/*.feather - parent_idx = 2 if sensor_name == "lidar" else 3 - log_id = path.parents[parent_idx].stem - sensor_name, timestamp_ns = path.parent.stem, int(path.stem) - return {"log_id": log_id, "sensor_name": sensor_name, "timestamp_ns": timestamp_ns} + return { + "split": log_path.parent.stem, + "log_id": log_path.stem, + "sensor_name": sensor_name, + "timestamp_ns": int(path.stem), + } diff --git a/src/av2/geometry/camera/pinhole_camera.py b/src/av2/geometry/camera/pinhole_camera.py index d2380de8..4e52d928 100644 --- a/src/av2/geometry/camera/pinhole_camera.py +++ b/src/av2/geometry/camera/pinhole_camera.py @@ -181,9 +181,10 @@ def project_cam_to_img( is_valid_points: boolean indicator of valid cheirality and within image boundary, as boolean Numpy array of shape (N,). """ - uv = self.intrinsics.K @ points_cam[:3, :] - uv = uv.T - points_cam = points_cam.T + points_cam = points_cam.transpose() + uv: NDArrayFloat = self.intrinsics.K @ points_cam + uv = uv.transpose() + points_cam = points_cam.transpose() if remove_nan: uv, points_cam = remove_nan_values(uv, points_cam) @@ -241,7 +242,7 @@ def project_ego_to_img_motion_compensated( boolean Numpy array of shape (N,). Raises: - ValueError: If `city_SE3_ego_cam_t` or `city_SE3_ego_lidar_t` is `None`. + ValueError: If `city_SE3_egovehicle_cam_t` or `city_SE3_egovehicle_lidar_t` is `None`. """ if city_SE3_ego_cam_t is None: raise ValueError("city_SE3_ego_cam_t cannot be `None`!") @@ -406,6 +407,25 @@ def compute_pixel_ray_directions(self, uv: Union[NDArrayFloat, NDArrayInt]) -> N raise RuntimeError("Ray directions must be (N,3)") return ray_dirs + def scale(self, scale: float) -> PinholeCamera: + """Scale the intrinsics and image size. + + Args: + scale: Scaling factor. + + Returns: + The scaled pinhole camera model. + """ + intrinsics = Intrinsics( + self.intrinsics.fx_px * scale, + self.intrinsics.fy_px * scale, + self.intrinsics.cx_px * scale, + self.intrinsics.cy_px * scale, + round(self.intrinsics.width_px * scale), + round(self.intrinsics.height_px * scale), + ) + return PinholeCamera(ego_SE3_cam=self.ego_SE3_cam, intrinsics=intrinsics, cam_name=self.cam_name) + def remove_nan_values(uv: NDArrayFloat, points_cam: NDArrayFloat) -> Tuple[NDArrayFloat, NDArrayFloat]: """Remove NaN values from camera coordinates and image plane coordinates (accepts corrupt array). diff --git a/src/av2/geometry/utm.py b/src/av2/geometry/utm.py index 16f54404..41562e12 100644 --- a/src/av2/geometry/utm.py +++ b/src/av2/geometry/utm.py @@ -83,7 +83,6 @@ def convert_city_coords_to_utm(points_city: Union[NDArrayFloat, NDArrayInt], cit latitude, longitude = CITY_ORIGIN_LATLONG_DICT[city_name] # get (easting, northing) of origin origin_utm = convert_gps_to_utm(latitude=latitude, longitude=longitude, city_name=city_name) - points_utm: NDArrayFloat = points_city.astype(float) + np.array(origin_utm, dtype=float) return points_utm diff --git a/src/av2/rendering/color.py b/src/av2/rendering/color.py index 665ef7c0..f5b62ba7 100644 --- a/src/av2/rendering/color.py +++ b/src/av2/rendering/color.py @@ -2,12 +2,14 @@ """Colormap related constants and functions.""" +from enum import Enum, unique from typing import Final, Sequence, Tuple +import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import LinearSegmentedColormap -from av2.utils.typing import NDArrayFloat +from av2.utils.typing import NDArrayByte, NDArrayFloat RED_HEX: Final[str] = "#df0101" GREEN_HEX: Final[str] = "#31b404" @@ -31,6 +33,31 @@ TRAFFIC_YELLOW1_BGR: Final[Tuple[int, int, int]] = TRAFFIC_YELLOW1_RGB[::-1] +@unique +class ColorFormats(str, Enum): + """Color channel formats.""" + + BGR = "BGR" + RGB = "RGB" + + +def create_range_map(points_xyz: NDArrayFloat) -> NDArrayByte: + """Generate an RGB colormap as a function of the lidar range. + + Args: + points_xyz: (N,3) Points (x,y,z). + + Returns: + (N,3) RGB colormap. + """ + range = points_xyz[..., 2] + range = np.round(range).astype(int) # type: ignore + color = plt.get_cmap("turbo")(np.arange(0, range.max() + 1)) + color = color[range] + range_cmap: NDArrayByte = (color * 255.0).astype(np.uint8) + return range_cmap + + def create_colormap(color_list: Sequence[str], n_colors: int) -> NDArrayFloat: """Create hex colorscale to interpolate between requested colors. diff --git a/src/av2/rendering/video.py b/src/av2/rendering/video.py index 48e964aa..eab48e97 100644 --- a/src/av2/rendering/video.py +++ b/src/av2/rendering/video.py @@ -2,20 +2,46 @@ """Rendering tools for video visualizations.""" +from __future__ import annotations + +from enum import Enum, unique from pathlib import Path -from typing import Dict, Final, Union +from typing import Dict, Final, Mapping, Optional, Set, Union import av -import cv2 import numpy as np import pandas as pd +from av2.rendering.color import ColorFormats from av2.utils.typing import NDArrayByte +COLOR_FORMAT_TO_PYAV_COLOR_FORMAT: Final[Dict[ColorFormats, str]] = { + ColorFormats.RGB: "rgb24", + ColorFormats.BGR: "bgr24", +} FFMPEG_OPTIONS: Final[Dict[str, str]] = {"crf": "27"} -def tile_cameras(named_sensors: Dict[str, Union[NDArrayByte, pd.DataFrame]]) -> NDArrayByte: +@unique +class VideoCodecs(str, Enum): + """Available video codecs for encoding mp4 videos. + + NOTE: The codecs available are dependent on the FFmpeg build that + you are using. We recommend defaulting to LIBX264. + """ + + LIBX264 = "libx264" # https://en.wikipedia.org/wiki/Advanced_Video_Coding + LIBX265 = "libx265" # https://en.wikipedia.org/wiki/High_Efficiency_Video_Coding + HEVC_VIDEOTOOLBOX = "hevc_videotoolbox" # macOS GPU acceleration. + + +HIGH_EFFICIENCY_VIDEO_CODECS: Final[Set[VideoCodecs]] = set([VideoCodecs.LIBX265, VideoCodecs.HEVC_VIDEOTOOLBOX]) + + +def tile_cameras( + named_sensors: Mapping[str, Union[NDArrayByte, pd.DataFrame]], + bev_img: Optional[NDArrayByte] = None, +) -> NDArrayByte: """Combine ring cameras into a tiled image. NOTE: Images are expected in BGR ordering. @@ -32,43 +58,62 @@ def tile_cameras(named_sensors: Dict[str, Union[NDArrayByte, pd.DataFrame]]) -> Args: named_sensors: Dictionary of camera names to the (width, height, 3) images. + bev_img: (H,W,3) Bird's-eye view image. Returns: Tiled image. """ - landscape_width = 2048 - landscape_height = 1550 + landscape_height = 2048 + landscape_width = 1550 + for _, v in named_sensors.items(): + landscape_width = max(v.shape[0], v.shape[1]) + landscape_height = min(v.shape[0], v.shape[1]) + break height = landscape_height + landscape_height + landscape_height width = landscape_width + landscape_height + landscape_width tiled_im_bgr: NDArrayByte = np.zeros((height, width, 3), dtype=np.uint8) - ring_rear_left = named_sensors["ring_rear_left"] - ring_side_left = named_sensors["ring_side_left"] - ring_front_center = named_sensors["ring_front_center"] - ring_front_left = named_sensors["ring_front_left"] - ring_front_right = named_sensors["ring_front_right"] - ring_side_right = named_sensors["ring_side_right"] - ring_rear_right = named_sensors["ring_rear_right"] + if "ring_front_left" in named_sensors: + ring_front_left = named_sensors["ring_front_left"] + tiled_im_bgr[:landscape_height, :landscape_width] = ring_front_left + + if "ring_front_center" in named_sensors: + ring_front_center = named_sensors["ring_front_center"] + tiled_im_bgr[:landscape_width, landscape_width : landscape_width + landscape_height] = ring_front_center + + if "ring_front_right" in named_sensors: + ring_front_right = named_sensors["ring_front_right"] + tiled_im_bgr[:landscape_height, landscape_width + landscape_height :] = ring_front_right - tiled_im_bgr[:landscape_height, :landscape_width] = ring_front_left - tiled_im_bgr[:landscape_width, landscape_width : landscape_width + landscape_height] = ring_front_center - tiled_im_bgr[:landscape_height, landscape_width + landscape_height :] = ring_front_right + if "ring_side_left" in named_sensors: + ring_side_left = named_sensors["ring_side_left"] + tiled_im_bgr[landscape_height : 2 * landscape_height, :landscape_width] = ring_side_left - tiled_im_bgr[landscape_height:3100, :landscape_width] = ring_side_left - tiled_im_bgr[landscape_height:3100, landscape_width + landscape_height :] = ring_side_right + if "ring_side_right" in named_sensors: + ring_side_right = named_sensors["ring_side_right"] + tiled_im_bgr[landscape_height : 2 * landscape_height, landscape_width + landscape_height :] = ring_side_right - start = (width - 4096) // 2 - tiled_im_bgr[3100:4650, start : start + landscape_width] = np.fliplr(ring_rear_left) # type: ignore - tiled_im_bgr[3100:4650, start + landscape_width : start + 4096] = np.fliplr(ring_rear_right) # type: ignore - tiled_im_rgb: NDArrayByte = cv2.cvtColor(tiled_im_bgr, cv2.COLOR_BGR2RGB) - return tiled_im_rgb + if bev_img is not None: + tiled_im_bgr[ + landscape_width : 2 * landscape_width, landscape_width : landscape_width + landscape_height + ] = bev_img + + if "ring_rear_left" in named_sensors: + ring_rear_left = named_sensors["ring_rear_left"] + tiled_im_bgr[2 * landscape_height : 3 * landscape_height, :landscape_width] = ring_rear_left + + if "ring_rear_right" in named_sensors: + ring_rear_right = named_sensors["ring_rear_right"] + tiled_im_bgr[2 * landscape_height : 3 * landscape_height, width - landscape_width :] = ring_rear_right + return tiled_im_bgr def write_video( video: NDArrayByte, dst: Path, - codec: str = "libx264", + color_format: ColorFormats = ColorFormats.RGB, + codec: VideoCodecs = VideoCodecs.LIBX264, fps: int = 10, crf: int = 27, preset: str = "veryfast", @@ -78,14 +123,15 @@ def write_video( Reference: https://github.com/PyAV-Org/PyAV Args: - video: (N,H,W,3) array representing N RGB frames of identical dimensions. - dst: path to save folder. - codec: the name of a codec. - fps: the frame rate for video. - crf: constant rate factor (CRF) parameter of video, controlling the quality. + video: (N,H,W,3) Array representing N RGB frames of identical dimensions. + dst: Path to save folder. + color_format: Format of the color channels. + codec: Name of the codec. + fps: Frame rate for video. + crf: Constant rate factor (CRF) parameter of video, controlling the quality. Lower values would result in better quality, at the expense of higher file sizes. For x264, the valid Constant Rate Factor (crf) range is 0-51. - preset: file encoding speed. Options range from "ultrafast", ..., "fast", ..., "medium", ..., "slow", ... + preset: File encoding speed. Options range from "ultrafast", ..., "fast", ..., "medium", ..., "slow", ... Higher compression efficiency often translates to slower video encoding speed, at file write time. """ _, H, W, _ = video.shape @@ -98,6 +144,8 @@ def write_video( dst.parent.mkdir(parents=True, exist_ok=True) with av.open(str(dst), "w") as output: stream = output.add_stream(codec, fps) + if codec in HIGH_EFFICIENCY_VIDEO_CODECS: + stream.codec_tag = "hvc1" stream.width = W stream.height = H stream.options = { @@ -106,10 +154,11 @@ def write_video( "movflags": "+faststart", "preset": preset, "profile:v": "main", - "tag": "hvc1", } + + format = COLOR_FORMAT_TO_PYAV_COLOR_FORMAT[color_format] for _, img in enumerate(video): - frame = av.VideoFrame.from_ndarray(img) + frame = av.VideoFrame.from_ndarray(img, format=format) output.mux(stream.encode(frame)) output.mux(stream.encode(None)) diff --git a/tests/datasets/sensor/test_sensor_dataloader.py b/tests/datasets/sensor/test_sensor_dataloader.py index 98a31e91..da83c508 100644 --- a/tests/datasets/sensor/test_sensor_dataloader.py +++ b/tests/datasets/sensor/test_sensor_dataloader.py @@ -2,13 +2,45 @@ """Unit tests on sensor data synchronization utilities.""" -import os +import tempfile from pathlib import Path +from typing import Dict, Final, List +from av2.datasets.sensor.av2_sensor_dataloader import AV2SensorDataLoader +from av2.datasets.sensor.constants import RingCameras from av2.datasets.sensor.sensor_dataloader import SensorDataloader - -def test_sensor_data_loader_milliseconds(tmpdir: "os.PathLike[str]") -> None: +SENSOR_TIMESTAMPS_MS_DICT: Final[Dict[str, List[int]]] = { + "ring_rear_left": [0, 50, 100, 150, 200, 250, 300, 350, 400, 450], + "ring_side_left": [15, 65, 115, 165, 215, 265, 315, 365, 415, 465], + "ring_front_left": [30, 80, 130, 180, 230, 280, 330, 380, 430, 480], + "ring_front_center": [42, 92, 142, 192, 242, 292, 342, 392, 442, 492], + "ring_front_right": [5, 55, 105, 155, 205, 255, 305, 355, 405, 455], + "ring_side_right": [20, 70, 120, 170, 220, 270, 320, 370, 420, 470], + "ring_rear_right": [35, 85, 135, 185, 235, 285, 335, 385, 435, 485], + "lidar": [2, 102, 202, 303, 402, 502, 603, 702, 802, 903], +} + + +def _create_dummy_sensor_dataloader(log_id: str) -> SensorDataloader: + """Create a dummy sensor dataloader.""" + with Path(tempfile.TemporaryDirectory().name) as sensor_dataset_dir: + for sensor_name, timestamps_ms in SENSOR_TIMESTAMPS_MS_DICT.items(): + for t in timestamps_ms: + if "ring" in sensor_name: + fpath = Path( + sensor_dataset_dir, "dummy", log_id, "sensors", "cameras", sensor_name, f"{int(t*1e6)}.jpg" + ) + Path(fpath).parent.mkdir(exist_ok=True, parents=True) + fpath.open("w").close() + elif "lidar" in sensor_name: + fpath = Path(sensor_dataset_dir, "dummy", log_id, "sensors", sensor_name, f"{int(t*1e6)}.feather") + Path(fpath).parent.mkdir(exist_ok=True, parents=True) + fpath.open("w").close() + return SensorDataloader(dataset_dir=sensor_dataset_dir, with_cache=False) + + +def test_sensor_data_loader_milliseconds() -> None: """Test that the sensor dataset dataloader can synchronize lidar and image data. Given toy data in milliseconds, we write out dummy files at corresponding timestamps. @@ -26,69 +58,95 @@ def test_sensor_data_loader_milliseconds(tmpdir: "os.PathLike[str]") -> None: 7 lidar 702000000 NaN 8 lidar 802000000 NaN 9 lidar 903000000 NaN - - Args: - tmpdir: Temp directory used in the test (provided via built-in fixture). """ - tmpdir = Path(tmpdir) - log_id = "00a6ffc1-6ce9-3bc3-a060-6006e9893a1a" - # 7x10 images, and 10 sweeps. Timestamps below given in human-readable milliseconds. - sensor_timestamps_ms_dict = { - "ring_rear_left": [0, 50, 100, 150, 200, 250, 300, 350, 400, 450], - "ring_side_left": [15, 65, 115, 165, 215, 265, 315, 365, 415, 465], - "ring_front_left": [30, 80, 130, 180, 230, 280, 330, 380, 430, 480], - "ring_front_center": [42, 92, 142, 192, 242, 292, 342, 392, 442, 492], - "ring_front_right": [5, 55, 105, 155, 205, 255, 305, 355, 405, 455], - "ring_side_right": [20, 70, 120, 170, 220, 270, 320, 370, 420, 470], - "ring_rear_right": [35, 85, 135, 185, 235, 285, 335, 385, 435, 485], - "lidar": [2, 102, 202, 303, 402, 502, 603, 702, 802, 903], - } - - for sensor_name, timestamps_ms in sensor_timestamps_ms_dict.items(): - for t in timestamps_ms: - if "ring" in sensor_name: - fpath = tmpdir / log_id / "sensors" / "cameras" / sensor_name / f"{int(t*1e6)}.jpg" - elif "lidar" in sensor_name: - fpath = tmpdir / log_id / "sensors" / sensor_name / f"{int(t*1e6)}.feather" - fpath.parent.mkdir(exist_ok=True, parents=True) - # create an empty file - f = open(fpath, "w") - f.close() - - loader = SensorDataloader(sensor_dataset_dir=tmpdir, with_cache=False) + log_id = "00a6ffc1-6ce9-3bc3-a060-6006e9893a1a" + loader = _create_dummy_sensor_dataloader(log_id=log_id) # LiDAR 402 -> matches to ring front center 392. - img_fpath = loader.get_closest_img_fpath( - log_id=log_id, cam_name="ring_front_center", lidar_timestamp_ns=int(402 * 1e6) + img_fpath = loader.find_closest_target_fpath( + split="dummy", + log_id=log_id, + src_sensor_name="lidar", + src_timestamp_ns=int(402 * 1e6), + target_sensor_name="ring_front_center", ) + assert isinstance(img_fpath, Path) # result should be 392 milliseconds (and then a conversion to nanoseconds by adding 6 zeros) - print(img_fpath) assert img_fpath.name == "392" + "000000" + ".jpg" # nothing should be within bounds for this (valid lidar timestamp 903) - img_fpath = loader.get_closest_img_fpath( - log_id=log_id, cam_name="ring_front_center", lidar_timestamp_ns=int(903 * 1e6) + img_fpath = loader.find_closest_target_fpath( + split="dummy", + log_id=log_id, + src_sensor_name="lidar", + target_sensor_name="ring_front_center", + src_timestamp_ns=int(903 * 1e6), ) assert img_fpath is None # nothing should be within bounds for this (invalid lidar timestamp 904) - img_fpath = loader.get_closest_img_fpath( - log_id=log_id, cam_name="ring_front_center", lidar_timestamp_ns=int(904 * 1e6) + img_fpath = loader.find_closest_target_fpath( + split="dummy", + log_id=log_id, + src_sensor_name="lidar", + target_sensor_name="ring_front_center", + src_timestamp_ns=int(904 * 1e6), ) assert img_fpath is None # ring front center 392 -> matches to LiDAR 402. - lidar_fpath = loader.get_closest_lidar_fpath( - log_id=log_id, cam_name="ring_front_center", cam_timestamp_ns=int(392 * 1e6) + lidar_fpath = loader.find_closest_target_fpath( + split="dummy", + log_id=log_id, + src_sensor_name="ring_front_center", + target_sensor_name="lidar", + src_timestamp_ns=int(392 * 1e6), ) + assert isinstance(lidar_fpath, Path) # result should be 402 milliseconds (and then a conversion to nanoseconds by adding 6 zeros) assert lidar_fpath.name == "402" + "000000.feather" # way outside of bounds - lidar_fpath = loader.get_closest_lidar_fpath( - log_id=log_id, cam_name="ring_front_center", cam_timestamp_ns=int(7000 * 1e6) + lidar_fpath = loader.find_closest_target_fpath( + split="dummy", + log_id=log_id, + src_sensor_name="ring_front_center", + target_sensor_name="lidar", + src_timestamp_ns=int(7000 * 1e6), ) assert lidar_fpath is None + + # use the non-pandas implementation as a "brute-force" (BF) check. + # read out the dataset root from the other dataloader's attributes. + bf_loader = AV2SensorDataLoader(data_dir=loader.dataset_dir / "dummy", labels_dir=loader.dataset_dir / "dummy") + + # for every image, make sure query result matches the brute-force query result. + for ring_camera_enum in RingCameras: + ring_camera_name = ring_camera_enum.value + for cam_timestamp_ms in SENSOR_TIMESTAMPS_MS_DICT[ring_camera_name]: + cam_timestamp_ns = int(cam_timestamp_ms * 1e6) + result = loader.get_closest_lidar_fpath( + split="dummy", log_id=log_id, cam_name=ring_camera_name, cam_timestamp_ns=cam_timestamp_ns + ) + bf_result = bf_loader.get_closest_lidar_fpath(log_id=log_id, cam_timestamp_ns=cam_timestamp_ns) + assert result == bf_result + + # for every lidar sweep, make sure query result matches the brute-force query result. + for lidar_timestamp_ms in SENSOR_TIMESTAMPS_MS_DICT["lidar"]: + lidar_timestamp_ns = int(lidar_timestamp_ms * 1e6) + for ring_camera_enum in list(RingCameras): + ring_camera_name = ring_camera_enum.value + result = loader.get_closest_img_fpath( + split="dummy", log_id=log_id, cam_name=ring_camera_name, lidar_timestamp_ns=lidar_timestamp_ns + ) + bf_result = bf_loader.get_closest_img_fpath( + log_id=log_id, cam_name=ring_camera_name, lidar_timestamp_ns=lidar_timestamp_ns + ) + assert result == bf_result + + +if __name__ == "__main__": + test_sensor_data_loader_milliseconds() diff --git a/tutorials/generate_sensor_dataset_visualizations.py b/tutorials/generate_sensor_dataset_visualizations.py new file mode 100644 index 00000000..caa9a515 --- /dev/null +++ b/tutorials/generate_sensor_dataset_visualizations.py @@ -0,0 +1,152 @@ +# + +"""Example script for loading data from the AV2 sensor dataset.""" + +from pathlib import Path +from typing import Final, List, Tuple, Union + +import click +import numpy as np +from rich.progress import track + +from av2.datasets.sensor.constants import RingCameras, StereoCameras +from av2.datasets.sensor.sensor_dataloader import SensorDataloader +from av2.rendering.color import ColorFormats, create_range_map +from av2.rendering.rasterize import draw_points_xy_in_img +from av2.rendering.video import tile_cameras, write_video +from av2.structures.ndgrid import BEVGrid +from av2.utils.typing import NDArrayByte, NDArrayInt + +# Bird's-eye view parameters. +MIN_RANGE_M: Tuple[float, float] = (-102.4, -77.5) +MAX_RANGE_M: Tuple[float, float] = (+102.4, +77.5) +RESOLUTION_M_PER_CELL: Tuple[float, float] = (+0.1, +0.1) + +# Model an xy grid in the Bird's-eye view. +BEV_GRID: Final[BEVGrid] = BEVGrid( + min_range_m=MIN_RANGE_M, max_range_m=MAX_RANGE_M, resolution_m_per_cell=RESOLUTION_M_PER_CELL +) + + +def generate_sensor_dataset_visualizations( + dataset_dir: Path, + with_annotations: bool, + cam_names: Tuple[Union[RingCameras, StereoCameras], ...], +) -> None: + """Create a video of a point cloud in the ego-view. Annotations may be overlaid. + + Args: + dataset_dir: Path to the dataset directory. + with_annotations: Boolean flag to enable loading of annotations. + cam_names: Set of camera names to render. + """ + dataset = SensorDataloader( + dataset_dir, + with_annotations=with_annotations, + with_cache=True, + cam_names=cam_names, + ) + + tiled_cams_list: List[NDArrayByte] = [] + for _, datum in enumerate(track(dataset, "Creating sensor tutorial videos ...")): + sweep = datum.sweep + annotations = datum.annotations + + timestamp_city_SE3_ego_dict = datum.timestamp_city_SE3_ego_dict + synchronized_imagery = datum.synchronized_imagery + if synchronized_imagery is not None: + cam_name_to_img = {} + for cam_name, cam in synchronized_imagery.items(): + if ( + cam.timestamp_ns in timestamp_city_SE3_ego_dict + and sweep.timestamp_ns in timestamp_city_SE3_ego_dict + ): + city_SE3_ego_cam_t = timestamp_city_SE3_ego_dict[cam.timestamp_ns] + city_SE3_ego_lidar_t = timestamp_city_SE3_ego_dict[sweep.timestamp_ns] + + uv, points_cam, is_valid_points = cam.camera_model.project_ego_to_img_motion_compensated( + sweep.xyz, + city_SE3_ego_cam_t=city_SE3_ego_cam_t, + city_SE3_ego_lidar_t=city_SE3_ego_lidar_t, + ) + + uv_int: NDArrayInt = np.round(uv[is_valid_points]).astype(int) # type: ignore + colors = create_range_map(points_cam[is_valid_points, :3]) + img = draw_points_xy_in_img( + cam.img, uv_int, colors=colors, alpha=0.85, diameter=5, sigma=1.0, with_anti_alias=True + ) + if annotations is not None: + img = annotations.project_to_cam( + img, cam.camera_model, city_SE3_ego_cam_t, city_SE3_ego_lidar_t + ) + cam_name_to_img[cam_name] = img + if len(cam_name_to_img) < len(cam_names): + continue + tiled_img = tile_cameras(cam_name_to_img, bev_img=None) + tiled_cams_list.append(tiled_img) + + if datum.sweep_number == datum.num_sweeps_in_log - 1: + video: NDArrayByte = np.stack(tiled_cams_list) + dst_path = Path("videos") / f"{datum.log_id}.mp4" + dst_path.parent.mkdir(parents=True, exist_ok=True) + write_video(video, dst_path, crf=30, color_format=ColorFormats.BGR) + tiled_cams_list = [] + + +@click.command(help="Generate visualizations from the Argoverse 2 Sensor Dataset.") +@click.option( + "-d", + "--dataset-dir", + required=True, + help="Path to local directory where the Argoverse 2 Sensor Dataset is stored.", + type=click.Path(exists=True), +) +@click.option( + "-a", + "--with-annotations", + default=True, + help="Boolean flag to return annotations from the dataloader.", + type=bool, +) +@click.option( + "-c", + "--cam_names", + default=tuple(x.value for x in RingCameras), + help="List of cameras to load for each lidar sweep.", + multiple=True, + type=str, +) +def run_generate_sensor_dataset_visualizations( + dataset_dir: str, with_annotations: bool, cam_names: Tuple[str, ...] +) -> None: + """Click entry point for Argoverse Sensor Dataset visualization. + + Args: + dataset_dir: Dataset directory. + with_annotations: Boolean flag to return annotations. + cam_names: Tuple of camera names to load. + + Raises: + ValueError: If no valid camera names are provided. + """ + valid_ring_cams = set([x.value for x in RingCameras]) + valid_stereo_cams = set([x.value for x in StereoCameras]) + + cam_enums: List[Union[RingCameras, StereoCameras]] = [] + for cam_name in cam_names: + if cam_name in valid_ring_cams: + cam_enums.append(RingCameras(cam_name)) + elif cam_name in valid_stereo_cams: + cam_enums.append(StereoCameras(cam_name)) + else: + raise ValueError("Must provide _valid_ camera names!") + + generate_sensor_dataset_visualizations( + Path(dataset_dir), + with_annotations, + tuple(cam_enums), + ) + + +if __name__ == "__main__": + run_generate_sensor_dataset_visualizations()