Fly Tracker (Colab)#

This notebook tracks a fruit fly against a light background in a video, calculates its position and velocity over time, and produces path and velocity plots.

Notebook steps#

  1. Run the Setup cell to import packages.

  2. Mount your Google Drive and copy the path to your video file.

  3. Create the Helper Functions.

  4. Run the ROI cell and draw your ROI on the first frame, then confirm it.

  5. Run the Track Fly cell and wait for it to finish.

  6. Run the Plot and Velocity cells to view results.

  7. Re-run steps 3-7 if you would like to track an additional video.

Setup#

#@title Install & import packages { display-mode: "form" }
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from scipy.spatial.distance import euclidean
from google.colab import files, output
from IPython.display import display, clear_output, HTML, Image as IPImage
import ipywidgets as widgets
import os
import base64

# Initialize result containers — each is created independently so adding
# new lists in future won't require a full session restart
# You can uncomment the following to reset all results if needed:
#all_corrected  = []
#all_velocity   = []
#all_fps        = []
#all_video_paths = []
#print('Results cleared — ready to process new videos.')

if 'all_corrected' not in dir(): all_corrected = []
if 'all_velocity'  not in dir(): all_velocity  = []
if 'all_fps'       not in dir(): all_fps       = []
if 'all_video_paths' not in dir(): all_video_paths = []
print(f'Ready. {len(all_corrected)} video(s) processed so far.')

print('All packages loaded successfully!')
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 3
      1 #@title Install & import packages { display-mode: "form" }
      2 import numpy as np
----> 3 import cv2
      4 import matplotlib.pyplot as plt
      5 import matplotlib.colors as mcolors

ModuleNotFoundError: No module named 'cv2'

Mount your drive#

Run the cell below to mount your drive, which should contain your video files.

If you’d prefer to upload your videos, you can do that instead (and do not need to run the cell below. See instructions here.)

from google.colab import drive
drive.mount('/content/drive')

Select Video File#

Update the path below to your video file. You’ll process one video at a time, and results will accumulate in all_corrected and all_velocity.

Tip: In the window at left, you can use the three vertical dots to copy and paste the path to your video file.

# Set the current video file to process
current_video = '/content/drive/MyDrive/video1.avi'
print(f'Current video: {current_video}')
#@title Define helper functions (fly_finder, dist_filter, interpolate_pos) { display-mode: "form" }

# --- Parameters ---
frame_rate = None  # auto-detect from video (set e.g. frame_rate = 30 to override)
diameter = 4       # dish diameter in centimeters
search_size = 20   # fly search window size in pixels
per_pixel_threshold = 1.5
bin_size = 1       # velocity bin size in seconds
# -------------------


def fly_finder(roi_image, half_search, threshold, flip=True):
    """
    Find a fly (dark region) in a grayscale image.
    Locates the darkest pixel, retrieves a search area around it,
    and finds the center of pixel intensity.

    Returns (x, y) position or (NaN, NaN) if not found.
    """
    if flip:
        val = np.nanmin(roi_image)
    else:
        val = np.nanmax(roi_image)

    ys, xs = np.where(roi_image == val)
    xpos = np.mean(xs)
    ypos = np.mean(ys)

    h, w = roi_image.shape
    left = max(int(round(xpos) - half_search), 0)
    right = min(int(round(xpos) + half_search), w - 1)
    top = max(int(round(ypos) - half_search), 0)
    bottom = min(int(round(ypos) + half_search), h - 1)

    search_area = roi_image[top:bottom+1, left:right+1].astype(np.float64)

    if flip:
        search_area = 255.0 - search_area

    total = np.sum(search_area)

    if total >= threshold:
        x_indices = np.arange(search_area.shape[1])
        y_indices = np.arange(search_area.shape[0])
        x = np.sum(search_area @ x_indices) / total + left
        y = np.sum(search_area.T @ y_indices) / total + top
        return x, y
    else:
        return np.nan, np.nan


