Fly Tracker (Colab)#

This notebook tracks a fruit fly against a light background in a video, calculates its position and velocity over time, and produces path and velocity plots.

Original MATLAB code by Jeff Stafford, modified by A. Juavinett for BIPN 145.

How to use#

  1. Run the Setup cell to install dependencies.

  2. Upload your video (.avi, .mp4, etc.) using the upload cell.

  3. Set your parameters (dish diameter, frame rate).

  4. Draw your ROI on the first frame.

  5. Run the remaining cells to track the fly and view results.

Setup#

#@title Install & import packages { display-mode: "form" }
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from scipy.spatial.distance import euclidean
from google.colab import files, output
from IPython.display import display, clear_output, HTML, Image as IPImage
import ipywidgets as widgets
import os
import base64

# Initialize result containers (run once, then append to them)
if 'all_corrected' not in dir():
    all_corrected = []
    all_velocity = []
    print('Initialized result containers.')
else:
    print(f'Existing results: {len(all_corrected)} video(s) processed so far.')

print('All packages loaded successfully!')
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 3
      1 #@title Install & import packages { display-mode: "form" }
      2 import numpy as np
----> 3 import cv2
      4 import matplotlib.pyplot as plt
      5 import matplotlib.colors as mcolors

ModuleNotFoundError: No module named 'cv2'

Mount your drive#

Run the cell below to mount your drive, which should contain your video files.

from google.colab import drive
drive.mount('/content/drive')

Select Video File#

Update the path below to your video file. You’ll process one video at a time, and results will accumulate in all_corrected and all_velocity.

Tip: In the window at left, you can use the three vertical dots to copy and paste the path to your video file.

# Set the current video file to process
current_video = '/content/drive/MyDrive/video1.avi'
print(f'Current video: {current_video}')

Set Parameters#

Modify the frame_rate below as needed.

# Frame rate of the video
frame_rate = 30 

diameter = 4  # Diameter of the dish in centimeters
search_size = 20 # Search size for fly detection (in pixels)
per_pixel_threshold = 1.5
bin_size = 1 # Bin size for velocity calculation (in seconds)

height = diameter
width = diameter

print(f'Frame rate: {frame_rate} fps')

Helper Functions#

Just run the cell below to define a set of functions to run our analysis.

#@title Define helper functions (fly_finder, dist_filter, interpolate_pos) { display-mode: "form" }

def fly_finder(roi_image, half_search, threshold, flip=True):
    """
    Find a fly (dark region) in a grayscale image.
    Locates the darkest pixel, retrieves a search area around it,
    and finds the center of pixel intensity.

    Returns (x, y) position or (NaN, NaN) if not found.
    """
    if flip:
        val = np.nanmin(roi_image)
    else:
        val = np.nanmax(roi_image)

    ys, xs = np.where(roi_image == val)
    xpos = np.mean(xs)
    ypos = np.mean(ys)

    h, w = roi_image.shape
    left = max(int(round(xpos) - half_search), 0)
    right = min(int(round(xpos) + half_search), w - 1)
    top = max(int(round(ypos) - half_search), 0)
    bottom = min(int(round(ypos) + half_search), h - 1)

    search_area = roi_image[top:bottom+1, left:right+1].astype(np.float64)

    if flip:
        search_area = 255.0 - search_area

    total = np.sum(search_area)

    if total >= threshold:
        # Center of mass
        x_indices = np.arange(search_area.shape[1])
        y_indices = np.arange(search_area.shape[0])
        x = np.sum(search_area @ x_indices) / total + left
        y = np.sum(search_area.T @ y_indices) / total + top
        return x, y
    else:
        return np.nan, np.nan


