Skip to content

Commit

Permalink
Merge pull request #2389 from AllenInstitute/ticket/2382
Browse files Browse the repository at this point in the history
Removes unnecessary validation on non-behavior files
  • Loading branch information
aamster authored Apr 25, 2022
2 parents e51c39c + 5d51dd5 commit e3d4fc9
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 85 deletions.
46 changes: 5 additions & 41 deletions allensdk/brain_observatory/behavior/data_files/stimulus_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,44 +36,6 @@ def from_lims_cache_key(cls, db, behavior_session_id: int):
return hashkey(behavior_session_id, cls.file_path_key())


class _NumFramesMixin(object):
"""
Mixin to implement num_frames for a generic (i.e. non-behavior)
StimulusFile
"""

def _validate_frame_data(self) -> None:
"""
Check that self.data['intervalsms'] is present and
that self.data['items']['behavior']['intervalsms'] is empty
"""
msg = ""
if 'intervalsms' not in self.data:
msg += "self.data['intervalsms'] not present\n"
if "items" in self.data:
if "behavior" in self.data["items"]:
if "intervalsms" in self.data["items"]["behavior"]:
val = self.data["items"]["behavior"]["intervalsms"]
if len(val) > 0:
msg += ("len(self.data['items']['behavior']"
f"['intervalsms'] == {len(val)}; "
"expected zero\n")
if len(msg) > 0:
full_msg = f"When getting num_frames from {type(self)}\n"
full_msg += msg
full_msg += f"\nfilepath: {self.filepath}"
raise RuntimeError(full_msg)

return None

def num_frames(self) -> int:
"""
Return the number of frames associated with this StimulusFile
"""
self._validate_frame_data()
return len(self.data['intervalsms']) + 1


class _StimulusFile(DataFile):
"""A DataFile which contains methods for accessing and loading visual
behavior stimulus *.pkl files.
Expand Down Expand Up @@ -120,11 +82,12 @@ def load_data(filepath: Union[str, Path]) -> dict:
filepath = safe_system_path(file_name=filepath)
return pd.read_pickle(filepath)

@property
def num_frames(self) -> int:
"""
Return the number of frames associated with this StimulusFile
"""
raise NotImplementedError()
return len(self.data['intervalsms']) + 1


class BehaviorStimulusFile(_StimulusFile):
Expand Down Expand Up @@ -171,6 +134,7 @@ def _validate_frame_data(self):

return None

@property
def num_frames(self) -> int:
"""
Return the number of frames associated with this StimulusFile
Expand All @@ -179,14 +143,14 @@ def num_frames(self) -> int:
return len(self.data['items']['behavior']['intervalsms']) + 1


class ReplayStimulusFile(_NumFramesMixin, _StimulusFile):
class ReplayStimulusFile(_StimulusFile):

@classmethod
def file_path_key(cls) -> str:
return "replay_stimulus_file"


class MappingStimulusFile(_NumFramesMixin, _StimulusFile):
class MappingStimulusFile(_StimulusFile):

@classmethod
def file_path_key(cls) -> str:
Expand Down
10 changes: 5 additions & 5 deletions allensdk/brain_observatory/ecephys/write_nwb/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,6 @@ class BaseNeuropixelsSchema(ArgSchema):
help="""file at this path contains information about the optogenetic
stimulation applied during this experiment"""
)
running_speed_path = String(
required=True,
help="""data collected about the running behavior of the experiment's
subject""",
)
eye_tracking_rig_geometry = Dict(
required=False,
help="""Mapping containing information about session rig geometry used
Expand Down Expand Up @@ -259,6 +254,11 @@ class Meta:
allow_none=True,
required=False,
help="miscellaneous information describing this session""")
running_speed_path = String(
required=True,
help="""data collected about the running behavior of the experiment's
subject""",
)