def dist_filter(array, tele_dist_threshold, num_avg=5):
    """
    Teleport filter: removes spurious points where fly position
    jumps far from the mean of surrounding frames.

    array: Nx3 array [time, x, y]
    """
    filtered = array.copy()
    n = len(filtered)
    x = filtered[:, 1]
    y = filtered[:, 2]

    def rolling_nanmean(arr, window, forward=False):
        """Rolling mean over 'window' preceding (or following) valid elements."""
        valid = (~np.isnan(arr)).astype(np.float64)
        filled = np.where(np.isnan(arr), 0.0, arr)
        cs_val = np.concatenate([[0.0], np.cumsum(filled)])
        cs_cnt = np.concatenate([[0.0], np.cumsum(valid)])
        idx = np.arange(n)
        if forward:
            starts = idx + 1
            ends = np.minimum(n, idx + 1 + window)
        else:
            starts = np.maximum(0, idx - window)
            ends = idx
        sum_v = cs_val[ends] - cs_val[starts]
        cnt_v = cs_cnt[ends] - cs_cnt[starts]
        with np.errstate(invalid='ignore', divide='ignore'):
            return np.where(cnt_v > 0, sum_v / cnt_v, np.nan)

    last_x = rolling_nanmean(x, num_avg)
    last_y = rolling_nanmean(y, num_avg)
    next_x = rolling_nanmean(x, num_avg, forward=True)
    next_y = rolling_nanmean(y, num_avg, forward=True)

    dist_last = np.sqrt((x - last_x) ** 2 + (y - last_y) ** 2)
    dist_next = np.sqrt((x - next_x) ** 2 + (y - next_y) ** 2)

    valid_pts = ~np.isnan(x)
    tele_mask = valid_pts & (
        (dist_last > tele_dist_threshold) | (dist_next > tele_dist_threshold)
    )
    tele_mask[:num_avg] = False
    tele_mask[n - num_avg:] = False

    filtered[tele_mask, 1:3] = np.nan
    tele_count = int(tele_mask.sum())

    edge_indices = list(range(0, min(5, n - 1))) + list(range(max(0, n - 6), n - 1))
    for idx in edge_indices:
        if np.any(np.isnan(filtered[idx, 1:3])) or np.any(np.isnan(filtered[idx + 1, 1:3])):
            continue
        if euclidean(filtered[idx, 1:3], filtered[idx + 1, 1:3]) > tele_dist_threshold / 2:
            filtered[idx, 1:3] = np.nan
            tele_count += 1

    print(f'{tele_count} points removed by the teleportation filter.')
    return filtered


def interpolate_pos(array, inter_dist_threshold):
    """
    Linearly interpolate fly position between NaN gaps,
    as long as the gap endpoints are within inter_dist_threshold.

    array: Nx3 array [time, x, y]
    """
    result = array.copy()
    interp_count = 0

    x = result[:, 1]
    y = result[:, 2]
    nan_mask = np.isnan(x)

    if not nan_mask.any():
        print('0 points recovered through interpolation.')
        return result

    padded = np.concatenate([[False], nan_mask, [False]])
    diff = np.diff(padded.astype(np.int8))
    starts = np.where(diff == 1)[0]
    ends = np.where(diff == -1)[0] - 1

    for s, e in zip(starts, ends):
        last_idx = s - 1
        next_idx = e + 1
        if last_idx < 0 or next_idx >= len(result):
            continue
        last_pt = result[last_idx, 1:3]
        next_pt = result[next_idx, 1:3]
        if np.any(np.isnan(last_pt)) or np.any(np.isnan(next_pt)):
            continue
        if euclidean(last_pt, next_pt) <= inter_dist_threshold:
            gap = next_idx - last_idx
            fracs = np.arange(1, gap) / gap
            result[s:e + 1, 1] = last_pt[0] + fracs * (next_pt[0] - last_pt[0])
            result[s:e + 1, 2] = last_pt[1] + fracs * (next_pt[1] - last_pt[1])
            interp_count += e - s + 1

    print(f'{interp_count} points recovered through interpolation.')
    return result


def boundary_filter(array, diameter):
    """
    Remove points that fall outside the circular dish.
    """
    filtered = array.copy()
    cx, cy = diameter / 2, diameter / 2
    radius = diameter / 2
    x = filtered[:, 1]
    y = filtered[:, 2]
    dist_from_center = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
    outside = dist_from_center > radius
    outside = outside & ~np.isnan(x)
    filtered[outside, 1:3] = np.nan
    count = np.sum(outside)
    if count > 0:
        print(f'  {count} points removed by the boundary filter (outside dish).')
    return filtered


