Skip to content

Commit

Permalink
Replace empty by missing
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Dec 5, 2024
1 parent 8f34dbc commit 3bce790
Showing 1 changed file with 105 additions and 157 deletions.
262 changes: 105 additions & 157 deletions examples/load_and_upsample_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
print(ds.time)

# %%
#
# In the following sections of the notebook we will explore options to upsample
# the dataset by filling in values for the video frames with no data.

Expand All @@ -80,6 +79,8 @@
# Let's inspect the first 6 frames of the video for which we have
# annotations, and plot the annotated bounding box and centroid at each frame.

# sphinx_gallery_thumbnail_number = 1

# set last frame to plot
end_frame_idx = 25
# create list of frames to loop over with step=5
Expand Down Expand Up @@ -126,7 +127,7 @@
ax.scatter(
x=ds.position.sel(time=slice(0, frame_idx - 1), space="x"),
y=ds.position.sel(time=slice(0, frame_idx - 1), space="y"),
s=5,
s=10,
color="tab:blue",
label="past frames",
)
Expand All @@ -135,11 +136,12 @@
ax.scatter(
x=ds.position.sel(time=slice(frame_idx + 1, end_frame_idx), space="x"),
y=ds.position.sel(time=slice(frame_idx + 1, end_frame_idx), space="y"),
s=5,
s=10,
color="white",
label="future frames",
)

# set title and labels
ax.set_title(f"Frame {frame_idx}")
ax.set_xlabel("x (pixles)")
ax.set_ylabel("y (pixels)")
Expand All @@ -160,7 +162,7 @@
# %%
# Fill in empty values with forward filling
# ----------------------------------------------------
# We can fill in the frames with missing values for the ``position`` and
# We can fill in the frames with empty values for the ``position`` and
# ``shape`` arrays by taking the last valid value in time. In this way, a
# box's position and shape stay constant if for a current frame the box
# has no annotation defined.
Expand All @@ -171,87 +173,81 @@
)

# %%
# We can verify with a plot that the missing values have been filled in
# We can verify with a plot that the empty values have been filled in
# using the last valid value in time.

# In the plot below, the original position and shape data is shown in black,
# while the forward-filled values are shown in blue.
# %%
# In the plot below, the original ``position`` and ``shape`` data is shown
# in black, while the forward-filled values are shown in green.


# We define a convenience function to plot the ``position`` and ``shape``
# space coordinates for the input dataset and a filled one.
def plot_position_and_shape_xy_coords(ds_input_data, ds_filled, color_filled):
"""Compare the x and y coordinates of the position and shape arrays in time
for the input and filled datasets.
"""
fig, axs = plt.subplots(2, 2, figsize=(8, 6))
for row in range(axs.shape[0]):
space_coord = ["x", "y"][row]
for col in range(axs.shape[1]):
ax = axs[row, col]
data_array_str = ["position", "shape"][col]

# plot original data
ax.scatter(
x=ds_input_data.time,
y=ds_input_data[data_array_str].sel(
individuals="id_1", space=space_coord
),
marker="o",
color="black",
label="original data",
)

fig, axs = plt.subplots(2, 2, figsize=(8, 6))
for row in range(axs.shape[0]):
space_coord = ["x", "y"][row]
for col in range(axs.shape[1]):
ax = axs[row, col]
data_array_str = ["position", "shape"][col]
# plot original data
ax.scatter(
x=ds.time,
y=ds[data_array_str].sel(individuals="id_1", space=space_coord),
marker="o",
color="black",
label="original data",
)
# plot forward filled data
ax.plot(
ds_ff.time,
ds_ff[data_array_str].sel(individuals="id_1", space=space_coord),
marker=".",
linewidth=1,
color="tab:green",
label="upsampled data",
)
ax.set_ylabel(f"{space_coord} (pixels)")
if row == 0:
ax.set_title(f"Bounding box {data_array_str}")
if col == 1:
ax.legend()
if row == 1:
ax.set_xlabel("time (frames)")
# plot forward filled data
ax.plot(
ds_filled.time,
ds_filled[data_array_str].sel(
individuals="id_1", space=space_coord
),
marker=".",
linewidth=1,
color=color_filled,
label="upsampled data",
)