class ProbeOutputs(RaisingSchema):
Expand Down
4 changes: 0 additions & 4 deletions allensdk/brain_observatory/ecephys/write_nwb/vbn/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ class _VBNSessionDataSchema(BaseBehaviorSessionDataSchema,
required=True,
description='path to stimulus presentations csv file'
)
raw_running_speed_path = argschema.fields.InputFile(
required=True,
description='path to raw running speed h5 file'
)
raw_eye_tracking_video_meta_data = argschema.fields.InputFile(
required=True,
description='path to eye tracking metadata'
Expand Down
2 changes: 1 addition & 1 deletion allensdk/brain_observatory/sync_stim_aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def get_stim_timestamps_from_stimulus_blocks(
data=sync_data,
sync_lines=raw_frame_time_lines)

frame_count_list = [s.num_frames() for s in stimulus_files]
frame_count_list = [s.num_frames for s in stimulus_files]
start_frames = _get_start_frames(
data=sync_data,
raw_frame_times=raw_frame_times,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_behavior_num_frames(
str_path = str(behavior_pkl_fixture['path'].resolve().absolute())
dict_repr = {"behavior_stimulus_file": str_path}
beh_stim = BehaviorStimulusFile.from_json(dict_repr=dict_repr)
assert beh_stim.num_frames() == behavior_pkl_fixture['expected_frames']
assert beh_stim.num_frames == behavior_pkl_fixture['expected_frames']


def test_replay_num_frames(
Expand All @@ -140,7 +140,7 @@ def test_replay_num_frames(
str_path = str(general_pkl_fixture['path'].resolve().absolute())
dict_repr = {"replay_stimulus_file": str_path}
rep_stim = ReplayStimulusFile.from_json(dict_repr=dict_repr)
assert rep_stim.num_frames() == general_pkl_fixture['expected_frames']
assert rep_stim.num_frames == general_pkl_fixture['expected_frames']


def test_mapping_num_frames(
Expand All @@ -152,7 +152,7 @@ def test_mapping_num_frames(
str_path = str(general_pkl_fixture['path'].resolve().absolute())
dict_repr = {"mapping_stimulus_file": str_path}
map_stim = MappingStimulusFile.from_json(dict_repr=dict_repr)
assert map_stim.num_frames() == general_pkl_fixture['expected_frames']
assert map_stim.num_frames == general_pkl_fixture['expected_frames']


def test_malformed_behavior_pkl(
Expand All @@ -166,35 +166,7 @@ def test_malformed_behavior_pkl(
stim = BehaviorStimulusFile.from_json(dict_repr=dict_repr)
with pytest.raises(RuntimeError,
match="When getting num_frames from"):
stim.num_frames()


def test_malformed_replay_pkl(
behavior_pkl_fixture):
"""
Test that the correct error is raised when a replay pickle file
is mal-formed and num_frames is called
"""
str_path = str(behavior_pkl_fixture['path'].resolve().absolute())
dict_repr = {"replay_stimulus_file": str_path}
stim = ReplayStimulusFile.from_json(dict_repr=dict_repr)
with pytest.raises(RuntimeError,
match="When getting num_frames from"):
stim.num_frames()


def test_malformed_mapping_pkl(
behavior_pkl_fixture):
"""
Test that the correct error is raised when a mapping pickle file
is mal-formed and num_frames is called
"""
str_path = str(behavior_pkl_fixture['path'].resolve().absolute())
dict_repr = {"mapping_stimulus_file": str_path}
stim = MappingStimulusFile.from_json(dict_repr=dict_repr)
with pytest.raises(RuntimeError,
match="When getting num_frames from"): # noqa W605
stim.num_frames()
_ = stim.num_frames


def test_stimulus_file_lookup(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class DummyStim(object):
def __init__(self, n_frames):
self._n_frames = n_frames

@property
def num_frames(self):
return self._n_frames

Expand Down Expand Up @@ -254,7 +255,7 @@ def test_user_facing_get_stim_timestamps_smoke(
raw_idx = line_to_edges_fixture['vsync_stim'][f'{edge_type}_idx']
raw_times = sync_sample_fixture[raw_idx]/sync_freq_fixture
idx0 = expected_start_frames_fixture[edge_type][ii]
expected = raw_times[idx0: idx0+this_stim.num_frames()]
expected = raw_times[idx0: idx0+this_stim.num_frames]
np.testing.assert_array_equal(this_array, expected)
assert this_start_frame == expected_start_frames_fixture[edge_type][ii]

Expand Down Expand Up @@ -300,6 +301,6 @@ def test_user_facing_get_stim_timestamps(
raw_idx = line_to_edges_fixture['vsync_stim'][f'{edge_type}_idx']
raw_times = sync_sample_fixture[raw_idx]/sync_freq_fixture
idx0 = expected_start[ii]
expected = raw_times[idx0: idx0+this_stim.num_frames()]
expected = raw_times[idx0: idx0+this_stim.num_frames]
np.testing.assert_array_equal(this_array, expected)
assert this_start_frame == expected_start[ii]

0 comments on commit e3d4fc9

Please sign in to comment.