def displacement_filter(array, multiplier=5):
    """
    Remove points where the frame-to-frame displacement is much larger
    than the median displacement. This catches sudden jumps that the
    teleport filter might miss (e.g., runs of bad frames).
    """
    filtered = array.copy()
    x = filtered[:, 1]
    y = filtered[:, 2]

    dx = np.diff(x)
    dy = np.diff(y)
    displacements = np.sqrt(dx ** 2 + dy ** 2)

    valid_disp = displacements[~np.isnan(displacements)]
    if len(valid_disp) == 0:
        return filtered
    median_disp = np.nanmedian(valid_disp)

    if median_disp == 0:
        return filtered

    threshold = median_disp * multiplier
    count = 0

    for i in range(len(displacements)):
        if np.isnan(displacements[i]):
            continue
        if displacements[i] > threshold:
            filtered[i + 1, 1:3] = np.nan
            count += 1

    if count > 0:
        print(f'  {count} points removed by the displacement filter (>{multiplier}x median jump).')
    return filtered


def process_video(video_path, roi, diameter, frame_rate, search_size, per_pixel_threshold):
    """
    Process a single fly video and return the corrected position array.
    Returns (Nx3 array [time_s, x_cm, y_cm], frame_rate).
    """
    height = diameter
    width = diameter
    roi_x, roi_y, roi_w, roi_h = roi

    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    nfrm = total_frames - 1

    # Auto-detect frame rate from video if not specified
    if frame_rate is None:
        if fps > 0:
            frame_rate = round(fps)
            print(f"\nProcessing: {os.path.basename(video_path)}")
            print(f"  Auto-detected frame rate: {frame_rate} fps ({total_frames} frames)")
        else:
            frame_rate = 30
            print(f"\nProcessing: {os.path.basename(video_path)}")
            print(f"  WARNING: Could not detect frame rate, using default: {frame_rate} fps")
            print(f"    Override by setting frame_rate = 30 (or your value) in the Set Parameters cell.")
    else:
        print(f"\nProcessing: {os.path.basename(video_path)}")
        print(f"  Video FPS: {fps}, using override: {frame_rate} fps ({total_frames} frames)")

    # Sanity check: warn if video duration seems unusual
    if frame_rate > 0:
        duration = total_frames / frame_rate
        print(f"  Estimated duration: {duration:.1f} seconds")
        if duration < 50 or duration > 70:
            print(f"  WARNING: Expected ~60s video but got {duration:.1f}s.")
            print(f"    Override by setting frame_rate = 30 (or your value) in the Set Parameters cell.")

    # --- Create background from 100 random frames ---
    print("  Calculating background...")
    bg_number = min(100, nfrm)
    bg_indices = sorted(np.random.choice(nfrm, bg_number, replace=False))

    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    ret, sample = cap.read()
    gray_sample = cv2.cvtColor(sample, cv2.COLOR_BGR2GRAY)
    bg_array = np.zeros((*gray_sample.shape, bg_number), dtype=np.uint8)

    for idx, frame_num in enumerate(bg_indices):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        ret, frame = cap.read()
        if ret:
            bg_array[:, :, idx] = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    background = np.mean(bg_array, axis=2).astype(np.uint8)
    bg_scale = 256.0 / (background.astype(np.float64) + 1)
    # Pre-crop bg_scale to ROI so per-frame math runs on ROI pixels only
    bg_scale_roi = bg_scale[roi_y:roi_y + roi_h, roi_x:roi_x + roi_w]

    # --- Process each frame ---
    print("  Tracking fly positions...")
    threshold = (search_size ** 2) * per_pixel_threshold
    half_search = round(search_size / 2)

    pos_array = np.zeros((nfrm, 3))

    # Seek once; sequential cap.read() is far faster than per-frame seeking
    # (compressed formats must decode from the last keyframe on every seek)
    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    for nofr in range(nfrm):
        ret, frame = cap.read()
        if not ret:
            pos_array[nofr] = [nofr, np.nan, np.nan]
            continue

        # Crop to ROI before float conversion so background subtraction
        # operates on ROI pixels only, not the full frame
        frame_roi = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)[
            roi_y:roi_y + roi_h, roi_x:roi_x + roi_w
        ].astype(np.float64)
        frame_div = np.clip(frame_roi * bg_scale_roi, 0, 255).astype(np.uint8)

        # Find fly
        fx, fy = fly_finder(frame_div, half_search, threshold, flip=True)
        pos_array[nofr] = [nofr, fx, fy]

        # Progress update every 10%
        if (nofr + 1) % max(1, nfrm // 10) == 0:
            pct = (nofr + 1) / nfrm * 100
            print(f"    {pct:.0f}% complete ({nofr + 1}/{nfrm} frames)")

    cap.release()

    # --- Convert to real coordinates ---
    xscale = width / roi_w
    yscale = height / roi_h

    corrected_array = np.column_stack([
        pos_array[:, 0] / frame_rate,
        pos_array[:, 1] * xscale,
        pos_array[:, 2] * yscale,
    ])

    skipped = np.sum(np.isnan(corrected_array[:, 1]))
    print(f"  {skipped} points skipped out of {nfrm}.")

    # Apply filters and interpolation
    corrected_array = boundary_filter(corrected_array, diameter)
    corrected_array = displacement_filter(corrected_array)
    corrected_array = dist_filter(corrected_array, 2)
    corrected_array = interpolate_pos(corrected_array, 2)

    return corrected_array, frame_rate

print('Helper functions defined.')
#@title Interactive ROI selector — click and drag on the image { display-mode: "form" }

# Read first frame and encode as base64 for display
cap = cv2.VideoCapture(current_video)
ret, frame = cap.read()
cap.release()
if not ret:
    raise ValueError(f'Could not read first frame of {current_video}')

frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_h, img_w = frame_rgb.shape[:2]

# Encode frame as PNG base64
_, buf = cv2.imencode('.png', cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR))
img_b64 = base64.b64encode(buf).decode('utf-8')

