A diary of AI development and learning activities.
Visualized SAM2 Video Segmentation with Gradio
Total Views for This Post:
This blog explains a simple implementation of a Gradio-based UI for
the SAM2 (Segment Anything 2)
video segmentation model. The code processes input videos, performs
object segmentation using SAM2, and visualizes results. Let's break down
the key components:
Imports & Device
Configuration
1 2 3 4 5 6 7 8 9 10
import gradio as gr import cv2, os, numpy as np import torch # ... other imports ...
Imports necessary libraries for computer vision, deep learning, and
UI creation
Automatically selects the best available computation device (CUDA
GPU > CPU)
Video Processing Utilities
1 2 3 4 5 6 7
defvideo2images(video, output_folder): # Video to frame extraction logic # ... (skipped for brevity) ...
defimages2video(image_folder, output_video_path, fps=24): # Frame to video conversion logic # ... (skipped for brevity) ...
Key Features:
video2images(): Extracts frames from video, skips first
frame (assumed to be overlay grid)
images2video(): Reconstructs video from processed
frames with specified FPS
SAM2 Inference Pipeline
Let's take a closer look at the core SAM2 inference implementation
that powers the video segmentation. This section contains the most
critical logic for processing user inputs and generating segmentation
masks.
Model Initialization
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
defsam2_inference(images_path, point_x, point_y, save_dir_name): # Model configuration model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" sam2_checkpoint = "/path/to/checkpoint.pt" # Handle different compute architectures if device.type == "cuda": torch.autocast("cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
# Initialize predictor with proper device handling from sam2.build_sam import build_sam2_video_predictor predictor = build_sam2_video_predictor( model_cfg, sam2_checkpoint, device=device )
Key Components:
Loads model configuration from YAML file
Initializes SAM2 predictor with pretrained weights
# Initialize inference state with first video frame inference_state = predictor.init_state(video_path=images_path) predictor.reset_state(inference_state)
# Process user click coordinates ann_frame_idx = 0# First frame index ann_obj_id = 1# First object ID
# Convert user click to model input points = np.array([[int(point_x), int(point_y)]], dtype=np.float32) labels = np.array([1], np.int32) # 1 = positive prompt
if __name__ == "__main__": with gr.Blocks() as demo: with gr.Row(): input_video = gr.Video(format='mp4',label='Source Video') first_img = gr.Image(label="First Image")
with gr.Row(): input_x_cord = gr.Textbox(label="X Cord") input_y_cord = gr.Textbox(label="Y Cord") with gr.Row(): text_button = gr.Button("Submit")
This blog demonstrates the implementation of SAM2 with a single point
prompt. You can extend it to multiple prompts as needed (by utilizing
gr.render to input a custom number of prompts). In
addition, event listeners can be implemented through
gr.State that enable to control the frames displayed on the
UI through buttons to facilitate the selection of prompts on multiple
frames.