Visualized SAM2 Video Segmentation with Gradio

This blog explains a simple implementation of a Gradio-based UI for the SAM2 (Segment Anything 2) video segmentation model. The code processes input videos, performs object segmentation using SAM2, and visualizes results. Let's break down the key components:

Imports & Device Configuration

1
2
3
4
5
6
7
8
9
10
import gradio as gr
import cv2, os, numpy as np
import torch
# ... other imports ...

# Device selection
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")

Functionality:

  • Imports necessary libraries for computer vision, deep learning, and UI creation
  • Automatically selects the best available computation device (CUDA GPU > CPU)

Video Processing Utilities

1
2
3
4
5
6
7
def video2images(video, output_folder):
# Video to frame extraction logic
# ... (skipped for brevity) ...

def images2video(image_folder, output_video_path, fps=24):
# Frame to video conversion logic
# ... (skipped for brevity) ...

Key Features:

  • video2images(): Extracts frames from video, skips first frame (assumed to be overlay grid)
  • images2video(): Reconstructs video from processed frames with specified FPS

SAM2 Inference Pipeline

Let's take a closer look at the core SAM2 inference implementation that powers the video segmentation. This section contains the most critical logic for processing user inputs and generating segmentation masks.

Model Initialization

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def sam2_inference(images_path, point_x, point_y, save_dir_name):
# Model configuration
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam2_checkpoint = "/path/to/checkpoint.pt"

# Handle different compute architectures
if device.type == "cuda":
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Initialize predictor with proper device handling
from sam2.build_sam import build_sam2_video_predictor
predictor = build_sam2_video_predictor(
model_cfg,
sam2_checkpoint,
device=device
)

Key Components:

  • Loads model configuration from YAML file
  • Initializes SAM2 predictor with pretrained weights

Prompt Processing & State Initialization

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Initialize inference state with first video frame
inference_state = predictor.init_state(video_path=images_path)
predictor.reset_state(inference_state)

# Process user click coordinates
ann_frame_idx = 0 # First frame index
ann_obj_id = 1 # First object ID

# Convert user click to model input
points = np.array([[int(point_x), int(point_y)]], dtype=np.float32)
labels = np.array([1], np.int32) # 1 = positive prompt

# Add prompts to predictor
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)

Critical Operations:

  • Load the images from the input video
  • Creates initial inference state from video frames
  • Add user click coordinates to a point prompt

Temporal Mask Propagation

1
2
3
4
5
6
7
# Perform video-level mask inference
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}

Core Algorithm:

  • Iterates through video frames sequentially

Mask Visualization & Output

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def coloring_mask(mask, color):
# Give a mask a color
# ... (skipped for brevity) ...

# Iterate output masks
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
frame_idx = cv2.imread(...)
frame_idx = torch.tensor(frame_idx, dtype=torch.float32).to(device)
mask_idx = torch.zeros_like(frame_idx)
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
# Apply color blending to masks
colored_mask = coloring_mask(out_mask,color[out_obj_id-1])
frame_copy = frame_idx.clone()
frame_copy[mask] = colored_mask[mask]
mask_idx[mask] = 255
alpha = 0.7
beta = 1 - alpha
frame_idx = torch.add(frame_copy * alpha, frame_idx * beta)
# Save to path
frame_idx = cv2.imwrite(..., frame_idx.cpu().numpy())
mask_idx = cv2.imwrite(..., mask_idx.cpu().numpy())

Visualization Techniques:

  • Implements blending (70% mask opacity & 30% frame opacity)
  • Saves both blended visualizations and raw masks

Main Function sam2_video()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def sam2_video(video, point_x, point_y):
# Convert video to images
video2image_saved = ...
output_video_folder = ...

fps, video_name = video2images(video, video2image_saved)
images_path = os.path.join(video2image_saved, video_name)

# Perform SAM2 inference
images_outputs_path, mask_dir = sam2_inference(images_path, height, width, video_name)

# Convert images back to video
output_video_path = os.path.join(output_video_folder, video_name + '.mp4')
output_mask_path = os.path.join(output_video_folder, video_name + '_mask.mp4')
images2video(images_outputs_path, output_video_path, fps)
images2video(mask_dir, output_mask_path, fps)

return output_video_path, output_mask_path

Explanation:

  • It first converts the video into individual images.
  • Then, it runs SAM2 segmentation on those images.
  • Finally, it converts the processed frames back into a video, both with and without masks, and returns the video paths for download.

Gradio Interface

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
if __name__ == "__main__":
with gr.Blocks() as demo:
with gr.Row():
input_video = gr.Video(format='mp4',label='Source Video')
first_img = gr.Image(label="First Image")

with gr.Row():
input_x_cord = gr.Textbox(label="X Cord")
input_y_cord = gr.Textbox(label="Y Cord")

with gr.Row():
text_button = gr.Button("Submit")

with gr.Row():
output_video = gr.Video(format='mp4',label="SAM2 Vis",show_download_button=True)
output_mask = gr.Video(format='mp4',label="Mask",show_download_button=True)

input_video.upload(load_first_img, [input_video],[first_img])
first_img.select(get_select_coords, [first_img,input_video], [input_x_cord, input_y_cord])
text_button.click(sam2_video, [input_video, input_x_cord, input_y_cord], [output_video, output_mask])

demo.launch()

Follow-up Development

This blog demonstrates the implementation of SAM2 with a single point prompt. You can extend it to multiple prompts as needed (by utilizing gr.render to input a custom number of prompts). In addition, event listeners can be implemented through gr.State that enable to control the frames displayed on the UI through buttons to facilitate the selection of prompts on multiple frames.

A more practical and stable implementation of SAM2 involves using a mask as the prompt. To achieve this you need to use predictor.predict() to generate a mask from a point or a box prompt using SAM (see SAM2 Official Tutorial: image_predictor_example.ipynb), which can then be used to invoke SAM2 with the mask prompt for further inference (mask prompt see SAM2 Official Tutorial: automatic_mask_generator_example.ipynb).