Fly Tracker (Colab)#
This notebook tracks a fruit fly against a light background in a video, calculates its position and velocity over time, and produces path and velocity plots.
Original MATLAB code by Jeff Stafford, modified by A. Juavinett for BIPN 145.
How to use#
Run the Setup cell to install dependencies.
Upload your video (
.avi,.mp4, etc.) using the upload cell.Set your parameters (dish diameter, frame rate).
Draw your ROI on the first frame.
Run the remaining cells to track the fly and view results.
Setup#
#@title Install & import packages { display-mode: "form" }
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from scipy.spatial.distance import euclidean
from google.colab import files, output
from IPython.display import display, clear_output, HTML, Image as IPImage
import ipywidgets as widgets
import os
import base64
# Initialize result containers (run once, then append to them)
if 'all_corrected' not in dir():
all_corrected = []
all_velocity = []
print('Initialized result containers.')
else:
print(f'Existing results: {len(all_corrected)} video(s) processed so far.')
print('All packages loaded successfully!')
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[1], line 3
1 #@title Install & import packages { display-mode: "form" }
2 import numpy as np
----> 3 import cv2
4 import matplotlib.pyplot as plt
5 import matplotlib.colors as mcolors
ModuleNotFoundError: No module named 'cv2'
Mount your drive#
Run the cell below to mount your drive, which should contain your video files.
from google.colab import drive
drive.mount('/content/drive')
Select Video File#
Update the path below to your video file. You’ll process one video at a time, and results will accumulate in all_corrected and all_velocity.
Tip: In the window at left, you can use the three vertical dots to copy and paste the path to your video file.
# Set the current video file to process
current_video = '/content/drive/MyDrive/video1.avi'
print(f'Current video: {current_video}')
Set Parameters#
Modify the frame_rate below as needed.
# Frame rate of the video
frame_rate = 30
diameter = 4 # Diameter of the dish in centimeters
search_size = 20 # Search size for fly detection (in pixels)
per_pixel_threshold = 1.5
bin_size = 1 # Bin size for velocity calculation (in seconds)
height = diameter
width = diameter
print(f'Frame rate: {frame_rate} fps')
Helper Functions#
Just run the cell below to define a set of functions to run our analysis.
#@title Define helper functions (fly_finder, dist_filter, interpolate_pos) { display-mode: "form" }
def fly_finder(roi_image, half_search, threshold, flip=True):
"""
Find a fly (dark region) in a grayscale image.
Locates the darkest pixel, retrieves a search area around it,
and finds the center of pixel intensity.
Returns (x, y) position or (NaN, NaN) if not found.
"""
if flip:
val = np.nanmin(roi_image)
else:
val = np.nanmax(roi_image)
ys, xs = np.where(roi_image == val)
xpos = np.mean(xs)
ypos = np.mean(ys)
h, w = roi_image.shape
left = max(int(round(xpos) - half_search), 0)
right = min(int(round(xpos) + half_search), w - 1)
top = max(int(round(ypos) - half_search), 0)
bottom = min(int(round(ypos) + half_search), h - 1)
search_area = roi_image[top:bottom+1, left:right+1].astype(np.float64)
if flip:
search_area = 255.0 - search_area
total = np.sum(search_area)
if total >= threshold:
# Center of mass
x_indices = np.arange(search_area.shape[1])
y_indices = np.arange(search_area.shape[0])
x = np.sum(search_area @ x_indices) / total + left
y = np.sum(search_area.T @ y_indices) / total + top
return x, y
else:
return np.nan, np.nan
def dist_filter(array, tele_dist_threshold, num_avg=5):
"""
Teleport filter: removes spurious points where fly position
jumps far from the mean of surrounding frames.
array: Nx3 array [time, x, y]
"""
filtered = array.copy()
tele_count = 0
for i in range(num_avg, len(filtered) - num_avg):
point = filtered[i, 1:3]
if np.any(np.isnan(point)):
continue
last_set = filtered[i - num_avg:i, 1:3]
last_set = last_set[~np.isnan(last_set[:, 0])]
if len(last_set) == 0:
continue
last_mean = np.mean(last_set, axis=0)
next_set = filtered[i + 1:i + 1 + num_avg, 1:3]
next_set = next_set[~np.isnan(next_set[:, 0])]
if len(next_set) == 0:
continue
next_mean = np.mean(next_set, axis=0)
if (euclidean(point, last_mean) > tele_dist_threshold or
euclidean(point, next_mean) > tele_dist_threshold):
filtered[i, 1:3] = np.nan
tele_count += 1
# More stringent check at start and end
for idx in list(range(0, min(5, len(filtered) - 1))) + \
list(range(max(0, len(filtered) - 6), len(filtered) - 1)):
if np.any(np.isnan(filtered[idx, 1:3])) or np.any(np.isnan(filtered[idx + 1, 1:3])):
continue
if euclidean(filtered[idx, 1:3], filtered[idx + 1, 1:3]) > tele_dist_threshold / 2:
filtered[idx, 1:3] = np.nan
tele_count += 1
print(f'{tele_count} points removed by the teleportation filter.')
return filtered
def interpolate_pos(array, inter_dist_threshold):
"""
Linearly interpolate fly position between NaN gaps,
as long as the gap endpoints are within inter_dist_threshold.
array: Nx3 array [time, x, y]
"""
result = array.copy()
interp_count = 0
col_pairs = [(1, 2)] # x, y columns
for cx, cy in col_pairs:
i = 0
while i < len(result):
if np.isnan(result[i, cx]) and i > 0:
last_idx = i - 1
last_point = result[last_idx, cx:cy+1]
# Find next non-NaN
remaining = result[i:, cx]
non_nan = np.where(~np.isnan(remaining))[0]
if len(non_nan) == 0:
break
next_idx = non_nan[0] + i
next_point = result[next_idx, cx:cy+1]
gap = next_idx - i
if euclidean(last_point, next_point) <= inter_dist_threshold:
for j in range(1, gap + 1):
frac = j / (gap + 1)
result[last_idx + j, cx:cy+1] = last_point + (next_point - last_point) * frac
interp_count += gap
i = next_idx
elif np.isnan(result[i, cx]) and i == 0:
non_nan = np.where(~np.isnan(result[:, cx]))[0]
if len(non_nan) == 0:
break
i = non_nan[0]
else:
i += 1
print(f'{interp_count} points recovered through interpolation.')
return result
def boundary_filter(array, diameter):
"""
Remove points that fall outside the circular dish.
The dish is assumed to be a circle of the given diameter,
centered at (diameter/2, diameter/2).
"""
filtered = array.copy()
cx, cy = diameter / 2, diameter / 2
radius = diameter / 2
x = filtered[:, 1]
y = filtered[:, 2]
dist_from_center = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
outside = dist_from_center > radius
# Don't count points that are already NaN
outside = outside & ~np.isnan(x)
filtered[outside, 1:3] = np.nan
count = np.sum(outside)
if count > 0:
print(f" {count} points removed by the boundary filter (outside dish).")
return filtered
print('Helper functions defined.')
Select ROI (Region of Interest)#
Run the cell below to display the first frame of your video. Click and drag directly on the image to draw a red rectangle around the dish. You can redraw as many times as you like. When you’re happy with the ROI, click Confirm ROI.
#@title Interactive ROI selector — click and drag on the image { display-mode: "form" }
# Read first frame and encode as base64 for display
cap = cv2.VideoCapture(current_video)
ret, frame = cap.read()
cap.release()
if not ret:
raise ValueError(f'Could not read first frame of {current_video}')
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_h, img_w = frame_rgb.shape[:2]
# Encode frame as PNG base64
_, buf = cv2.imencode('.png', cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR))
img_b64 = base64.b64encode(buf).decode('utf-8')
print(f'Image size: {img_w} x {img_h} pixels')
print('Click and drag on the image below to select the ROI, then click "Confirm ROI".\n')
# JavaScript/HTML interactive ROI selector
html_code = f"""
<div id="roi-container" style="position:relative; display:inline-block; cursor:crosshair;">
<canvas id="roi-canvas" width="{img_w}" height="{img_h}"
style="max-width:100%; height:auto; border:1px solid #ccc;"></canvas>
</div>
<br>
<button id="roi-confirm-btn" style="font-size:16px; padding:8px 20px; background:#4CAF50;
color:white; border:none; border-radius:4px; cursor:pointer; margin-top:8px;">
Confirm ROI
</button>
<p id="roi-info" style="font-size:14px; color:#333;"></p>
<script>
(function() {{
var canvas = document.getElementById('roi-canvas');
var ctx = canvas.getContext('2d');
var img = new window.Image();
img.src = 'data:image/png;base64,{img_b64}';
var drawing = false;
var x0 = 0, y0 = 0, x1 = 0, y1 = 0;
var hasROI = false;
// Scale factor: canvas CSS size vs actual pixel size
function getScale() {{
var rect = canvas.getBoundingClientRect();
return {{ sx: canvas.width / rect.width, sy: canvas.height / rect.height }};
}}
img.onload = function() {{
ctx.drawImage(img, 0, 0);
}};
function redraw() {{
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.drawImage(img, 0, 0);
if (hasROI) {{
var rx = Math.min(x0, x1);
var ry = Math.min(y0, y1);
var rw = Math.abs(x1 - x0);
var rh = Math.abs(y1 - y0);
ctx.strokeStyle = 'red';
ctx.lineWidth = 3;
ctx.strokeRect(rx, ry, rw, rh);
document.getElementById('roi-info').innerText =
'ROI: x=' + Math.round(rx) + ', y=' + Math.round(ry) +
', w=' + Math.round(rw) + ', h=' + Math.round(rh);
}}
}}
canvas.addEventListener('mousedown', function(e) {{
var rect = canvas.getBoundingClientRect();
var scale = getScale();
x0 = (e.clientX - rect.left) * scale.sx;
y0 = (e.clientY - rect.top) * scale.sy;
drawing = true;
hasROI = false;
}});
canvas.addEventListener('mousemove', function(e) {{
if (!drawing) return;
var rect = canvas.getBoundingClientRect();
var scale = getScale();
x1 = (e.clientX - rect.left) * scale.sx;
y1 = (e.clientY - rect.top) * scale.sy;
hasROI = true;
redraw();
}});
canvas.addEventListener('mouseup', function(e) {{
drawing = false;
}});
document.getElementById('roi-confirm-btn').addEventListener('click', function() {{
if (!hasROI) {{
document.getElementById('roi-info').innerText =
'No ROI drawn yet! Click and drag on the image first.';
return;
}}
var rx = Math.round(Math.min(x0, x1));
var ry = Math.round(Math.min(y0, y1));
var rw = Math.round(Math.abs(x1 - x0));
var rh = Math.round(Math.abs(y1 - y0));
// Send ROI back to Python via Colab's kernel eval
google.colab.kernel.invokeFunction('set_roi', [rx, ry, rw, rh], {{}});
document.getElementById('roi-info').innerText =
'ROI CONFIRMED: x=' + rx + ', y=' + ry + ', w=' + rw + ', h=' + rh;
document.getElementById('roi-confirm-btn').style.background = '#888';
document.getElementById('roi-confirm-btn').innerText = 'ROI Confirmed!';
}});
}})();
</script>
"""
# Register a callback so JavaScript can send the ROI to Python
roi_result = {'value': None}
def set_roi(x, y, w, h):
roi_result['value'] = (int(x), int(y), int(w), int(h))
print(f'ROI received: x={int(x)}, y={int(y)}, w={int(w)}, h={int(h)}')
output.register_callback('set_roi', set_roi)
display(HTML(html_code))
# Run this cell AFTER you have confirmed your ROI above
roi = roi_result['value']
if roi is None:
raise ValueError("No ROI selected! Go back and draw a rectangle on the image, then click 'Confirm ROI'.")
print(f'Using ROI: x={roi[0]}, y={roi[1]}, width={roi[2]}, height={roi[3]}')
Track Fly in Video(s)#
This processes each video frame-by-frame:
Creates a background image from 100 random frames
Subtracts the background from each frame
Finds the fly position using center-of-mass of dark pixels
Applies teleportation filter and interpolation
Converts pixel positions to centimeters
#@title Fly tracking engine — process_video() { display-mode: "form" }
def process_video(video_path, roi, diameter, frame_rate, search_size, per_pixel_threshold):
"""
Process a single fly video and return the corrected position array.
Returns (Nx3 array [time_s, x_cm, y_cm], frame_rate).
"""
height = diameter
width = diameter
roi_x, roi_y, roi_w, roi_h = roi
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
nfrm = total_frames - 1
# Auto-detect frame rate from video if not specified
if frame_rate is None:
if fps > 0:
frame_rate = round(fps)
print(f"\nProcessing: {os.path.basename(video_path)}")
print(f" Auto-detected frame rate: {frame_rate} fps ({total_frames} frames)")
else:
frame_rate = 30
print(f"\nProcessing: {os.path.basename(video_path)}")
print(f" WARNING: Could not detect frame rate, using default: {frame_rate} fps")
print(f" Re-run with --frame-rate to set the correct value.")
else:
print(f"\nProcessing: {os.path.basename(video_path)}")
print(f" Video FPS: {fps}, using override: {frame_rate} fps ({total_frames} frames)")
# Sanity check: warn if video duration seems unusual
if frame_rate > 0:
duration = total_frames / frame_rate
print(f" Estimated duration: {duration:.1f} seconds")
if duration < 50 or duration > 70:
print(f" WARNING: Expected ~60s video but got {duration:.1f}s.")
print(f" If this looks wrong, re-run with --frame-rate to override.")
# --- Create background from 100 random frames ---
print(" Calculating background...")
bg_number = min(100, nfrm)
bg_indices = sorted(np.random.choice(nfrm, bg_number, replace=False))
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, sample = cap.read()
gray_sample = cv2.cvtColor(sample, cv2.COLOR_BGR2GRAY)
bg_array = np.zeros((*gray_sample.shape, bg_number), dtype=np.uint8)
for idx, frame_num in enumerate(bg_indices):
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
ret, frame = cap.read()
if ret:
bg_array[:, :, idx] = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
background = np.mean(bg_array, axis=2).astype(np.uint8)
# --- Process each frame ---
print(" Tracking fly positions...")
threshold = (search_size ** 2) * per_pixel_threshold
half_search = round(search_size / 2)
pos_array = np.zeros((nfrm, 3))
for nofr in range(nfrm):
cap.set(cv2.CAP_PROP_POS_FRAMES, nofr)
ret, frame = cap.read()
if not ret:
pos_array[nofr] = [nofr, np.nan, np.nan]
continue
frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float64)
# Background subtraction using GIMP division formula
frame_div = np.clip(
(256.0 * frame_gray) / (background.astype(np.float64) + 1), 0, 255
).astype(np.uint8)
# Crop to ROI
frame_crop = frame_div[roi_y:roi_y + roi_h, roi_x:roi_x + roi_w]
# Find fly
fx, fy = fly_finder(frame_crop, half_search, threshold, flip=True)
pos_array[nofr] = [nofr, fx, fy]
# Progress update every 10%
if (nofr + 1) % max(1, nfrm // 10) == 0:
pct = (nofr + 1) / nfrm * 100
print(f" {pct:.0f}% complete ({nofr + 1}/{nfrm} frames)")
cap.release()
# --- Convert to real coordinates ---
xscale = width / roi_w
yscale = height / roi_h
corrected_array = np.column_stack([
pos_array[:, 0] / frame_rate,
pos_array[:, 1] * xscale,
pos_array[:, 2] * yscale,
])
skipped = np.sum(np.isnan(corrected_array[:, 1]))
print(f" {skipped} points skipped out of {nfrm}.")
# Apply filters and interpolation
corrected_array = boundary_filter(corrected_array, diameter)
corrected_array = displacement_filter(corrected_array)
corrected_array = dist_filter(corrected_array, 2)
corrected_array = interpolate_pos(corrected_array, 2)
return corrected_array, frame_rate
# --- Process the current video ---
corrected = process_video(current_video, roi, diameter, frame_rate,
search_size, per_pixel_threshold)
all_corrected.append(corrected)
print(f'\nDone! Total videos processed: {len(all_corrected)}')
print('To process another video, go back to "Select Video File" cell and change `current_video`, then run this cell again.')
Prepare for Another Video (Optional)#
If you want to process another video with the same ROI:
Go back to the Select Video File cell and update
current_videoto the new file pathRe-run the video selection cell
Skip the ROI selection (you can reuse the same ROI with
roi = roior draw a new one)Run the Track Fly cell again — results automatically append to
all_corrected
Or, to apply a different ROI for the next video, re-run the ROI selection cells before tracking.
Plot Fly Path(s)#
Color-coded by time (light blue = start, dark purple = end).
for idx, corrected in enumerate(all_corrected):
x = corrected[:, 1]
y = corrected[:, 2]
t = corrected[:, 0]
fig, ax = plt.subplots(figsize=(6, 6))
# Plot path colored by time
valid = ~np.isnan(x) & ~np.isnan(y)
points = np.array([x[valid], y[valid]]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
lc = LineCollection(segments, cmap='BuPu', linewidth=2)
lc.set_array(t[valid][:-1])
ax.add_collection(lc)
ax.set_xlim(0, width)
ax.set_ylim(height, 0) # Invert y to match video
ax.set_aspect('equal')
ax.set_xlabel('X-coordinate (cm)', fontsize=11)
ax.set_ylabel('Y-coordinate (cm)', fontsize=11)
ax.set_title(f'Fly Path — Video {idx + 1}')
cbar = fig.colorbar(lc, ax=ax, orientation='horizontal', pad=0.1)
cbar.set_label('Time (s)')
plt.tight_layout()
plt.show()
Calculate & Plot Velocity#
# Calculate velocity for videos that haven't been processed yet
start_idx = len(all_velocity) # Only calculate for newly added videos
for idx in range(start_idx, len(all_corrected)):
corrected = all_corrected[idx]
x = corrected[:, 1]
y = corrected[:, 2]
total_time = len(x) / frame_rate
total_bins = int(np.floor(total_time / bin_size))
# Calculate velocity per bin
data_rate = round(1.0 / corrected[1, 0]) * bin_size if corrected[1, 0] > 0 else frame_rate * bin_size
data_rate = int(data_rate)
if data_rate < 1:
raise ValueError('bin_size is smaller than the minimum data rate.')
velocity = np.zeros(total_bins)
for row in range(0, len(corrected) - data_rate, data_rate):
bin_idx = row // data_rate
if bin_idx >= total_bins:
break
p1 = corrected[row, 1:3]
p2 = corrected[row + data_rate, 1:3]
if np.any(np.isnan(p1)) or np.any(np.isnan(p2)):
velocity[bin_idx] = np.nan
else:
# 10x converts cm to mm
velocity[bin_idx] = 10.0 * euclidean(p1, p2)
# Convert from mm/bin to mm/s
velocity = velocity / bin_size
all_velocity.append(velocity)
# Warn about absurd velocities
if np.nanmax(velocity) > 30:
print(f'WARNING (Video {idx + 1}): Absurdly high velocities detected.')
print(' Consider changing the ROI or re-recording the video.')
mean_vel = np.nanmean(velocity)
std_vel = np.nanstd(velocity)
print(f'\n--- Video {idx + 1} (from: {os.path.basename(current_video)}) ---')
print(f' Mean velocity: {mean_vel:.2f} mm/s')
print(f' Std deviation: {std_vel:.2f} mm/s')
# Plot velocity over time
time_axis = np.arange(total_bins) * bin_size
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(time_axis, velocity, linewidth=1.5)
ax.set_xlim(0, time_axis[-1] + bin_size if len(time_axis) > 0 else 1)
ax.set_ylim(0, np.nanmax(velocity) * 1.5 if np.nanmax(velocity) > 0 else 1)
ax.set_xlabel('Time (s)', fontsize=11)
ax.set_ylabel('Velocity (mm/s)', fontsize=11)
ax.set_title(f'Fly Velocity — Video {idx + 1}')
plt.tight_layout()
plt.show()
print(f'\nVelocity analysis complete. Total videos analyzed: {len(all_velocity)}')
Summary Across All Videos#
num_files = len(all_velocity)
if num_files > 1:
# Pad velocity arrays to the same length for comparison
max_len = max(len(v) for v in all_velocity)
velocity_matrix = np.full((num_files, max_len), np.nan)
for i, v in enumerate(all_velocity):
velocity_matrix[i, :len(v)] = v
# Plot all velocities together
fig, ax = plt.subplots(figsize=(10, 5))
time_axis = np.arange(max_len) * bin_size
for i in range(num_files):
ax.plot(time_axis, velocity_matrix[i], linewidth=2,
label=f'Fly {i + 1}')
ax.set_xlabel('Time (s)', fontsize=11)
ax.set_ylabel('Velocity (mm/s)', fontsize=11)
ax.set_title('All Fly Velocities')
ax.legend()
plt.tight_layout()
plt.show()
# Summary stats across videos
per_video_means = [np.nanmean(v) for v in all_velocity]
mean_across = np.mean(per_video_means)
sd_across = np.std(per_video_means)
print(f'\n=== Summary Across {num_files} Videos ===')
print(f'Mean velocity across videos: {mean_across:.2f} mm/s')
print(f'SD of mean velocity across videos: {sd_across:.2f} mm/s')
else:
mean_vel = np.nanmean(all_velocity[0])
sd_vel = np.nanstd(all_velocity[0])
print(f'\n=== Summary (1 Video) ===')
print(f'Mean velocity: {mean_vel:.2f} mm/s')
print(f'SD of velocity: {sd_vel:.2f} mm/s')
Download Results (Optional)#
Download the tracking data as CSV files.
for idx, corrected in enumerate(all_corrected):
base = os.path.splitext(video_files[idx])[0]
csv_name = f'{base}_tracking.csv'
np.savetxt(csv_name, corrected, delimiter=',',
header='Time_s,X_cm,Y_cm', comments='')
files.download(csv_name)
print(f'Downloaded {csv_name}')
# Also save velocity data
for idx, vel in enumerate(all_velocity):
base = os.path.splitext(video_files[idx])[0]
csv_name = f'{base}_velocity.csv'
time_axis = np.arange(len(vel)) * bin_size
vel_data = np.column_stack([time_axis, vel])
np.savetxt(csv_name, vel_data, delimiter=',',
header='Time_s,Velocity_mm_per_s', comments='')
files.download(csv_name)
print(f'Downloaded {csv_name}')