print(f'Image size: {img_w} x {img_h} pixels')
print('Click and drag on the image below to select the ROI, then click "Confirm ROI".\n')

# JavaScript/HTML interactive ROI selector
html_code = f"""
<div id="roi-container" style="position:relative; display:inline-block; cursor:crosshair;">
  <canvas id="roi-canvas" width="{img_w}" height="{img_h}"
    style="max-width:100%; height:auto; border:1px solid #ccc;"></canvas>
</div>
<br>
<button id="roi-confirm-btn" style="font-size:16px; padding:8px 20px; background:#4CAF50;
  color:white; border:none; border-radius:4px; cursor:pointer; margin-top:8px;">
  Confirm ROI
</button>
<p id="roi-info" style="font-size:14px; color:#333;"></p>

<script>
(function() {{
  var canvas = document.getElementById('roi-canvas');
  var ctx = canvas.getContext('2d');
  var img = new window.Image();
  img.src = 'data:image/png;base64,{img_b64}';

  var drawing = false;
  var x0 = 0, y0 = 0, x1 = 0, y1 = 0;
  var hasROI = false;

  // Scale factor: canvas CSS size vs actual pixel size
  function getScale() {{
    var rect = canvas.getBoundingClientRect();
    return {{ sx: canvas.width / rect.width, sy: canvas.height / rect.height }};
  }}

  img.onload = function() {{
    ctx.drawImage(img, 0, 0);
  }};

  function redraw() {{
    ctx.clearRect(0, 0, canvas.width, canvas.height);
    ctx.drawImage(img, 0, 0);
    if (hasROI) {{
      var rx = Math.min(x0, x1);
      var ry = Math.min(y0, y1);
      var rw = Math.abs(x1 - x0);
      var rh = Math.abs(y1 - y0);
      ctx.strokeStyle = 'red';
      ctx.lineWidth = 3;
      ctx.strokeRect(rx, ry, rw, rh);
      document.getElementById('roi-info').innerText =
        'ROI: x=' + Math.round(rx) + ', y=' + Math.round(ry) +
        ', w=' + Math.round(rw) + ', h=' + Math.round(rh);
    }}
  }}

  canvas.addEventListener('mousedown', function(e) {{
    var rect = canvas.getBoundingClientRect();
    var scale = getScale();
    x0 = (e.clientX - rect.left) * scale.sx;
    y0 = (e.clientY - rect.top) * scale.sy;
    drawing = true;
    hasROI = false;
  }});

  canvas.addEventListener('mousemove', function(e) {{
    if (!drawing) return;
    var rect = canvas.getBoundingClientRect();
    var scale = getScale();
    x1 = (e.clientX - rect.left) * scale.sx;
    y1 = (e.clientY - rect.top) * scale.sy;
    hasROI = true;
    redraw();
  }});

  canvas.addEventListener('mouseup', function(e) {{
    drawing = false;
  }});

  document.getElementById('roi-confirm-btn').addEventListener('click', function() {{
    if (!hasROI) {{
      document.getElementById('roi-info').innerText =
        'No ROI drawn yet! Click and drag on the image first.';
      return;
    }}
    var rx = Math.round(Math.min(x0, x1));
    var ry = Math.round(Math.min(y0, y1));
    var rw = Math.round(Math.abs(x1 - x0));
    var rh = Math.round(Math.abs(y1 - y0));

    // Send ROI back to Python via Colab's kernel eval
    google.colab.kernel.invokeFunction('set_roi', [rx, ry, rw, rh], {{}});
    document.getElementById('roi-info').innerText =
      'ROI CONFIRMED: x=' + rx + ', y=' + ry + ', w=' + rw + ', h=' + rh;
    document.getElementById('roi-confirm-btn').style.background = '#888';
    document.getElementById('roi-confirm-btn').innerText = 'ROI Confirmed!';
  }});
}})();
</script>
"""

