Skip to content

Commit

Permalink
Merge pull request SainsburyWellcomeCentre#336 from ttngu207/datajoin…
Browse files Browse the repository at this point in the history
…t_pipeline

add BlockSubjectPatch analysis & other minor fixes
  • Loading branch information
JaerongA authored Feb 16, 2024
2 parents a6725ad + 9aaae2b commit 2a7261d
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 6 deletions.
156 changes: 151 additions & 5 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
from matplotlib import path as mpl_path

from aeon.analysis import utils as analysis_utils
from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking
Expand Down Expand Up @@ -226,14 +227,159 @@ class Patch(dj.Part):
in_patch_time: float # total seconds spent in this patch for this block
pellet_count: int
pellet_timestamps: longblob
wheel_distance_travelled: longblob # wheel's cumulative distance travelled
wheel_timestamps: longblob
cumulative_sum_preference: longblob
windowed_sum_preference: longblob
wheel_cumsum_distance_travelled: longblob # wheel's cumulative distance travelled
"""

class Preference(dj.Part):
definition = """ # Measure of preference for a particular patch from a particular subject
-> master
-> BlockAnalysis.Patch
-> BlockAnalysis.Subject
---
cumulative_preference_by_wheel: longblob
windowed_preference_by_wheel: longblob
cumulative_preference_by_pellet: longblob
windowed_preference_by_pellet: longblob
cumulative_preference_by_time: longblob
windowed_preference_by_time: longblob
preference_score: float # one representative preference score for the entire block
"""

def make(self, key):
pass
block_patches = (BlockAnalysis.Patch & key).fetch(as_dict=True)
block_subjects = (BlockAnalysis.Subject & key).fetch(as_dict=True)
subject_names = [s["subject_name"] for s in block_subjects]
# Construct subject position dataframe
subjects_positions_df = pd.concat(
[
pd.DataFrame(
{"subject_name": [s["subject_name"]] * len(s["position_timestamps"])}
| {
k: s[k]
for k in (
"position_timestamps",
"position_x",
"position_y",
"position_likelihood",
)
}
)
for s in block_subjects
]
)
subjects_positions_df.set_index("position_timestamps", inplace=True)
# Get frame rate of CameraTop
camera_fps = int(
(
streams.SpinnakerVideoSource * streams.SpinnakerVideoSource.Attribute
& key
& 'attribute_name = "SamplingFrequency"'
& 'spinnaker_video_source_name = "CameraTop"'
& f'spinnaker_video_source_install_time < "{key["block_start"]}"'
).fetch("attribute_value", order_by="spinnaker_video_source_install_time DESC", limit=1)[0]
)

self.insert1(key)
for i, patch in enumerate(block_patches):
cum_wheel_dist = pd.Series(
index=patch["wheel_timestamps"], data=patch["wheel_cumsum_distance_travelled"]
)
# Get distance-to-patch at each pose data timestep
patch_region = (
acquisition.EpochActiveRegion.Region
& key
& {"region_name": f"{patch['patch_name']}Region"}
& f'epoch_start < "{key["block_start"]}"'
).fetch("region_data", order_by="epoch_start DESC", limit=1)[0]
patch_xy = list(zip(*[(int(p["X"]), int(p["Y"])) for p in patch_region["ArrayOfPoint"]]))
patch_center = np.mean(patch_xy[0]).astype(np.uint32), np.mean(patch_xy[1]).astype(np.uint32)
subjects_xy = subjects_positions_df[["position_x", "position_y"]].values
dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_center) ** 2, axis=1).astype(float))
dist_to_patch_df = subjects_positions_df[["subject_name"]].copy()
dist_to_patch_df["dist_to_patch"] = dist_to_patch
# Assign pellets and wheel timestamps to subjects
if len(block_subjects) == 1:
cum_wheel_dist_dm = cum_wheel_dist.to_frame(name=subject_names[0])
patch_df_for_pellets_df = pd.DataFrame(
index=patch["pellet_timestamps"], data={"subject_name": subject_names[0]}
)
else:
# Assign id based on which subject was closest to patch at time of event
# Get distance-to-patch at each wheel ts and pel del ts, organized by subject
dist_to_patch_wheel_ts_id_df = pd.DataFrame(
index=cum_wheel_dist.index, columns=subject_names
)
dist_to_patch_pel_ts_id_df = pd.DataFrame(
index=patch["pellet_timestamps"], columns=subject_names
)
for subject_name in subject_names:
# Find closest match between pose_df indices and wheel indices
if not dist_to_patch_wheel_ts_id_df.empty:
dist_to_patch_wheel_ts_subj = pd.merge_asof(
left=dist_to_patch_wheel_ts_id_df[subject_name],
right=dist_to_patch_df[dist_to_patch_df["subject_name"] == subject_name],
left_index=True,
right_index=True,
direction="forward",
tolerance=pd.Timedelta("100ms"),
)
dist_to_patch_wheel_ts_id_df[subject_name] = dist_to_patch_wheel_ts_subj[
"dist_to_patch"
]
# Find closest match between pose_df indices and pel indices
if not dist_to_patch_pel_ts_id_df.empty:
dist_to_patch_pel_ts_subj = pd.merge_asof(
left=dist_to_patch_pel_ts_id_df[subject_name],
right=dist_to_patch_df[dist_to_patch_df["subject_name"] == subject_name],
left_index=True,
right_index=True,
direction="forward",
tolerance=pd.Timedelta("200ms"),
)
dist_to_patch_pel_ts_id_df[subject_name] = dist_to_patch_pel_ts_subj[
"dist_to_patch"
]
# Get closest subject to patch at each pel del timestep
patch_df_for_pellets_df = pd.DataFrame(
index=patch["pellet_timestamps"],
data={"subject_name": dist_to_patch_pel_ts_id_df.idxmin(axis=1).values},
)

# Get closest subject to patch at each wheel timestep
cum_wheel_dist_subj_df = pd.DataFrame(
index=cum_wheel_dist.index, columns=subject_names, data=0.0
)
closest_subjects = dist_to_patch_wheel_ts_id_df.idxmin(axis=1)
wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0])
# Assign wheel dist to closest subject for each wheel timestep
for subject_name in subject_names:
subj_idxs = cum_wheel_dist_subj_df[closest_subjects == subject_name].index
cum_wheel_dist_subj_df.loc[subj_idxs, subject_name] = wheel_dist[subj_idxs]
cum_wheel_dist_dm = cum_wheel_dist_subj_df.cumsum(axis=0)

# In Patch Time
patch_bbox = mpl_path.Path(list(zip(*patch_xy)))
in_patch = subjects_positions_df.apply(
lambda row: patch_bbox.contains_point((row["position_x"], row["position_y"])), axis=1
)
# Insert data
for subject_name in subject_names:
pellets = patch_df_for_pellets_df[patch_df_for_pellets_df["subject_name"] == subject_name]
subject_in_patch = subjects_positions_df[
in_patch & (subjects_positions_df["subject_name"] == subject_name)
]
self.Patch.insert1(
key
| dict(
patch_name=patch["patch_name"],
subject_name=subject_name,
in_patch_timestamps=subject_in_patch.index.values,
in_patch_time=len(subject_in_patch) / camera_fps,
pellet_count=len(pellets),
pellet_timestamps=pellets.index.values,
wheel_cumsum_distance_travelled=cum_wheel_dist_dm[subject_name].values,
)
)


@schema
Expand Down
11 changes: 11 additions & 0 deletions aeon/dj_pipeline/utils/load_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath

device_type_mapper, _ = get_device_mapper(devices_schema, metadata_yml_filepath)

# Retrieve video controller
video_controller = epoch_config["metadata"].pop("VideoController", {})

# Insert into each device table
epoch_device_types = []
device_list = []
Expand Down Expand Up @@ -221,6 +224,14 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath
}
for attribute_name, attribute_value in device_config.items()
]
if "TriggerFrequency" in device_config:
table_attribute_entry.append(
{
**table_entry,
"attribute_name": "SamplingFrequency",
"attribute_value": video_controller[device_config["TriggerFrequency"]],
}
)

"""Check if this device is currently installed. If the same device serial number is currently installed check for any changes in configuration. If not, skip this"""
current_device_query = table - table.RemovalTime & experiment_key & device_key
Expand Down
8 changes: 7 additions & 1 deletion aeon/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,13 @@ def __init__(self, pattern, columns, dtype=None, extension="csv"):

def read(self, file):
"""Reads data from the specified CSV text file."""
return pd.read_csv(file, header=0, names=self.columns, dtype=self.dtype, index_col=0)
return pd.read_csv(
file,
header=0,
names=self.columns,
dtype=self.dtype,
index_col=0 if file.stat().st_size else None,
)


class Subject(Csv):
Expand Down

0 comments on commit 2a7261d

Please sign in to comment.