def dist_filter(array, tele_dist_threshold, num_avg=5):
    """
    Teleport filter: removes spurious points where fly position
    jumps far from the mean of surrounding frames.

    array: Nx3 array [time, x, y]
    """
    filtered = array.copy()
    tele_count = 0

    for i in range(num_avg, len(filtered) - num_avg):
        point = filtered[i, 1:3]
        if np.any(np.isnan(point)):
            continue

        last_set = filtered[i - num_avg:i, 1:3]
        last_set = last_set[~np.isnan(last_set[:, 0])]
        if len(last_set) == 0:
            continue
        last_mean = np.mean(last_set, axis=0)

        next_set = filtered[i + 1:i + 1 + num_avg, 1:3]
        next_set = next_set[~np.isnan(next_set[:, 0])]
        if len(next_set) == 0:
            continue
        next_mean = np.mean(next_set, axis=0)

        if (euclidean(point, last_mean) > tele_dist_threshold or
                euclidean(point, next_mean) > tele_dist_threshold):
            filtered[i, 1:3] = np.nan
            tele_count += 1

    # More stringent check at start and end
    for idx in list(range(0, min(5, len(filtered) - 1))) + \
               list(range(max(0, len(filtered) - 6), len(filtered) - 1)):
        if np.any(np.isnan(filtered[idx, 1:3])) or np.any(np.isnan(filtered[idx + 1, 1:3])):
            continue
        if euclidean(filtered[idx, 1:3], filtered[idx + 1, 1:3]) > tele_dist_threshold / 2:
            filtered[idx, 1:3] = np.nan
            tele_count += 1

    print(f'{tele_count} points removed by the teleportation filter.')
    return filtered


def interpolate_pos(array, inter_dist_threshold):
    """
    Linearly interpolate fly position between NaN gaps,
    as long as the gap endpoints are within inter_dist_threshold.

    array: Nx3 array [time, x, y]
    """
    result = array.copy()
    interp_count = 0

    col_pairs = [(1, 2)]  # x, y columns
    for cx, cy in col_pairs:
        i = 0
        while i < len(result):
            if np.isnan(result[i, cx]) and i > 0:
                last_idx = i - 1
                last_point = result[last_idx, cx:cy+1]
                # Find next non-NaN
                remaining = result[i:, cx]
                non_nan = np.where(~np.isnan(remaining))[0]
                if len(non_nan) == 0:
                    break
                next_idx = non_nan[0] + i
                next_point = result[next_idx, cx:cy+1]
                gap = next_idx - i

                if euclidean(last_point, next_point) <= inter_dist_threshold:
                    for j in range(1, gap + 1):
                        frac = j / (gap + 1)
                        result[last_idx + j, cx:cy+1] = last_point + (next_point - last_point) * frac
                    interp_count += gap

                i = next_idx
            elif np.isnan(result[i, cx]) and i == 0:
                non_nan = np.where(~np.isnan(result[:, cx]))[0]
                if len(non_nan) == 0:
                    break
                i = non_nan[0]
            else:
                i += 1

    print(f'{interp_count} points recovered through interpolation.')
    return result

def boundary_filter(array, diameter):
    """
    Remove points that fall outside the circular dish.
    The dish is assumed to be a circle of the given diameter,
    centered at (diameter/2, diameter/2).
    """
    filtered = array.copy()
    cx, cy = diameter / 2, diameter / 2
    radius = diameter / 2

    x = filtered[:, 1]
    y = filtered[:, 2]
    dist_from_center = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)

    outside = dist_from_center > radius
    # Don't count points that are already NaN
    outside = outside & ~np.isnan(x)
    filtered[outside, 1:3] = np.nan

    count = np.sum(outside)
    if count > 0:
        print(f"  {count} points removed by the boundary filter (outside dish).")
    return filtered

print('Helper functions defined.')

Select ROI (Region of Interest)#

Run the cell below to display the first frame of your video. Click and drag directly on the image to draw a red rectangle around the dish. You can redraw as many times as you like. When you’re happy with the ROI, click Confirm ROI.

#@title Interactive ROI selector — click and drag on the image { display-mode: "form" }

# Read first frame and encode as base64 for display
cap = cv2.VideoCapture(current_video)
ret, frame = cap.read()
cap.release()
if not ret:
    raise ValueError(f'Could not read first frame of {current_video}')

frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_h, img_w = frame_rgb.shape[:2]