# Register a callback so JavaScript can send the ROI to Python
roi_result = {'value': None}

def set_roi(x, y, w, h):
    roi_result['value'] = (int(x), int(y), int(w), int(h))
    print(f'ROI received: x={int(x)}, y={int(y)}, w={int(w)}, h={int(h)}')

output.register_callback('set_roi', set_roi)

display(HTML(html_code))
# Run this cell AFTER you have confirmed your ROI above
roi = roi_result['value']

if roi is None:
    raise ValueError("No ROI selected! Go back and draw a rectangle on the image, then click 'Confirm ROI'.")

print(f'Using ROI: x={roi[0]}, y={roi[1]}, width={roi[2]}, height={roi[3]}')

Track Fly in Video(s)#

This processes each video frame-by-frame:

  1. Creates a background image from 100 random frames

  2. Subtracts the background from each frame

  3. Finds the fly position using center-of-mass of dark pixels

  4. Applies teleportation filter and interpolation

  5. Converts pixel positions to centimeters

# --- Process the current video ---
corrected, detected_fps = process_video(current_video, roi, diameter, frame_rate,
                                        search_size, per_pixel_threshold)
all_corrected.append(corrected)
all_fps.append(detected_fps)
all_video_paths.append(current_video)

print(f'\nDone! Total videos processed: {len(all_corrected)}')
print('To process another video, go back to "Select Video File" cell and change `current_video`, then run this cell again.')

Plot Fly Path(s)#

Color-coded by time (light blue = start, dark purple = end).

for idx, corrected in enumerate(all_corrected):
    x = corrected[:, 1]
    y = corrected[:, 2]
    t = corrected[:, 0]

    fig, ax = plt.subplots(figsize=(6, 6))

    # Plot path colored by time
    valid = ~np.isnan(x) & ~np.isnan(y)
    points = np.array([x[valid], y[valid]]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)

    lc = LineCollection(segments, cmap='BuPu', linewidth=2)
    lc.set_array(t[valid][:-1])
    ax.add_collection(lc)

    ax.set_xlim(0, diameter)
    ax.set_ylim(diameter, 0)  # Invert y to match video
    ax.set_aspect('equal')
    ax.set_xlabel('X-coordinate (cm)', fontsize=11)
    ax.set_ylabel('Y-coordinate (cm)', fontsize=11)
    ax.set_title(f'Fly Path — Video {idx + 1}')

    cbar = fig.colorbar(lc, ax=ax, orientation='horizontal', pad=0.1)
    cbar.set_label('Time (s)')

    plt.tight_layout()
    plt.show()

Calculate & Plot Velocity#

# Calculate velocity for videos that haven't been processed yet
start_idx = len(all_velocity)  # Only calculate for newly added videos

