Skip to content

Commit 943dc6e

Browse files
authored
Lazily init filtergraph so it can respect raw decoded resolution (#432)
1 parent 84cef50 commit 943dc6e

File tree

2 files changed

+49
-48
lines changed

2 files changed

+49
-48
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ std::vector<std::string> splitStringWithDelimiters(
101101
return result;
102102
}
103103

104-
VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth(
104+
VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibrary(
105105
int width) {
106106
VideoDecoder::ColorConversionLibrary library =
107107
VideoDecoder::ColorConversionLibrary::SWSCALE;
@@ -121,7 +121,7 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth(
121121
// or 4D.
122122
// Calling permute() is guaranteed to return a view as per the docs:
123123
// https://pytorch.org/docs/stable/generated/torch.permute.html
124-
torch::Tensor VideoDecoder::MaybePermuteHWC2CHW(
124+
torch::Tensor VideoDecoder::maybePermuteHWC2CHW(
125125
int streamIndex,
126126
torch::Tensor& hwcTensor) {
127127
if (streams_[streamIndex].options.dimensionOrder == "NHWC") {
@@ -299,31 +299,32 @@ std::unique_ptr<VideoDecoder> VideoDecoder::createFromBuffer(
299299
return decoder;
300300
}
301301

302-
void VideoDecoder::initializeFilterGraphForStream(
303-
int streamIndex,
304-
const VideoStreamDecoderOptions& options) {
305-
FilterState& filterState = streams_[streamIndex].filterState;
302+
void VideoDecoder::initializeFilterGraph(
303+
StreamInfo& streamInfo,
304+
int expectedOutputHeight,
305+
int expectedOutputWidth) {
306+
FilterState& filterState = streamInfo.filterState;
306307
if (filterState.filterGraph) {
307308
return;
308309
}
309310

310311
filterState.filterGraph.reset(avfilter_graph_alloc());
311312
TORCH_CHECK(filterState.filterGraph.get() != nullptr);
312-
if (options.ffmpegThreadCount.has_value()) {
313-
filterState.filterGraph->nb_threads = options.ffmpegThreadCount.value();
313+
if (streamInfo.options.ffmpegThreadCount.has_value()) {
314+
filterState.filterGraph->nb_threads =
315+
streamInfo.options.ffmpegThreadCount.value();
314316
}
315317

316318
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
317319
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
318-
const StreamInfo& activeStream = streams_[streamIndex];
319-
AVCodecContext* codecContext = activeStream.codecContext.get();
320+
AVCodecContext* codecContext = streamInfo.codecContext.get();
320321

321322
std::stringstream filterArgs;
322323
filterArgs << "video_size=" << codecContext->width << "x"
323324
<< codecContext->height;
324325
filterArgs << ":pix_fmt=" << codecContext->pix_fmt;
325-
filterArgs << ":time_base=" << activeStream.stream->time_base.num << "/"
326-
<< activeStream.stream->time_base.den;
326+
filterArgs << ":time_base=" << streamInfo.stream->time_base.num << "/"
327+
<< streamInfo.stream->time_base.den;
327328
filterArgs << ":pixel_aspect=" << codecContext->sample_aspect_ratio.num << "/"
328329
<< codecContext->sample_aspect_ratio.den;
329330

@@ -378,10 +379,8 @@ void VideoDecoder::initializeFilterGraphForStream(
378379
inputs->pad_idx = 0;
379380
inputs->next = nullptr;
380381

381-
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(
382-
options, containerMetadata_.streams[streamIndex]);
383382
std::stringstream description;
384-
description << "scale=" << frameDims.width << ":" << frameDims.height;
383+
description << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
385384
description << ":sws_flags=bilinear";
386385

387386
AVFilterInOut* outputsTmp = outputs.release();
@@ -469,26 +468,16 @@ void VideoDecoder::addVideoStreamDecoder(
469468
streamInfo.options = options;
470469
int width = options.width.value_or(codecContext->width);
471470

472-
// Use swscale for color conversion by default because it is faster.
473-
VideoDecoder::ColorConversionLibrary defaultColorConversionLibrary =
474-
getDefaultColorConversionLibraryForWidth(width);
475-
// If the user specifies the color conversion library (example in
476-
// benchmarks), we use that instead.
477-
auto colorConversionLibrary =
478-
options.colorConversionLibrary.value_or(defaultColorConversionLibrary);
479-
480-
if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
481-
initializeFilterGraphForStream(streamNumber, options);
482-
streamInfo.colorConversionLibrary = ColorConversionLibrary::FILTERGRAPH;
483-
} else if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
484-
streamInfo.colorConversionLibrary = ColorConversionLibrary::SWSCALE;
485-
} else {
486-
throw std::invalid_argument(
487-
"Invalid colorConversionLibrary=" +
488-
std::to_string(static_cast<int>(colorConversionLibrary)) +
489-
". colorConversionLibrary must be either "
490-
"filtergraph or swscale.");
491-
}
471+
// By default, we want to use swscale for color conversion because it is
472+
// faster. However, it has width requirements, so we may need to fall back
473+
// to filtergraph. We also need to respect what was requested from the
474+
// options; we respect the options unconditionally, so it's possible for
475+
// swscale's width requirements to be violated. We don't expose the ability to
476+
// choose color conversion library publicly; we only use this ability
477+
// internally.
478+
auto defaultLibrary = getDefaultColorConversionLibrary(width);
479+
streamInfo.colorConversionLibrary =
480+
options.colorConversionLibrary.value_or(defaultLibrary);
492481
}
493482

494483
void VideoDecoder::updateMetadataWithCodecContext(
@@ -938,6 +927,17 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
938927
} else if (
939928
streamInfo.colorConversionLibrary ==
940929
ColorConversionLibrary::FILTERGRAPH) {
930+
// Note that is a lazy init; we initialize filtergraph the first time
931+
// we have a raw decoded frame. We do this lazily because up until this
932+
// point, we really don't know what the resolution of the frames are
933+
// without modification. In theory, we should be able to get that from the
934+
// stream metadata, but in practice, we have encountered videos where the
935+
// stream metadata had a different resolution from the actual resolution
936+
// of the raw decoded frames.
937+
if (!streamInfo.filterState.filterGraph) {
938+
initializeFilterGraph(
939+
streamInfo, expectedOutputHeight, expectedOutputWidth);
940+
}
941941
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
942942

943943
// Similarly to above, if this check fails it means the frame wasn't
@@ -952,6 +952,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
952952
expectedOutputWidth,
953953
"x3, got ",
954954
shape);
955+
955956
if (preAllocatedOutputTensor.has_value()) {
956957
// We have already validated that preAllocatedOutputTensor and
957958
// outputTensor have the same shape.
@@ -965,7 +966,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
965966
"Invalid color conversion library: " +
966967
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
967968
}
968-
969969
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
970970
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
971971
// audio decoding.
@@ -1007,7 +1007,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux(
10071007
});
10081008
// Convert the frame to tensor.
10091009
auto output = convertAVFrameToDecodedOutput(rawOutput);
1010-
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame);
1010+
output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame);
10111011
return output;
10121012
}
10131013

@@ -1045,7 +1045,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10451045
int streamIndex,
10461046
int64_t frameIndex) {
10471047
auto output = getFrameAtIndexInternal(streamIndex, frameIndex);
1048-
output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame);
1048+
output.frame = maybePermuteHWC2CHW(streamIndex, output.frame);
10491049
return output;
10501050
}
10511051

@@ -1118,7 +1118,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
11181118
}
11191119
previousIndexInVideo = indexInVideo;
11201120
}
1121-
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
1121+
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
11221122
return output;
11231123
}
11241124

