Home | History | Annotate | Download | only in ffmpeg
      1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 
     16 #include <limits>
     17 
     18 #include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
     19 #include "tensorflow/core/framework/common_shape_fns.h"
     20 #include "tensorflow/core/framework/op.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 
     23 namespace tensorflow {
     24 namespace ffmpeg {
     25 namespace {
     26 
     27 /*
     28  * Encoding implementation, shared across V1 and V2 ops. Creates a new
     29  * output in the context.
     30  */
     31 void Encode(OpKernelContext* context, const Tensor& contents,
     32             const string& file_format, const int32 bits_per_second,
     33             const int32 samples_per_second) {
     34   std::vector<float> samples;
     35   samples.reserve(contents.NumElements());
     36   for (int32 i = 0; i < contents.NumElements(); ++i) {
     37     samples.push_back(contents.flat<float>()(i));
     38   }
     39   const int32 channel_count = contents.dim_size(1);
     40   string encoded_audio;
     41   OP_REQUIRES_OK(
     42       context, CreateAudioFile(file_format, bits_per_second, samples_per_second,
     43                                channel_count, samples, &encoded_audio));
     44 
     45   // Copy the encoded audio file to the output tensor.
     46   Tensor* output = nullptr;
     47   OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output));
     48   output->scalar<string>()() = encoded_audio;
     49 }
     50 
     51 }  // namespace
     52 
     53 /*
     54  * Supersedes `EncodeAudioOp`. Allows all parameters to be inputs
     55  * instead of attributes, so that the sample rate (and, probably less
     56  * usefully, the output file format) can be given as tensors rather than
     57  * constants only.
     58  */
     59 class EncodeAudioOpV2 : public OpKernel {
     60  public:
     61   explicit EncodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) {}
     62 
     63   void Compute(OpKernelContext* context) override {
     64     OP_REQUIRES(
     65         context, context->num_inputs() == 4,
     66         errors::InvalidArgument("EncodeAudio requires exactly four inputs."));
     67 
     68     const Tensor& contents = context->input(0);
     69     const Tensor& file_format_tensor = context->input(1);
     70     const Tensor& samples_per_second_tensor = context->input(2);
     71     const Tensor& bits_per_second_tensor = context->input(3);
     72 
     73     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()),
     74                 errors::InvalidArgument(
     75                     "sampled_audio must be a rank-2 tensor but got shape ",
     76                     contents.shape().DebugString()));
     77     OP_REQUIRES(
     78         context, contents.NumElements() <= std::numeric_limits<int32>::max(),
     79         errors::InvalidArgument(
     80             "sampled_audio cannot have more than 2^31 entries. Shape = ",
     81             contents.shape().DebugString()));
     82     OP_REQUIRES(context, TensorShapeUtils::IsScalar(file_format_tensor.shape()),
     83                 errors::InvalidArgument(
     84                     "file_format must be a rank-0 tensor but got shape ",
     85                     file_format_tensor.shape().DebugString()));
     86     OP_REQUIRES(context,
     87                 TensorShapeUtils::IsScalar(samples_per_second_tensor.shape()),
     88                 errors::InvalidArgument(
     89                     "samples_per_second must be a rank-0 tensor but got shape ",
     90                     samples_per_second_tensor.shape().DebugString()));
     91     OP_REQUIRES(context,
     92                 TensorShapeUtils::IsScalar(bits_per_second_tensor.shape()),
     93                 errors::InvalidArgument(
     94                     "bits_per_second must be a rank-0 tensor but got shape ",
     95                     bits_per_second_tensor.shape().DebugString()));
     96 
     97     const string file_format =
     98         str_util::Lowercase(file_format_tensor.scalar<string>()());
     99     const int32 samples_per_second =
    100         samples_per_second_tensor.scalar<int32>()();
    101     const int32 bits_per_second = bits_per_second_tensor.scalar<int32>()();
    102 
    103     OP_REQUIRES(context, file_format == "wav",
    104                 errors::InvalidArgument(
    105                     "file_format must be \"wav\", but got: ", file_format));
    106     OP_REQUIRES(context, samples_per_second > 0,
    107                 errors::InvalidArgument(
    108                     "samples_per_second must be positive, but got: ",
    109                     samples_per_second));
    110     OP_REQUIRES(
    111         context, bits_per_second > 0,
    112         errors::InvalidArgument("bits_per_second must be positive, but got: ",
    113                                 bits_per_second));
    114 
    115     Encode(context, contents, file_format, bits_per_second, samples_per_second);
    116   }
    117 };
    118 
    119 REGISTER_KERNEL_BUILDER(Name("EncodeAudioV2").Device(DEVICE_CPU),
    120                         EncodeAudioOpV2);
    121 
    122 REGISTER_OP("EncodeAudioV2")
    123     .Input("sampled_audio: float")
    124     .Input("file_format: string")
    125     .Input("samples_per_second: int32")
    126     .Input("bits_per_second: int32")
    127     .Output("contents: string")
    128     .SetShapeFn(shape_inference::ScalarShape)
    129     .Doc(R"doc(
    130 Processes a `Tensor` containing sampled audio with the number of channels
    131 and length of the audio specified by the dimensions of the `Tensor`. The
    132 audio is converted into a string that, when saved to disk, will be equivalent
    133 to the audio in the specified audio format.
    134 
    135 The input audio has one row of the tensor for each channel in the audio file.
    136 Each channel contains audio samples starting at the beginning of the audio and
    137 having `1/samples_per_second` time between them. The output file will contain
    138 all of the audio channels contained in the tensor.
    139 
    140 sampled_audio: A rank-2 float tensor containing all tracks of the audio.
    141     Dimension 0 is time and dimension 1 is the channel.
    142 file_format: A string or rank-0 string tensor describing the audio file
    143     format. This value must be `"wav"`.
    144 samples_per_second: The number of samples per second that the audio should
    145     have, as an int or rank-0 `int32` tensor. This value must be
    146     positive.
    147 bits_per_second: The approximate bitrate of the encoded audio file, as
    148     an int or rank-0 `int32` tensor. This is ignored by the "wav" file
    149     format.
    150 contents: The binary audio file contents, as a rank-0 string tensor.
    151 )doc");
    152 
    153 /*
    154  * Deprecated in favor of EncodeAudioOpV2.
    155  */
    156 class EncodeAudioOp : public OpKernel {
    157  public:
    158   explicit EncodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
    159     OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
    160     file_format_ = str_util::Lowercase(file_format_);
    161     OP_REQUIRES(context, file_format_ == "wav",
    162                 errors::InvalidArgument("file_format arg must be \"wav\"."));
    163 
    164     OP_REQUIRES_OK(
    165         context, context->GetAttr("samples_per_second", &samples_per_second_));
    166     OP_REQUIRES(context, samples_per_second_ > 0,
    167                 errors::InvalidArgument("samples_per_second must be > 0."));
    168     OP_REQUIRES_OK(context,
    169                    context->GetAttr("bits_per_second", &bits_per_second_));
    170   }
    171 
    172   void Compute(OpKernelContext* context) override {
    173     // Get and verify the input data.
    174     OP_REQUIRES(
    175         context, context->num_inputs() == 1,
    176         errors::InvalidArgument("EncodeAudio requires exactly one input."));
    177     const Tensor& contents = context->input(0);
    178     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()),
    179                 errors::InvalidArgument(
    180                     "sampled_audio must be a rank 2 tensor but got shape ",
    181                     contents.shape().DebugString()));
    182     OP_REQUIRES(
    183         context, contents.NumElements() <= std::numeric_limits<int32>::max(),
    184         errors::InvalidArgument(
    185             "sampled_audio cannot have more than 2^31 entries. Shape = ",
    186             contents.shape().DebugString()));
    187 
    188     Encode(context, contents, file_format_, bits_per_second_,
    189            samples_per_second_);
    190   }
    191 
    192  private:
    193   string file_format_;
    194   int32 samples_per_second_;
    195   int32 bits_per_second_;
    196 };
    197 
    198 REGISTER_KERNEL_BUILDER(Name("EncodeAudio").Device(DEVICE_CPU), EncodeAudioOp);
    199 
    200 REGISTER_OP("EncodeAudio")
    201     .Input("sampled_audio: float")
    202     .Output("contents: string")
    203     .Attr("file_format: string")
    204     .Attr("samples_per_second: int")
    205     .Attr("bits_per_second: int = 192000")
    206     .SetShapeFn(shape_inference::ScalarShape)
    207     .Doc(R"doc(
    208 Processes a `Tensor` containing sampled audio with the number of channels
    209 and length of the audio specified by the dimensions of the `Tensor`. The
    210 audio is converted into a string that, when saved to disk, will be equivalent
    211 to the audio in the specified audio format.
    212 
    213 The input audio has one row of the tensor for each channel in the audio file.
    214 Each channel contains audio samples starting at the beginning of the audio and
    215 having `1/samples_per_second` time between them. The output file will contain
    216 all of the audio channels contained in the tensor.
    217 
    218 sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0
    219     is time and dimension 1 is the channel.
    220 contents: The binary audio file contents.
    221 file_format: A string describing the audio file format. This must be "wav".
    222 samples_per_second: The number of samples per second that the audio should have.
    223 bits_per_second: The approximate bitrate of the encoded audio file. This is
    224     ignored by the "wav" file format.
    225 )doc");
    226 
    227 }  // namespace ffmpeg
    228 }  // namespace tensorflow
    229