dgx-spark-playbooks/nvidia/station-gr00t/assets/patches/001-pyav-get-frames-by-indices.patch
2026-05-26 18:25:53 +00:00

44 lines
1.7 KiB
Diff

diff --git a/gr00t/utils/video_utils.py b/gr00t/utils/video_utils.py
index b3571e5..1cbd256 100644
--- a/gr00t/utils/video_utils.py
+++ b/gr00t/utils/video_utils.py
@@ -366,6 +366,29 @@ def _extract_all_frames_ffmpeg(video_path: str) -> tuple[np.ndarray, np.ndarray]
raise ValueError(f"Failed to extract frames from {video_path}: {e}")
+def _extract_frames_pyav_by_indices(video_path: str, frame_indices: list[int]) -> np.ndarray:
+ """Extract RGB frames at integer indices using PyAV (sequential decode).
+
+ Used when torchcodec/decord are unavailable. Short LIBERO / LeRobot clips are
+ typically small enough that scanning from the start beats per-frame ffmpeg.
+ """
+ idx_list = [int(i) for i in frame_indices]
+ need = set(idx_list)
+ max_idx = max(idx_list)
+ collected: dict[int, np.ndarray] = {}
+ with av.open(video_path) as container:
+ stream = container.streams.video[0]
+ for i, frame in enumerate(container.decode(stream)):
+ if i in need:
+ collected[i] = frame.to_ndarray(format="rgb24")
+ if i >= max_idx and len(collected) == len(need):
+ break
+ missing = need - set(collected.keys())
+ if missing:
+ raise ValueError(f"PyAV could not read frame indices {sorted(missing)} from {video_path}")
+ return np.stack([collected[i] for i in idx_list])
+
+
def get_frames_by_indices(
video_path: str,
indices: list[int] | np.ndarray,
@@ -398,6 +421,8 @@ def get_frames_by_indices(
cap.release()
frames = np.array(frames)
return frames
+ elif video_backend == "pyav":
+ return _extract_frames_pyav_by_indices(video_path, list(indices))
else:
raise NotImplementedError