Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <functional>
     19 #include <memory>
     20 #include <vector>
     21 
     22 #include "tensorflow/cc/client/client_session.h"
     23 #include "tensorflow/cc/ops/audio_ops.h"
     24 #include "tensorflow/cc/ops/const_op.h"
     25 #include "tensorflow/cc/ops/math_ops.h"
     26 #include "tensorflow/core/framework/shape_inference_testutil.h"
     27 #include "tensorflow/core/framework/tensor_testutil.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/framework/types.pb.h"
     30 #include "tensorflow/core/kernels/ops_util.h"
     31 #include "tensorflow/core/lib/core/status_test_util.h"
     32 #include "tensorflow/core/platform/test.h"
     33 
     34 namespace tensorflow {
     35 namespace ops {
     36 namespace {
     37 
     38 TEST(DecodeWavOpTest, DecodeWavTest) {
     39   Scope root = Scope::NewRootScope();
     40 
     41   std::vector<uint8> wav_data = {
     42       'R',  'I',  'F', 'F', 44,  0,   0,   0,  // size of whole file - 8
     43       'W',  'A',  'V', 'E', 'f', 'm', 't', ' ', 16, 0, 0,
     44       0,                   // size of fmt block - 8: 24 - 8
     45       1,    0,             // format: PCM (1)
     46       1,    0,             // channels: 1
     47       0x13, 0x37, 0,   0,  // sample rate: 14099
     48       0x26, 0x6e, 0,   0,  // byte rate: 2 * 14099
     49       2,    0,             // block align: NumChannels * BytesPerSample
     50       16,   0,             // bits per sample: 2 * 8
     51       'd',  'a',  't', 'a', 8,   0,   0,   0,  // size of payload: 8
     52       0,    0,                                 // first sample: 0
     53       0xff, 0x3f,                              // second sample: 16383
     54       0xff, 0x7f,  // third sample: 32767 (saturated)
     55       0x00, 0x80,  // fourth sample: -32768 (saturated)
     56   };
     57   Tensor content_tensor =
     58       test::AsScalar<string>(string(wav_data.begin(), wav_data.end()));
     59   Output content_op =
     60       Const(root.WithOpName("content_op"), Input::Initializer(content_tensor));
     61 
     62   DecodeWav decode_wav_op =
     63       DecodeWav(root.WithOpName("decode_wav_op"), content_op);
     64 
     65   TF_ASSERT_OK(root.status());
     66 
     67   ClientSession session(root);
     68   std::vector<Tensor> outputs;
     69 
     70   TF_EXPECT_OK(session.Run(ClientSession::FeedType(),
     71                            {decode_wav_op.audio, decode_wav_op.sample_rate},
     72                            &outputs));
     73 
     74   const Tensor& audio = outputs[0];
     75   const int sample_rate = outputs[1].flat<int32>()(0);
     76 
     77   EXPECT_EQ(2, audio.dims());
     78   EXPECT_EQ(1, audio.dim_size(1));
     79   EXPECT_EQ(4, audio.dim_size(0));
     80   EXPECT_NEAR(0.0f, audio.flat<float>()(0), 1e-4f);
     81   EXPECT_NEAR(0.5f, audio.flat<float>()(1), 1e-4f);
     82   EXPECT_NEAR(1.0f, audio.flat<float>()(2), 1e-4f);
     83   EXPECT_NEAR(-1.0f, audio.flat<float>()(3), 1e-4f);
     84   EXPECT_EQ(14099, sample_rate);
     85 }
     86 
     87 TEST(DecodeWavOpTest, DecodeWav_ShapeFn) {
     88   ShapeInferenceTestOp op("DecodeWav");
     89   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1]");
     90 
     91   // audio shape is unknown when desired_{samples,channels} are default.
     92   TF_ASSERT_OK(NodeDefBuilder("test", "DecodeWav")
     93                    .Input({"a", 0, DT_STRING})
     94                    .Finalize(&op.node_def));
     95   INFER_OK(op, "[]", "[?,?];[]");
     96 
     97   TF_ASSERT_OK(NodeDefBuilder("test", "DecodeWav")
     98                    .Input({"a", 0, DT_STRING})
     99                    .Attr("desired_samples", 42)
    100                    .Finalize(&op.node_def));
    101   INFER_OK(op, "[]", "[42,?];[]");
    102 
    103   // Negative sample value is rejected.
    104   TF_ASSERT_OK(NodeDefBuilder("test", "DecodeWav")
    105                    .Input({"a", 0, DT_STRING})
    106                    .Attr("desired_samples", -2)
    107                    .Finalize(&op.node_def));
    108   INFER_ERROR("samples must be non-negative, got -2", op, "[]");
    109 
    110   TF_ASSERT_OK(NodeDefBuilder("test", "DecodeWav")
    111                    .Input({"a", 0, DT_STRING})
    112                    .Attr("desired_channels", 2)
    113                    .Finalize(&op.node_def));
    114   INFER_OK(op, "[]", "[?,2];[]");
    115 
    116   // Negative channel value is rejected.
    117   TF_ASSERT_OK(NodeDefBuilder("test", "DecodeWav")
    118                    .Input({"a", 0, DT_STRING})
    119                    .Attr("desired_channels", -2)
    120                    .Finalize(&op.node_def));
    121   INFER_ERROR("channels must be non-negative, got -2", op, "[]");
    122 }
    123 
    124 }  // namespace
    125 }  // namespace ops
    126 }  // namespace tensorflow
    127