for idx in range(start_idx, len(all_corrected)):
    corrected = all_corrected[idx]
    x = corrected[:, 1]
    y = corrected[:, 2]

    fps = round(1.0 / corrected[1, 0]) if corrected[1, 0] > 0 else 30
    total_time = len(x) / fps
    total_bins = int(np.floor(total_time / bin_size))

    # Calculate velocity per bin
    data_rate = round(1.0 / corrected[1, 0]) * bin_size if corrected[1, 0] > 0 else fps * bin_size
    data_rate = int(data_rate)

    if data_rate < 1:
        raise ValueError('bin_size is smaller than the minimum data rate.')

    velocity = np.zeros(total_bins)
    for row in range(0, len(corrected) - data_rate, data_rate):
        bin_idx = row // data_rate
        if bin_idx >= total_bins:
            break
        p1 = corrected[row, 1:3]
        p2 = corrected[row + data_rate, 1:3]
        if np.any(np.isnan(p1)) or np.any(np.isnan(p2)):
            velocity[bin_idx] = np.nan
        else:
            # 10x converts cm to mm
            velocity[bin_idx] = 10.0 * euclidean(p1, p2)

    # Convert from mm/bin to mm/s
    velocity = velocity / bin_size

    all_velocity.append(velocity)

    # Warn about absurd velocities
    if np.nanmax(velocity) > 30:
        print(f'WARNING (Video {idx + 1}): Absurdly high velocities detected.')
        print('  Consider changing the ROI or re-recording the video.')

    mean_vel = np.nanmean(velocity)
    std_vel = np.nanstd(velocity)
    print(f'\n--- Video {idx + 1} (from: {os.path.basename(current_video)}) ---')
    print(f'  Mean velocity: {mean_vel:.2f} mm/s')
    print(f'  Std deviation: {std_vel:.2f} mm/s')

    # Plot velocity over time
    time_axis = np.arange(total_bins) * bin_size
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.plot(time_axis, velocity, linewidth=1.5)
    ax.set_xlim(0, time_axis[-1] + bin_size if len(time_axis) > 0 else 1)
    ax.set_ylim(0, np.nanmax(velocity) * 1.5 if np.nanmax(velocity) > 0 else 1)
    ax.set_xlabel('Time (s)', fontsize=11)
    ax.set_ylabel('Velocity (mm/s)', fontsize=11)
    ax.set_title(f'Fly Velocity — Video {idx + 1}')
    plt.tight_layout()
    plt.show()

print(f'\nVelocity analysis complete. Total videos analyzed: {len(all_velocity)}')

Summary Across All Videos#

num_files = len(all_velocity)

if num_files > 1:
    # Pad velocity arrays to the same length for comparison
    max_len = max(len(v) for v in all_velocity)
    velocity_matrix = np.full((num_files, max_len), np.nan)
    for i, v in enumerate(all_velocity):
        velocity_matrix[i, :len(v)] = v

    # Plot all velocities together
    fig, ax = plt.subplots(figsize=(10, 5))
    time_axis = np.arange(max_len) * bin_size
    for i in range(num_files):
        ax.plot(time_axis, velocity_matrix[i], linewidth=2,
                label=f'Fly {i + 1}')
    ax.set_xlabel('Time (s)', fontsize=11)
    ax.set_ylabel('Velocity (mm/s)', fontsize=11)
    ax.set_title('All Fly Velocities')
    ax.legend()
    plt.tight_layout()
    plt.show()

    # Summary stats across videos
    per_video_means = [np.nanmean(v) for v in all_velocity]
    mean_across = np.mean(per_video_means)
    sd_across = np.std(per_video_means)
    print(f'\n=== Summary Across {num_files} Videos ===')
    print(f'Mean velocity across videos: {mean_across:.2f} mm/s')
    print(f'SD of mean velocity across videos: {sd_across:.2f} mm/s')
else:
    mean_vel = np.nanmean(all_velocity[0])
    sd_vel = np.nanstd(all_velocity[0])
    print(f'\n=== Summary (1 Video) ===')
    print(f'Mean velocity: {mean_vel:.2f} mm/s')
    print(f'SD of velocity: {sd_vel:.2f} mm/s')

Prepare for Another Video (Optional)#

If you want to process another video with the same ROI:

  1. Go back to the Select Video File cell and update current_video to the new file path

  2. Re-run the video selection cell

  3. Skip the ROI selection (you can reuse the same ROI with roi = roi or draw a new one)

  4. Run the Track Fly cell again — results automatically append to all_corrected

Or, to apply a different ROI for the next video, re-run the ROI selection cells before tracking.

Download Results (Optional)#

Download the tracking data as CSV files.

for idx, (corrected, vel) in enumerate(zip(all_corrected, all_velocity)):
    base = os.path.splitext(os.path.basename(all_video_paths[idx]))[0]

    csv_name = f'{base}_tracking.csv'
    np.savetxt(csv_name, corrected, delimiter=',',
               header='Time_s,X_cm,Y_cm', comments='')
    files.download(csv_name)
    print(f'Downloaded {csv_name}')

    time_axis = np.arange(len(vel)) * bin_size
    vel_csv = f'{base}_velocity.csv'
    np.savetxt(vel_csv, np.column_stack([time_axis, vel]), delimiter=',',
               header='Time_s,Velocity_mm_per_s', comments='')
    files.download(vel_csv)
    print(f'Downloaded {vel_csv}')