@@ -101,7 +101,7 @@ std::vector<std::string> splitStringWithDelimiters(
101101 return result;
102102}
103103
104- VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth (
104+ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibrary (
105105 int width) {
106106 VideoDecoder::ColorConversionLibrary library =
107107 VideoDecoder::ColorConversionLibrary::SWSCALE;
@@ -121,7 +121,7 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth(
121121// or 4D.
122122// Calling permute() is guaranteed to return a view as per the docs:
123123// https://pytorch.org/docs/stable/generated/torch.permute.html
124- torch::Tensor VideoDecoder::MaybePermuteHWC2CHW (
124+ torch::Tensor VideoDecoder::maybePermuteHWC2CHW (
125125 int streamIndex,
126126 torch::Tensor& hwcTensor) {
127127 if (streams_[streamIndex].options .dimensionOrder == " NHWC" ) {
@@ -299,31 +299,32 @@ std::unique_ptr<VideoDecoder> VideoDecoder::createFromBuffer(
299299 return decoder;
300300}
301301
302- void VideoDecoder::initializeFilterGraphForStream (
303- int streamIndex,
304- const VideoStreamDecoderOptions& options) {
305- FilterState& filterState = streams_[streamIndex].filterState ;
302+ void VideoDecoder::initializeFilterGraph (
303+ StreamInfo& streamInfo,
304+ int expectedOutputHeight,
305+ int expectedOutputWidth) {
306+ FilterState& filterState = streamInfo.filterState ;
306307 if (filterState.filterGraph ) {
307308 return ;
308309 }
309310
310311 filterState.filterGraph .reset (avfilter_graph_alloc ());
311312 TORCH_CHECK (filterState.filterGraph .get () != nullptr );
312- if (options.ffmpegThreadCount .has_value ()) {
313- filterState.filterGraph ->nb_threads = options.ffmpegThreadCount .value ();
313+ if (streamInfo.options .ffmpegThreadCount .has_value ()) {
314+ filterState.filterGraph ->nb_threads =
315+ streamInfo.options .ffmpegThreadCount .value ();
314316 }
315317
316318 const AVFilter* buffersrc = avfilter_get_by_name (" buffer" );
317319 const AVFilter* buffersink = avfilter_get_by_name (" buffersink" );
318- const StreamInfo& activeStream = streams_[streamIndex];
319- AVCodecContext* codecContext = activeStream.codecContext .get ();
320+ AVCodecContext* codecContext = streamInfo.codecContext .get ();
320321
321322 std::stringstream filterArgs;
322323 filterArgs << " video_size=" << codecContext->width << " x"
323324 << codecContext->height ;
324325 filterArgs << " :pix_fmt=" << codecContext->pix_fmt ;
325- filterArgs << " :time_base=" << activeStream .stream ->time_base .num << " /"
326- << activeStream .stream ->time_base .den ;
326+ filterArgs << " :time_base=" << streamInfo .stream ->time_base .num << " /"
327+ << streamInfo .stream ->time_base .den ;
327328 filterArgs << " :pixel_aspect=" << codecContext->sample_aspect_ratio .num << " /"
328329 << codecContext->sample_aspect_ratio .den ;
329330
@@ -378,10 +379,8 @@ void VideoDecoder::initializeFilterGraphForStream(
378379 inputs->pad_idx = 0 ;
379380 inputs->next = nullptr ;
380381
381- auto frameDims = getHeightAndWidthFromOptionsOrMetadata (
382- options, containerMetadata_.streams [streamIndex]);
383382 std::stringstream description;
384- description << " scale=" << frameDims. width << " :" << frameDims. height ;
383+ description << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight ;
385384 description << " :sws_flags=bilinear" ;
386385
387386 AVFilterInOut* outputsTmp = outputs.release ();
@@ -469,26 +468,16 @@ void VideoDecoder::addVideoStreamDecoder(
469468 streamInfo.options = options;
470469 int width = options.width .value_or (codecContext->width );
471470
472- // Use swscale for color conversion by default because it is faster.
473- VideoDecoder::ColorConversionLibrary defaultColorConversionLibrary =
474- getDefaultColorConversionLibraryForWidth (width);
475- // If the user specifies the color conversion library (example in
476- // benchmarks), we use that instead.
477- auto colorConversionLibrary =
478- options.colorConversionLibrary .value_or (defaultColorConversionLibrary);
479-
480- if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
481- initializeFilterGraphForStream (streamNumber, options);
482- streamInfo.colorConversionLibrary = ColorConversionLibrary::FILTERGRAPH;
483- } else if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
484- streamInfo.colorConversionLibrary = ColorConversionLibrary::SWSCALE;
485- } else {
486- throw std::invalid_argument (
487- " Invalid colorConversionLibrary=" +
488- std::to_string (static_cast <int >(colorConversionLibrary)) +
489- " . colorConversionLibrary must be either "
490- " filtergraph or swscale." );
491- }
471+ // By default, we want to use swscale for color conversion because it is
472+ // faster. However, it has width requirements, so we may need to fall back
473+ // to filtergraph. We also need to respect what was requested from the
474+ // options; we respect the options unconditionally, so it's possible for
475+ // swscale's width requirements to be violated. We don't expose the ability to
476+ // choose color conversion library publicly; we only use this ability
477+ // internally.
478+ auto defaultLibrary = getDefaultColorConversionLibrary (width);
479+ streamInfo.colorConversionLibrary =
480+ options.colorConversionLibrary .value_or (defaultLibrary);
492481}
493482
494483void VideoDecoder::updateMetadataWithCodecContext (
@@ -938,6 +927,17 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
938927 } else if (
939928 streamInfo.colorConversionLibrary ==
940929 ColorConversionLibrary::FILTERGRAPH) {
930+ // Note that is a lazy init; we initialize filtergraph the first time
931+ // we have a raw decoded frame. We do this lazily because up until this
932+ // point, we really don't know what the resolution of the frames are
933+ // without modification. In theory, we should be able to get that from the
934+ // stream metadata, but in practice, we have encountered videos where the
935+ // stream metadata had a different resolution from the actual resolution
936+ // of the raw decoded frames.
937+ if (!streamInfo.filterState .filterGraph ) {
938+ initializeFilterGraph (
939+ streamInfo, expectedOutputHeight, expectedOutputWidth);
940+ }
941941 outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
942942
943943 // Similarly to above, if this check fails it means the frame wasn't
@@ -952,6 +952,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
952952 expectedOutputWidth,
953953 " x3, got " ,
954954 shape);
955+
955956 if (preAllocatedOutputTensor.has_value ()) {
956957 // We have already validated that preAllocatedOutputTensor and
957958 // outputTensor have the same shape.
@@ -965,7 +966,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
965966 " Invalid color conversion library: " +
966967 std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
967968 }
968-
969969 } else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
970970 // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
971971 // audio decoding.
@@ -1007,7 +1007,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux(
10071007 });
10081008 // Convert the frame to tensor.
10091009 auto output = convertAVFrameToDecodedOutput (rawOutput);
1010- output.frame = MaybePermuteHWC2CHW (output.streamIndex , output.frame );
1010+ output.frame = maybePermuteHWC2CHW (output.streamIndex , output.frame );
10111011 return output;
10121012}
10131013
@@ -1045,7 +1045,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
10451045 int streamIndex,
10461046 int64_t frameIndex) {
10471047 auto output = getFrameAtIndexInternal (streamIndex, frameIndex);
1048- output.frame = MaybePermuteHWC2CHW (streamIndex, output.frame );
1048+ output.frame = maybePermuteHWC2CHW (streamIndex, output.frame );
10491049 return output;
10501050}
10511051
@@ -1118,7 +1118,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
11181118 }
11191119 previousIndexInVideo = indexInVideo;
11201120 }
1121- output.frames = MaybePermuteHWC2CHW (streamIndex, output.frames );
1121+ output.frames = maybePermuteHWC2CHW (streamIndex, output.frames );
11221122 return output;
11231123}
11241124
@@ -1193,7 +1193,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
11931193 output.ptsSeconds [f] = singleOut.ptsSeconds ;
11941194 output.durationSeconds [f] = singleOut.durationSeconds ;
11951195 }
1196- output.frames = MaybePermuteHWC2CHW (streamIndex, output.frames );
1196+ output.frames = maybePermuteHWC2CHW (streamIndex, output.frames );
11971197 return output;
11981198}
11991199
@@ -1246,7 +1246,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12461246 // need this special case below.
12471247 if (startSeconds == stopSeconds) {
12481248 BatchDecodedOutput output (0 , options, streamMetadata);
1249- output.frames = MaybePermuteHWC2CHW (streamIndex, output.frames );
1249+ output.frames = maybePermuteHWC2CHW (streamIndex, output.frames );
12501250 return output;
12511251 }
12521252
@@ -1287,7 +1287,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12871287 output.ptsSeconds [f] = singleOut.ptsSeconds ;
12881288 output.durationSeconds [f] = singleOut.durationSeconds ;
12891289 }
1290- output.frames = MaybePermuteHWC2CHW (streamIndex, output.frames );
1290+ output.frames = maybePermuteHWC2CHW (streamIndex, output.frames );
12911291
12921292 return output;
12931293}
@@ -1303,7 +1303,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
13031303
13041304VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux () {
13051305 auto output = getNextFrameOutputNoDemuxInternal ();
1306- output.frame = MaybePermuteHWC2CHW (output.streamIndex , output.frame );
1306+ output.frame = maybePermuteHWC2CHW (output.streamIndex , output.frame );
13071307 return output;
13081308}
13091309
0 commit comments