# Encode frame as PNG base64
_, buf = cv2.imencode('.png', cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR))
img_b64 = base64.b64encode(buf).decode('utf-8')

print(f'Image size: {img_w} x {img_h} pixels')
print('Click and drag on the image below to select the ROI, then click "Confirm ROI".\n')

# JavaScript/HTML interactive ROI selector
html_code = f"""
<div id="roi-container" style="position:relative; display:inline-block; cursor:crosshair;">
  <canvas id="roi-canvas" width="{img_w}" height="{img_h}"
    style="max-width:100%; height:auto; border:1px solid #ccc;"></canvas>
</div>
<br>
<button id="roi-confirm-btn" style="font-size:16px; padding:8px 20px; background:#4CAF50;
  color:white; border:none; border-radius:4px; cursor:pointer; margin-top:8px;">
  Confirm ROI
</button>
<p id="roi-info" style="font-size:14px; color:#333;"></p>

<script>
(function() {{
  var canvas = document.getElementById('roi-canvas');
  var ctx = canvas.getContext('2d');
  var img = new window.Image();
  img.src = 'data:image/png;base64,{img_b64}';

  var drawing = false;
  var x0 = 0, y0 = 0, x1 = 0, y1 = 0;
  var hasROI = false;

  // Scale factor: canvas CSS size vs actual pixel size
  function getScale() {{
    var rect = canvas.getBoundingClientRect();
    return {{ sx: canvas.width / rect.width, sy: canvas.height / rect.height }};
  }}

  img.onload = function() {{
    ctx.drawImage(img, 0, 0);
  }};

  function redraw() {{
    ctx.clearRect(0, 0, canvas.width, canvas.height);
    ctx.drawImage(img, 0, 0);
    if (hasROI) {{
      var rx = Math.min(x0, x1);
      var ry = Math.min(y0, y1);
      var rw = Math.abs(x1 - x0);
      var rh = Math.abs(y1 - y0);
      ctx.strokeStyle = 'red';
      ctx.lineWidth = 3;
      ctx.strokeRect(rx, ry, rw, rh);
      document.getElementById('roi-info').innerText =
        'ROI: x=' + Math.round(rx) + ', y=' + Math.round(ry) +
        ', w=' + Math.round(rw) + ', h=' + Math.round(rh);
    }}
  }}

  canvas.addEventListener('mousedown', function(e) {{
    var rect = canvas.getBoundingClientRect();
    var scale = getScale();
    x0 = (e.clientX - rect.left) * scale.sx;
    y0 = (e.clientY - rect.top) * scale.sy;
    drawing = true;
    hasROI = false;
  }});

  canvas.addEventListener('mousemove', function(e) {{
    if (!drawing) return;
    var rect = canvas.getBoundingClientRect();
    var scale = getScale();
    x1 = (e.clientX - rect.left) * scale.sx;
    y1 = (e.clientY - rect.top) * scale.sy;
    hasROI = true;
    redraw();
  }});

  canvas.addEventListener('mouseup', function(e) {{
    drawing = false;
  }});

  document.getElementById('roi-confirm-btn').addEventListener('click', function() {{
    if (!hasROI) {{
      document.getElementById('roi-info').innerText =
        'No ROI drawn yet! Click and drag on the image first.';
      return;
    }}
    var rx = Math.round(Math.min(x0, x1));
    var ry = Math.round(Math.min(y0, y1));
    var rw = Math.round(Math.abs(x1 - x0));
    var rh = Math.round(Math.abs(y1 - y0));

    // Send ROI back to Python via Colab's kernel eval
    google.colab.kernel.invokeFunction('set_roi', [rx, ry, rw, rh], {{}});
    document.getElementById('roi-info').innerText =
      'ROI CONFIRMED: x=' + rx + ', y=' + ry + ', w=' + rw + ', h=' + rh;
    document.getElementById('roi-confirm-btn').style.background = '#888';
    document.getElementById('roi-confirm-btn').innerText = 'ROI Confirmed!';
  }});
}})();
</script>
"""

# Register a callback so JavaScript can send the ROI to Python
roi_result = {'value': None}

