Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
     17 // inputs or outputs in various ways.
     18 
     19 // See docs in ../ops/summary_ops.cc.
     20 
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/summary.pb.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/lib/wav/wav_io.h"
     25 #include "tensorflow/core/platform/types.h"
     26 
     27 namespace tensorflow {
     28 
     29 class SummaryAudioOp : public OpKernel {
     30  public:
     31   explicit SummaryAudioOp(OpKernelConstruction* context) : OpKernel(context) {
     32     OP_REQUIRES_OK(context, context->GetAttr("max_outputs", &max_outputs_));
     33     OP_REQUIRES(context, max_outputs_ > 0,
     34                 errors::InvalidArgument("max_outputs must be > 0"));
     35     has_sample_rate_attr_ =
     36         context->GetAttr("sample_rate", &sample_rate_attr_).ok();
     37   }
     38 
     39   void Compute(OpKernelContext* c) override {
     40     const Tensor& tag = c->input(0);
     41     const Tensor& tensor = c->input(1);
     42     OP_REQUIRES(c, IsLegacyScalar(tag.shape()),
     43                 errors::InvalidArgument("Tag must be a scalar"));
     44     OP_REQUIRES(c, tensor.dims() >= 2 && tensor.dims() <= 3,
     45                 errors::InvalidArgument("Tensor must be 3-D or 2-D, got: ",
     46                                         tensor.shape().DebugString()));
     47     const string& base_tag = tag.scalar<string>()();
     48 
     49     float sample_rate = sample_rate_attr_;
     50     if (!has_sample_rate_attr_) {
     51       const Tensor& sample_rate_tensor = c->input(2);
     52       sample_rate = sample_rate_tensor.scalar<float>()();
     53     }
     54     OP_REQUIRES(c, sample_rate > 0.0f,
     55                 errors::InvalidArgument("sample_rate must be > 0"));
     56 
     57     const int batch_size = tensor.dim_size(0);
     58     const int64 length_frames = tensor.dim_size(1);
     59     const int64 num_channels =
     60         tensor.dims() == 2 ? 1 : tensor.dim_size(tensor.dims() - 1);
     61 
     62     Summary s;
     63     const int N = std::min<int>(max_outputs_, batch_size);
     64     for (int i = 0; i < N; ++i) {
     65       Summary::Value* v = s.add_value();
     66       if (max_outputs_ > 1) {
     67         v->set_tag(strings::StrCat(base_tag, "/audio/", i));
     68       } else {
     69         v->set_tag(strings::StrCat(base_tag, "/audio"));
     70       }
     71 
     72       Summary::Audio* sa = v->mutable_audio();
     73       sa->set_sample_rate(sample_rate);
     74       sa->set_num_channels(num_channels);
     75       sa->set_length_frames(length_frames);
     76       sa->set_content_type("audio/wav");
     77 
     78       auto values =
     79           tensor.shaped<float, 3>({batch_size, length_frames, num_channels});
     80       auto channels_by_frames = typename TTypes<float>::ConstMatrix(
     81           &values(i, 0, 0),
     82           Eigen::DSizes<Eigen::DenseIndex, 2>(length_frames, num_channels));
     83       size_t sample_rate_truncated = lrintf(sample_rate);
     84       if (sample_rate_truncated == 0) {
     85         sample_rate_truncated = 1;
     86       }
     87       OP_REQUIRES_OK(
     88           c, wav::EncodeAudioAsS16LEWav(
     89                  channels_by_frames.data(), sample_rate_truncated, num_channels,
     90                  length_frames, sa->mutable_encoded_audio_string()));
     91     }
     92 
     93     Tensor* summary_tensor = nullptr;
     94     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
     95     CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
     96   }
     97 
     98  private:
     99   int max_outputs_;
    100   bool has_sample_rate_attr_;
    101   float sample_rate_attr_;
    102 };
    103 
    104 REGISTER_KERNEL_BUILDER(Name("AudioSummaryV2").Device(DEVICE_CPU),
    105                         SummaryAudioOp);
    106 
    107 // Deprecated -- this op is registered with sample_rate as an attribute for
    108 // backwards compatibility.
    109 REGISTER_KERNEL_BUILDER(Name("AudioSummary").Device(DEVICE_CPU),
    110                         SummaryAudioOp);
    111 
    112 }  // namespace tensorflow
    113