# set axes labels and legend
ax.set_ylabel(f"{space_coord} (pixels)")
if row == 0:
ax.set_title(f"Bounding box {data_array_str}")
if col == 1:
ax.legend()
if row == 1:
ax.set_xlabel("time (frames)")


# plot
plot_position_and_shape_xy_coords(
ds, ds_filled=ds_ff, color_filled="tab:green"
)

# %%
# Fill in empty values with NaN
# ----------------------------------------------------
# Alternatively, we can fill in the missing frames with NaN values.
# This can be useful if we want to interpolate the missing values later.
# Alternatively, we can fill in the empty frames with NaN values.
# This can be useful if we want to interpolate later.
ds_nan = ds.reindex(
{"time": list(range(ds.time[-1].item()))},
method=None, # default
)

# %%
# Like before, we can verify with a plot that the missing values have been
# Like before, we can verify with a plot that the empty values have been
# filled with NaN values.
fig, axs = plt.subplots(2, 2, figsize=(8, 6))
for row in range(axs.shape[0]):
space_coord = ["x", "y"][row]
for col in range(axs.shape[1]):
ax = axs[row, col]
data_array_str = ["position", "shape"][col]
# plot original data
ax.scatter(
x=ds.time,
y=ds[data_array_str].sel(individuals="id_1", space=space_coord),
marker="o",
color="black",
label="original data",
)
# plot NaN filled data
ax.plot(
ds_nan.time,
ds_nan[data_array_str].sel(individuals="id_1", space=space_coord),
marker=".",
linewidth=1,
color="tab:blue",
label="upsampled data",
)
ax.set_ylabel(f"{space_coord} (pixels)")
if row == 0:
ax.set_title(f"Bounding box {data_array_str}")
if col == 1:
ax.legend()
if row == 1:
ax.set_xlabel("time (frames)")
plot_position_and_shape_xy_coords(
ds, ds_filled=ds_nan, color_filled="tab:blue"
)

# %%
# We can further confirm we have NaNs where expected by printing the first few
Expand All @@ -265,7 +261,7 @@
# %%
# Linearly interpolate NaN values
# ----------------------------------------------------------
# We can instead fill in the missing values in the dataset by linearly
# We can instead fill in the empty values in the dataset by linearly
# interpolating the ``position`` and ``shape`` data arrays. In this way,
# we would be assuming that the centroid of the bounding box moves linearly
# between the two annotated values, and its width and height change linearly
Expand All @@ -284,119 +280,78 @@
)

# %%
# Like before, we can visually check that the missing data has been imputed as
# expected by plotting the x and y coordinates of the position and shape arrays
# Like before, we can visually check that the empty data has been imputed as
# expected by plotting the x and y coordinates of the ``position``
# and ``shape`` arrays
# in time.

fig, axs = plt.subplots(2, 2, figsize=(8, 6))
for row in range(axs.shape[0]):
space_coord = ["x", "y"][row]
for col in range(axs.shape[1]):
ax = axs[row, col]
data_array_str = ["position", "shape"][col]
# plot original data
ax.scatter(
x=ds.time,
y=ds[data_array_str].sel(individuals="id_1", space=space_coord),
marker="o",
color="black",
label="original data",
)
# plot linearly interpolated data
ax.plot(
ds_interp.time,
ds_interp[data_array_str].sel(
individuals="id_1", space=space_coord
),
marker=".",
linewidth=1,
color="tab:orange",
label="upsampled data",
)
ax.set_ylabel(f"{space_coord} (pixels)")
if row == 0:
ax.set_title(f"Bounding box {data_array_str}")
if col == 1:
ax.legend()
if row == 1:
ax.set_xlabel("time (frames)")
plot_position_and_shape_xy_coords(
ds, ds_filled=ds_interp, color_filled="tab:orange"
)

# %%
# The plot above shows that between the original data points (in black),
# the data is assumed to evolve linearly (in blue).
# the data is assumed to evolve linearly (in orange).

