diff --git a/src/keras_video/generator.py b/src/keras_video/generator.py index 912dd98..0f6bcd1 100644 --- a/src/keras_video/generator.py +++ b/src/keras_video/generator.py @@ -391,16 +391,17 @@ def _get_classname(self, video: str) -> str: return classname def _get_frames( - self, video, nbframe, shape, force_no_headers=False + self, video, nbframe, shape, seq_time=0, force_no_headers=False ) -> Optional[Iterable]: cap = cv.VideoCapture(video) total_frames = self.count_frames(cap, video, force_no_headers) + fps = cap.get(cv.CAP_PROP_FPS) orig_total = total_frames if total_frames % 2 != 0: total_frames += 1 - frame_step = floor(total_frames / (nbframe - 1)) + frame_step = floor(total_frames / (nbframe - 1)) if seq_time ==0 else floor((seq_time * fps) / (nbframe - 1)) # TODO: fix that, a tiny video can have a frame_step that is # under 1 frame_step = max(1, frame_step) diff --git a/src/keras_video/sliding.py b/src/keras_video/sliding.py index e577dbe..58f6a7f 100644 --- a/src/keras_video/sliding.py +++ b/src/keras_video/sliding.py @@ -77,6 +77,9 @@ def __init_length(self): else: seqtime = int(frame_count) + # add an asssert to check if nbframe is possible in sequence_time + assert self.nbframe < seqtime, "ERROR nbframe > sequence_time change parameter and restart !" + stop_at = int(seqtime - self.nbframe) step = np.ceil(seqtime / self.nbframe).astype(np.int) - 1 i = 0 @@ -172,7 +175,7 @@ def __getitem__(self, idx): video_id = vid["id"] if video_id not in self.__frame_cache: - frames: Iterable = self._get_frames(video, nbframe, shape) + frames: Iterable = self._get_frames(video, nbframe, shape, self.sequence_time if self.sequence_time != None else 0) else: frames: Iterable = self.__frame_cache[video_id]