@@ -1193,7 +1193,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
11931193
output.ptsSeconds[f] = singleOut.ptsSeconds;
11941194
output.durationSeconds[f] = singleOut.durationSeconds;
11951195
}
1196-
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
1196+
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
11971197
return output;
11981198
}
11991199

@@ -1246,7 +1246,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12461246
// need this special case below.
12471247
if (startSeconds == stopSeconds) {
12481248
BatchDecodedOutput output(0, options, streamMetadata);
1249-
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
1249+
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
12501250
return output;
12511251
}
12521252

@@ -1287,7 +1287,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12871287
output.ptsSeconds[f] = singleOut.ptsSeconds;
12881288
output.durationSeconds[f] = singleOut.durationSeconds;
12891289
}
1290-
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
1290+
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
12911291

12921292
return output;
12931293
}
@@ -1303,7 +1303,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
13031303

13041304
VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() {
13051305
auto output = getNextFrameOutputNoDemuxInternal();
1306-
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame);
1306+
output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame);
13071307
return output;
13081308
}
13091309

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class VideoDecoder {
157157
int streamIndex,
158158
const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions());
159159

160-
torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
160+
torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
161161

162162
// ---- SINGLE FRAME SEEK AND DECODING API ----
163163
// Places the cursor at the first frame on or after the position in seconds.
@@ -376,9 +376,10 @@ class VideoDecoder {
376376
void validateFrameIndex(const StreamInfo& stream, int64_t frameIndex);
377377
// Creates and initializes a filter graph for a stream. The filter graph can
378378
// do rescaling and color conversion.
379-
void initializeFilterGraphForStream(
380-
int streamIndex,
381-
const VideoStreamDecoderOptions& options);
379+
void initializeFilterGraph(
380+
StreamInfo& streamInfo,
381+
int expectedOutputHeight,
382+
int expectedOutputWidth);
382383
void maybeSeekToBeforeDesiredPts();
383384
RawDecodedOutput getDecodedOutputWithFilter(
384385
std::function<bool(int, AVFrame*)>);
@@ -436,7 +437,7 @@ class VideoDecoder {
436437
// We always allocate [N]HWC tensors. The low-level decoding functions all
437438
// assume HWC tensors, since this is what FFmpeg natively handles. It's up to
438439
// the high-level decoding entry-points to permute that back to CHW, by calling
439-
// MaybePermuteHWC2CHW().
440+
// maybePermuteHWC2CHW().
440441
//
441442
// Also, importantly, the way we figure out the the height and width of the
442443
// output frame tensor varies, and depends on the decoding entry-point. In

0 commit comments

Comments
 (0)