Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <functional>
     17 #include <memory>
     18 
     19 #include "tensorflow/core/framework/allocator.h"
     20 #include "tensorflow/core/framework/fake_input.h"
     21 #include "tensorflow/core/framework/node_def_builder.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/summary.pb.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/kernels/ops_testutil.h"
     27 #include "tensorflow/core/kernels/ops_util.h"
     28 #include "tensorflow/core/lib/core/status_test_util.h"
     29 #include "tensorflow/core/lib/histogram/histogram.h"
     30 #include "tensorflow/core/lib/strings/strcat.h"
     31 #include "tensorflow/core/platform/env.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/protobuf.h"
     34 #include "tensorflow/core/platform/test.h"
     35 
     36 namespace tensorflow {
     37 namespace {
     38 
     39 static void EXPECT_SummaryMatches(const Summary& actual,
     40                                   const string& expected_str) {
     41   Summary expected;
     42   CHECK(protobuf::TextFormat::ParseFromString(expected_str, &expected));
     43   EXPECT_EQ(expected.DebugString(), actual.DebugString());
     44 }
     45 
     46 // --------------------------------------------------------------------------
     47 // SummaryAudioOp
     48 // --------------------------------------------------------------------------
     49 class SummaryAudioOpTest : public OpsTestBase {
     50  protected:
     51   void MakeOp(const int max_outputs) {
     52     TF_ASSERT_OK(NodeDefBuilder("myop", "AudioSummaryV2")
     53                      .Input(FakeInput())
     54                      .Input(FakeInput())
     55                      .Input(FakeInput())
     56                      .Attr("max_outputs", max_outputs)
     57                      .Finalize(node_def()));
     58     TF_ASSERT_OK(InitOp());
     59   }
     60 
     61   void CheckAndRemoveEncodedAudio(Summary* summary) {
     62     for (int i = 0; i < summary->value_size(); ++i) {
     63       Summary::Value* value = summary->mutable_value(i);
     64       ASSERT_TRUE(value->has_audio()) << "No audio for value: " << value->tag();
     65       ASSERT_FALSE(value->audio().encoded_audio_string().empty())
     66           << "No encoded_audio_string for value: " << value->tag();
     67       if (VLOG_IS_ON(2)) {
     68         // When LOGGING, output the audio to disk for manual inspection.
     69         TF_CHECK_OK(WriteStringToFile(
     70             Env::Default(), strings::StrCat("/tmp/", value->tag(), ".wav"),
     71             value->audio().encoded_audio_string()));
     72       }
     73       value->mutable_audio()->clear_encoded_audio_string();
     74     }
     75   }
     76 };
     77 
     78 TEST_F(SummaryAudioOpTest, Basic3D) {
     79   const float kSampleRate = 44100.0f;
     80   const int kMaxOutputs = 3;
     81   MakeOp(kMaxOutputs);
     82 
     83   // Feed and run
     84   AddInputFromArray<string>(TensorShape({}), {"tag"});
     85   AddInputFromArray<float>(TensorShape({4, 2, 2}),
     86                            {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     87                             0.0, 0.0, 0.0, 0.0, 0.0, 0.0});
     88   AddInputFromArray<float>(TensorShape({}), {kSampleRate});
     89 
     90   TF_ASSERT_OK(RunOpKernel());
     91 
     92   // Check the output size.
     93   Tensor* out_tensor = GetOutput(0);
     94   ASSERT_EQ(0, out_tensor->dims());
     95   Summary summary;
     96   ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
     97 
     98   CheckAndRemoveEncodedAudio(&summary);
     99   EXPECT_SummaryMatches(summary, R"(
    100     value { tag: 'tag/audio/0'
    101             audio { content_type: "audio/wav" sample_rate: 44100 num_channels: 2
    102                     length_frames: 2 } }
    103     value { tag: 'tag/audio/1'
    104             audio { content_type: "audio/wav" sample_rate: 44100 num_channels: 2
    105                     length_frames: 2 } }
    106     value { tag: 'tag/audio/2'
    107             audio { content_type: "audio/wav" sample_rate: 44100 num_channels: 2
    108                     length_frames: 2 } }
    109   )");
    110 }
    111 
    112 TEST_F(SummaryAudioOpTest, Basic2D) {
    113   const float kSampleRate = 44100.0f;
    114   const int kMaxOutputs = 3;
    115   MakeOp(kMaxOutputs);
    116 
    117   // Feed and run
    118   AddInputFromArray<string>(TensorShape({}), {"tag"});
    119   AddInputFromArray<float>(TensorShape({4, 4}),
    120                            {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
    121                             0.0, 0.0, 0.0, 0.0, 0.0, 0.0});
    122   AddInputFromArray<float>(TensorShape({}), {kSampleRate});
    123 
    124   TF_ASSERT_OK(RunOpKernel());
    125 
    126   // Check the output size.
    127   Tensor* out_tensor = GetOutput(0);
    128   ASSERT_EQ(0, out_tensor->dims());
    129   Summary summary;
    130   ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
    131 
    132   CheckAndRemoveEncodedAudio(&summary);
    133   EXPECT_SummaryMatches(summary, R"(
    134     value { tag: 'tag/audio/0'
    135             audio { content_type: "audio/wav" sample_rate: 44100 num_channels: 1
    136                     length_frames: 4 } }
    137     value { tag: 'tag/audio/1'
    138             audio { content_type: "audio/wav" sample_rate: 44100 num_channels: 1
    139                     length_frames: 4 } }
    140     value { tag: 'tag/audio/2'
    141             audio { content_type: "audio/wav" sample_rate: 44100 num_channels: 1
    142                     length_frames: 4 } }
    143   )");
    144 }
    145 
    146 }  // namespace
    147 }  // namespace tensorflow
    148