# %%
# Compare methods
# ----------------
# We can now qualitatively compare the three different methods of filling
# in the missing frames, by plotting the bounding boxes
# for the first few frames of the video.
# We can now qualitatively compare the bounding boxes computed
# with the three different filling methods we have seen: forward filling,
# NaN filling and linear interpolation
#
# Remember that not all frames of the video are annotated in the original
# dataset. The original data are plotted in black, while the forward filled
# values are plotted in orange and the linearly interpolated values in green.

# sphinx_gallery_thumbnail_number = 4
# In the plot below, the NaN-filled data are plotted in blue, the forward
# filled values are plotted in orange, and the linearly interpolated values
# are shown in green.

# initialise figure
fig = plt.figure(figsize=(8, 8))

list_colors = ["tab:blue", "tab:green", "tab:orange"]

# loop over frames
for frame_n in range(6):
for frame_idx in range(6):
# add subplot axes
ax = plt.subplot(3, 2, frame_n + 1)
ax = plt.subplot(3, 2, frame_idx + 1)

# plot frame
# note: the video is indexed at every frame, so
# we use the frame number as index
ax.imshow(video[frame_n])
ax.imshow(video[frame_idx])

# plot bounding box for each dataset
for ds_i, ds_one in enumerate(
[ds_nan, ds_ff, ds_interp]
): # blue, green , orange
for ds_i, ds_filled in enumerate([ds_nan, ds_ff, ds_interp]):
# plot box
top_left_corner = (
ds_one.position.sel(time=frame_n, individuals="id_1").data
- ds_one.shape.sel(time=frame_n, individuals="id_1").data / 2
)
ds_filled.position.sel(time=frame_idx).data
- ds_filled.shape.sel(time=frame_idx).data / 2
).squeeze()

bbox = plt.Rectangle(
xy=tuple(top_left_corner),
width=ds_one.shape.sel(
time=frame_n, individuals="id_1", space="x"
).data,
height=ds_one.shape.sel(
time=frame_n, individuals="id_1", space="y"
).data,
width=ds_filled.shape.sel(time=frame_idx, space="x").item(),
height=ds_filled.shape.sel(time=frame_idx, space="y").item(),
edgecolor=list_colors[ds_i],
facecolor="none",
# make line for NaN dataset thicker and dotted
linewidth=[5, 1.5, 1.5][ds_i],
linestyle=["dotted", "solid", "solid"][ds_i],
label=["nan", "ffill", "linear"][ds_i],
linewidth=[8, 2.5, 2.5][ds_i],
linestyle=["dotted", "solid", "solid"][ds_i],
)
ax.add_patch(bbox)

# plot centroid
ax.scatter(
x=ds_one.position.sel(
time=frame_n, individuals="id_1", space="x"
).data,
y=ds_one.position.sel(
time=frame_n, individuals="id_1", space="y"
).data,
s=5,
x=ds_filled.position.sel(time=frame_idx, space="x"),
y=ds_filled.position.sel(time=frame_idx, space="y"),
s=20,
color=list_colors[ds_i],
)

# add legend to first frame
if frame_n == 0:
ax.legend()
ax.set_title(f"Frame {frame_n}")
# set title and labels
ax.set_title(f"Frame {frame_idx}")
ax.set_xlabel("x (pixels)")
ax.set_ylabel("y (pixels)")
if frame_idx == 0:
ax.legend()

fig.tight_layout()

Expand All @@ -419,23 +374,16 @@
writer = csv.writer(file)

# write the header
writer.writerow(
["frame_idx", "bbox_ID", "x", "y", "width", "height", "confidence"]
)
writer.writerow(["frame", "ID", "x", "y", "width", "height"])

# write the data
for individual in ds.individuals.data:
for frame in ds.time.data:
x, y = ds.position.sel(time=frame, individuals=individual).data
width, height = ds.shape.sel(
time=frame, individuals=individual
).data
confidence = ds.confidence.sel(
for individual in ds_ff.individuals.data:
for frame in ds_ff.time.data:
x, y = ds_ff.position.sel(time=frame, individuals=individual).data
width, height = ds_ff.shape.sel(
time=frame, individuals=individual
).data
writer.writerow(
[frame, individual, x, y, width, height, confidence]
)
writer.writerow([frame, individual, x, y, width, height])

# %%
# Clean-up
Expand Down

0 comments on commit 3bce790

Please sign in to comment.