def set_roi(x, y, w, h):
    roi_result['value'] = (int(x), int(y), int(w), int(h))
    print(f'ROI received: x={int(x)}, y={int(y)}, w={int(w)}, h={int(h)}')

output.register_callback('set_roi', set_roi)

display(HTML(html_code))
# Run this cell AFTER you have confirmed your ROI above
roi = roi_result['value']

if roi is None:
    raise ValueError("No ROI selected! Go back and draw a rectangle on the image, then click 'Confirm ROI'.")

print(f'Using ROI: x={roi[0]}, y={roi[1]}, width={roi[2]}, height={roi[3]}')

Track Fly in Video(s)#

This processes each video frame-by-frame:

  1. Creates a background image from 100 random frames

  2. Subtracts the background from each frame

  3. Finds the fly position using center-of-mass of dark pixels

  4. Applies teleportation filter and interpolation

  5. Converts pixel positions to centimeters

#@title Fly tracking engine — process_video() { display-mode: "form" }

def process_video(video_path, roi, diameter, frame_rate, search_size, per_pixel_threshold):
    """
    Process a single fly video and return the corrected position array.
    Returns (Nx3 array [time_s, x_cm, y_cm], frame_rate).
    """
    height = diameter
    width = diameter
    roi_x, roi_y, roi_w, roi_h = roi

    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    nfrm = total_frames - 1

    # Auto-detect frame rate from video if not specified
    if frame_rate is None:
        if fps > 0:
            frame_rate = round(fps)
            print(f"\nProcessing: {os.path.basename(video_path)}")
            print(f"  Auto-detected frame rate: {frame_rate} fps ({total_frames} frames)")
        else:
            frame_rate = 30
            print(f"\nProcessing: {os.path.basename(video_path)}")
            print(f"  WARNING: Could not detect frame rate, using default: {frame_rate} fps")
            print(f"    Re-run with --frame-rate to set the correct value.")
    else:
        print(f"\nProcessing: {os.path.basename(video_path)}")
        print(f"  Video FPS: {fps}, using override: {frame_rate} fps ({total_frames} frames)")

    # Sanity check: warn if video duration seems unusual
    if frame_rate > 0:
        duration = total_frames / frame_rate
        print(f"  Estimated duration: {duration:.1f} seconds")
        if duration < 50 or duration > 70:
            print(f"  WARNING: Expected ~60s video but got {duration:.1f}s.")
            print(f"    If this looks wrong, re-run with --frame-rate to override.")

    # --- Create background from 100 random frames ---
    print("  Calculating background...")
    bg_number = min(100, nfrm)
    bg_indices = sorted(np.random.choice(nfrm, bg_number, replace=False))

    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    ret, sample = cap.read()
    gray_sample = cv2.cvtColor(sample, cv2.COLOR_BGR2GRAY)
    bg_array = np.zeros((*gray_sample.shape, bg_number), dtype=np.uint8)

    for idx, frame_num in enumerate(bg_indices):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        ret, frame = cap.read()
        if ret:
            bg_array[:, :, idx] = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    background = np.mean(bg_array, axis=2).astype(np.uint8)

    # --- Process each frame ---
    print("  Tracking fly positions...")
    threshold = (search_size ** 2) * per_pixel_threshold
    half_search = round(search_size / 2)

    pos_array = np.zeros((nfrm, 3))

    for nofr in range(nfrm):
        cap.set(cv2.CAP_PROP_POS_FRAMES, nofr)
        ret, frame = cap.read()
        if not ret:
            pos_array[nofr] = [nofr, np.nan, np.nan]
            continue

        frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float64)

        # Background subtraction using GIMP division formula
        frame_div = np.clip(
            (256.0 * frame_gray) / (background.astype(np.float64) + 1), 0, 255
        ).astype(np.uint8)

        # Crop to ROI
        frame_crop = frame_div[roi_y:roi_y + roi_h, roi_x:roi_x + roi_w]

        # Find fly
        fx, fy = fly_finder(frame_crop, half_search, threshold, flip=True)
        pos_array[nofr] = [nofr, fx, fy]

        # Progress update every 10%
        if (nofr + 1) % max(1, nfrm // 10) == 0:
            pct = (nofr + 1) / nfrm * 100
            print(f"    {pct:.0f}% complete ({nofr + 1}/{nfrm} frames)")

    cap.release()

    # --- Convert to real coordinates ---
    xscale = width / roi_w
    yscale = height / roi_h

    corrected_array = np.column_stack([
        pos_array[:, 0] / frame_rate,
        pos_array[:, 1] * xscale,
        pos_array[:, 2] * yscale,
    ])

    skipped = np.sum(np.isnan(corrected_array[:, 1]))
    print(f"  {skipped} points skipped out of {nfrm}.")

    # Apply filters and interpolation
    corrected_array = boundary_filter(corrected_array, diameter)
    corrected_array = displacement_filter(corrected_array)
    corrected_array = dist_filter(corrected_array, 2)
    corrected_array = interpolate_pos(corrected_array, 2)

    return corrected_array, frame_rate

# --- Process the current video ---
corrected = process_video(current_video, roi, diameter, frame_rate,
                          search_size, per_pixel_threshold)
all_corrected.append(corrected)

print(f'\nDone! Total videos processed: {len(all_corrected)}')
print('To process another video, go back to "Select Video File" cell and change `current_video`, then run this cell again.')

Prepare for Another Video (Optional)#

If you want to process another video with the same ROI:

  1. Go back to the Select Video File cell and update current_video to the new file path

  2. Re-run the video selection cell

  3. Skip the ROI selection (you can reuse the same ROI with roi = roi or draw a new one)

  4. Run the Track Fly cell again — results automatically append to all_corrected

Or, to apply a different ROI for the next video, re-run the ROI selection cells before tracking.

Plot Fly Path(s)#

Color-coded by time (light blue = start, dark purple = end).

for idx, corrected in enumerate(all_corrected):
    x = corrected[:, 1]
    y = corrected[:, 2]
    t = corrected[:, 0]

    fig, ax = plt.subplots(figsize=(6, 6))

    # Plot path colored by time
    valid = ~np.isnan(x) & ~np.isnan(y)
    points = np.array([x[valid], y[valid]]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)

    lc = LineCollection(segments, cmap='BuPu', linewidth=2)
    lc.set_array(t[valid][:-1])
    ax.add_collection(lc)

    ax.set_xlim(0, width)
    ax.set_ylim(height, 0)  # Invert y to match video
    ax.set_aspect('equal')
    ax.set_xlabel('X-coordinate (cm)', fontsize=11)
    ax.set_ylabel('Y-coordinate (cm)', fontsize=11)
    ax.set_title(f'Fly Path — Video {idx + 1}')

    cbar = fig.colorbar(lc, ax=ax, orientation='horizontal', pad=0.1)
    cbar.set_label('Time (s)')

    plt.tight_layout()
    plt.show()

Calculate & Plot Velocity#

# Calculate velocity for videos that haven't been processed yet
start_idx = len(all_velocity)  # Only calculate for newly added videos

for idx in range(start_idx, len(all_corrected)):
    corrected = all_corrected[idx]
    x = corrected[:, 1]
    y = corrected[:, 2]

    total_time = len(x) / frame_rate
    total_bins = int(np.floor(total_time / bin_size))

    # Calculate velocity per bin
    data_rate = round(1.0 / corrected[1, 0]) * bin_size if corrected[1, 0] > 0 else frame_rate * bin_size
    data_rate = int(data_rate)

    if data_rate < 1:
        raise ValueError('bin_size is smaller than the minimum data rate.')

    velocity = np.zeros(total_bins)
    for row in range(0, len(corrected) - data_rate, data_rate):
        bin_idx = row // data_rate
        if bin_idx >= total_bins:
            break
        p1 = corrected[row, 1:3]
        p2 = corrected[row + data_rate, 1:3]
        if np.any(np.isnan(p1)) or np.any(np.isnan(p2)):
            velocity[bin_idx] = np.nan
        else:
            # 10x converts cm to mm
            velocity[bin_idx] = 10.0 * euclidean(p1, p2)

    # Convert from mm/bin to mm/s
    velocity = velocity / bin_size

    all_velocity.append(velocity)

    # Warn about absurd velocities
    if np.nanmax(velocity) > 30:
        print(f'WARNING (Video {idx + 1}): Absurdly high velocities detected.')
        print('  Consider changing the ROI or re-recording the video.')

    mean_vel = np.nanmean(velocity)
    std_vel = np.nanstd(velocity)
    print(f'\n--- Video {idx + 1} (from: {os.path.basename(current_video)}) ---')
    print(f'  Mean velocity: {mean_vel:.2f} mm/s')
    print(f'  Std deviation: {std_vel:.2f} mm/s')

    # Plot velocity over time
    time_axis = np.arange(total_bins) * bin_size
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.plot(time_axis, velocity, linewidth=1.5)
    ax.set_xlim(0, time_axis[-1] + bin_size if len(time_axis) > 0 else 1)
    ax.set_ylim(0, np.nanmax(velocity) * 1.5 if np.nanmax(velocity) > 0 else 1)
    ax.set_xlabel('Time (s)', fontsize=11)
    ax.set_ylabel('Velocity (mm/s)', fontsize=11)
    ax.set_title(f'Fly Velocity — Video {idx + 1}')
    plt.tight_layout()
    plt.show()

print(f'\nVelocity analysis complete. Total videos analyzed: {len(all_velocity)}')

Summary Across All Videos#

num_files = len(all_velocity)

if num_files > 1:
    # Pad velocity arrays to the same length for comparison
    max_len = max(len(v) for v in all_velocity)
    velocity_matrix = np.full((num_files, max_len), np.nan)
    for i, v in enumerate(all_velocity):
        velocity_matrix[i, :len(v)] = v

    # Plot all velocities together
    fig, ax = plt.subplots(figsize=(10, 5))
    time_axis = np.arange(max_len) * bin_size
    for i in range(num_files):
        ax.plot(time_axis, velocity_matrix[i], linewidth=2,
                label=f'Fly {i + 1}')
    ax.set_xlabel('Time (s)', fontsize=11)
    ax.set_ylabel('Velocity (mm/s)', fontsize=11)
    ax.set_title('All Fly Velocities')
    ax.legend()
    plt.tight_layout()
    plt.show()

    # Summary stats across videos
    per_video_means = [np.nanmean(v) for v in all_velocity]
    mean_across = np.mean(per_video_means)
    sd_across = np.std(per_video_means)
    print(f'\n=== Summary Across {num_files} Videos ===')
    print(f'Mean velocity across videos: {mean_across:.2f} mm/s')
    print(f'SD of mean velocity across videos: {sd_across:.2f} mm/s')
else:
    mean_vel = np.nanmean(all_velocity[0])
    sd_vel = np.nanstd(all_velocity[0])
    print(f'\n=== Summary (1 Video) ===')
    print(f'Mean velocity: {mean_vel:.2f} mm/s')
    print(f'SD of velocity: {sd_vel:.2f} mm/s')

Download Results (Optional)#

Download the tracking data as CSV files.

for idx, corrected in enumerate(all_corrected):
    base = os.path.splitext(video_files[idx])[0]
    csv_name = f'{base}_tracking.csv'
    np.savetxt(csv_name, corrected, delimiter=',',
                header='Time_s,X_cm,Y_cm', comments='')
    files.download(csv_name)
    print(f'Downloaded {csv_name}')

    # Also save velocity data
for idx, vel in enumerate(all_velocity):
    base = os.path.splitext(video_files[idx])[0]
    csv_name = f'{base}_velocity.csv'
    time_axis = np.arange(len(vel)) * bin_size
    vel_data = np.column_stack([time_axis, vel])
    np.savetxt(csv_name, vel_data, delimiter=',',
                header='Time_s,Velocity_mm_per_s', comments='')
    files.download(csv_name)
    print(f'Downloaded {csv_name}')