Skip to content

Commit bc2ca14

Browse files
authored
Benchmark encoding against ffmpeg cli (#1074)
1 parent 4eb36b4 commit bc2ca14

File tree

1 file changed

+227
-0
lines changed

1 file changed

+227
-0
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
#!/usr/bin/env python3
2+
import shutil
3+
import subprocess
4+
import tempfile
5+
from argparse import ArgumentParser
6+
from pathlib import Path
7+
from time import perf_counter_ns
8+
9+
import pynvml
10+
import torch
11+
from torchcodec.decoders import VideoDecoder
12+
from torchcodec.encoders import VideoEncoder
13+
14+
pynvml.nvmlInit()
15+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
16+
17+
FRAME_RATE = 30
18+
DEFAULT_VIDEO_PATH = "test/resources/nasa_13013.mp4"
19+
# Alternatively, run this command to generate a longer test video:
20+
# ffmpeg -f lavfi -i testsrc2=duration=600:size=1280x720:rate=30 -c:v libx264 -pix_fmt yuv420p test/resources/testsrc2_10min.mp4
21+
22+
23+
def bench(f, average_over=50, warmup=2, gpu_monitoring=False, **f_kwargs):
24+
for _ in range(warmup):
25+
f(**f_kwargs)
26+
27+
times = []
28+
utilizations = []
29+
memory_usage = []
30+
31+
for _ in range(average_over):
32+
start = perf_counter_ns()
33+
f(**f_kwargs)
34+
end = perf_counter_ns()
35+
times.append(end - start)
36+
37+
if gpu_monitoring:
38+
util = pynvml.nvmlDeviceGetEncoderUtilization(handle)[0]
39+
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
40+
mem_used = mem_info.used / (1_000_000) # Convert bytes to MB
41+
utilizations.append(util)
42+
memory_usage.append(mem_used)
43+
44+
times_tensor = torch.tensor(times).float()
45+
return times_tensor, {
46+
"utilization": torch.tensor(utilizations).float() if gpu_monitoring else None,
47+
"memory_used": torch.tensor(memory_usage).float() if gpu_monitoring else None,
48+
}
49+
50+
51+
def report_stats(times, num_frames, nvenc_metrics=None, prefix="", unit="ms"):
52+
fps = num_frames * 1e9 / times.median()
53+
54+
mul = {
55+
"ns": 1,
56+
"µs": 1e-3,
57+
"ms": 1e-6,
58+
"s": 1e-9,
59+
}[unit]
60+
unit_times = times * mul
61+
med = unit_times.median().item()
62+
max = unit_times.max().item()
63+
print(f"\n{prefix} {med = :.2f} {unit}, {max = :.2f} {unit}, fps = {fps:.1f}")
64+
65+
if nvenc_metrics is not None:
66+
mem_used_max = nvenc_metrics["memory_used"].max().item()
67+
mem_used_median = nvenc_metrics["memory_used"].median().item()
68+
util_max = nvenc_metrics["utilization"].max().item()
69+
70+
print(
71+
f"GPU memory used: med = {mem_used_median:.1f} MB, max = {mem_used_max:.1f} MB"
72+
)
73+
print(
74+
f"NVENC utilization: med = {nvenc_metrics["utilization"].median():.1f}%, max = {util_max:.1f}%"
75+
)
76+
77+
78+
def encode_torchcodec(frames, output_path, device="cpu"):
79+
encoder = VideoEncoder(frames=frames, frame_rate=FRAME_RATE)
80+
if device == "cuda":
81+
encoder.to_file(dest=output_path, codec="h264_nvenc", extra_options={"qp": 0})
82+
else:
83+
encoder.to_file(dest=output_path, codec="libx264", crf=0)
84+
85+
86+
def write_raw_frames(frames, raw_path):
87+
# Convert NCHW to NHWC for raw video format
88+
raw_frames = frames.permute(0, 2, 3, 1)
89+
with open(raw_path, "wb") as f:
90+
f.write(raw_frames.cpu().numpy().tobytes())
91+
92+
93+
def encode_ffmpeg_cli(
94+
frames, raw_path, output_path, device="cpu", skip_write_frames=False
95+
):
96+
# Write frames during benchmarking function by default unless skip_write_frames flag used
97+
if not skip_write_frames:
98+
write_raw_frames(frames, raw_path)
99+
100+
ffmpeg_cmd = [
101+
"ffmpeg",
102+
"-y",
103+
"-f",
104+
"rawvideo",
105+
"-pix_fmt",
106+
"rgb24",
107+
"-s",
108+
f"{frames.shape[3]}x{frames.shape[2]}",
109+
"-r",
110+
str(FRAME_RATE),
111+
"-i",
112+
raw_path,
113+
"-c:v",
114+
"h264_nvenc" if device == "cuda" else "libx264",
115+
"-pix_fmt",
116+
"yuv420p",
117+
]
118+
ffmpeg_cmd.extend(["-qp", "0"] if device == "cuda" else ["-crf", "0"])
119+
ffmpeg_cmd.extend([str(output_path)])
120+
subprocess.run(ffmpeg_cmd, check=True, capture_output=True)
121+
122+
123+
def main():
124+
parser = ArgumentParser()
125+
parser.add_argument(
126+
"--path", type=str, help="Path to input video file", default=DEFAULT_VIDEO_PATH
127+
)
128+
parser.add_argument(
129+
"--average-over",
130+
type=int,
131+
default=30,
132+
help="Number of runs to average over",
133+
)
134+
parser.add_argument(
135+
"--max-frames",
136+
type=int,
137+
default=None,
138+
help="Maximum number of frames to decode for benchmarking. By default, all frames will be decoded.",
139+
)
140+
parser.add_argument(
141+
"--skip-write-frames",
142+
action="store_true",
143+
help="Do not write raw frames in FFmpeg CLI benchmarks",
144+
)
145+
args = parser.parse_args()
146+
decoder = VideoDecoder(str(args.path))
147+
frames = decoder.get_frames_in_range(start=0, stop=args.max_frames).data
148+
149+
cuda_available = torch.cuda.is_available()
150+
if not cuda_available:
151+
print("CUDA not available. GPU benchmarks will be skipped.")
152+
153+
print(
154+
f"Benchmarking {len(frames)} frames from {Path(args.path).name} over {args.average_over} runs:"
155+
)
156+
gpu_frames = frames.cuda() if cuda_available else None
157+
print(
158+
f"Decoded {frames.shape[0]} frames of size {frames.shape[2]}x{frames.shape[3]}"
159+
)
160+
161+
temp_dir = Path(tempfile.mkdtemp())
162+
raw_frames_path = temp_dir / "input_frames.raw"
163+
164+
# If skip_write_frames is True, we will not benchmark the time it takes to write the frames.
165+
# Here, we still write the frames for FFmpeg to use!
166+
if args.skip_write_frames:
167+
write_raw_frames(frames, str(raw_frames_path))
168+
169+
if cuda_available:
170+
# Benchmark torchcodec on GPU
171+
gpu_output = temp_dir / "torchcodec_gpu.mp4"
172+
times, nvenc_metrics = bench(
173+
encode_torchcodec,
174+
frames=gpu_frames,
175+
output_path=str(gpu_output),
176+
device="cuda",
177+
gpu_monitoring=True,
178+
average_over=args.average_over,
179+
)
180+
report_stats(
181+
times, frames.shape[0], nvenc_metrics, prefix="VideoEncoder on GPU"
182+
)
183+
# Benchmark FFmpeg CLI on GPU
184+
ffmpeg_gpu_output = temp_dir / "ffmpeg_gpu.mp4"
185+
times, nvenc_metrics = bench(
186+
encode_ffmpeg_cli,
187+
frames=gpu_frames,
188+
raw_path=str(raw_frames_path),
189+
output_path=str(ffmpeg_gpu_output),
190+
device="cuda",
191+
gpu_monitoring=True,
192+
skip_write_frames=args.skip_write_frames,
193+
average_over=args.average_over,
194+
)
195+
prefix = "FFmpeg CLI on GPU "
196+
report_stats(times, frames.shape[0], nvenc_metrics, prefix=prefix)
197+
198+
# Benchmark torchcodec on CPU
199+
cpu_output = temp_dir / "torchcodec_cpu.mp4"
200+
times, _nvenc_metrics = bench(
201+
encode_torchcodec,
202+
frames=frames,
203+
output_path=str(cpu_output),
204+
device="cpu",
205+
average_over=args.average_over,
206+
)
207+
report_stats(times, frames.shape[0], prefix="VideoEncoder on CPU")
208+
209+
# Benchmark FFmpeg CLI on CPU
210+
ffmpeg_cpu_output = temp_dir / "ffmpeg_cpu.mp4"
211+
times, _nvenc_metrics = bench(
212+
encode_ffmpeg_cli,
213+
frames=frames,
214+
raw_path=str(raw_frames_path),
215+
output_path=str(ffmpeg_cpu_output),
216+
device="cpu",
217+
skip_write_frames=args.skip_write_frames,
218+
average_over=args.average_over,
219+
)
220+
prefix = "FFmpeg CLI on CPU "
221+
report_stats(times, frames.shape[0], prefix=prefix)
222+
223+
shutil.rmtree(temp_dir, ignore_errors=True)
224+
225+
226+
if __name__ == "__main__":
227+
main()

0 commit comments

Comments
 (0)