Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/allocator.h"
     17 #include "tensorflow/core/framework/fake_input.h"
     18 #include "tensorflow/core/framework/node_def_builder.h"
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/framework/types.h"
     23 #include "tensorflow/core/framework/types.pb.h"
     24 #include "tensorflow/core/kernels/ops_testutil.h"
     25 #include "tensorflow/core/kernels/ops_util.h"
     26 #include "tensorflow/core/platform/test.h"
     27 
     28 namespace tensorflow {
     29 namespace {
     30 
     31 class RangeOpTest : public OpsTestBase {
     32  protected:
     33   void MakeOp(DataType input_type) {
     34     TF_ASSERT_OK(NodeDefBuilder("myop", "Range")
     35                      .Input(FakeInput(input_type))
     36                      .Input(FakeInput(input_type))
     37                      .Input(FakeInput(input_type))
     38                      .Finalize(node_def()));
     39     TF_ASSERT_OK(InitOp());
     40   }
     41 };
     42 
     43 class LinSpaceOpTest : public OpsTestBase {
     44  protected:
     45   void MakeOp(DataType input_type, DataType index_type) {
     46     TF_ASSERT_OK(NodeDefBuilder("myop", "LinSpace")
     47                      .Input(FakeInput(input_type))
     48                      .Input(FakeInput(input_type))
     49                      .Input(FakeInput(index_type))
     50                      .Finalize(node_def()));
     51     TF_ASSERT_OK(InitOp());
     52   }
     53 };
     54 
     55 TEST_F(RangeOpTest, Simple_D32) {
     56   MakeOp(DT_INT32);
     57 
     58   // Feed and run
     59   AddInputFromArray<int32>(TensorShape({}), {0});
     60   AddInputFromArray<int32>(TensorShape({}), {10});
     61   AddInputFromArray<int32>(TensorShape({}), {2});
     62   TF_ASSERT_OK(RunOpKernel());
     63 
     64   // Check the output
     65   Tensor expected(allocator(), DT_INT32, TensorShape({5}));
     66   test::FillValues<int32>(&expected, {0, 2, 4, 6, 8});
     67   test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
     68 }
     69 
     70 TEST_F(RangeOpTest, Simple_Float) {
     71   MakeOp(DT_FLOAT);
     72 
     73   // Feed and run
     74   AddInputFromArray<float>(TensorShape({}), {0.5});
     75   AddInputFromArray<float>(TensorShape({}), {2});
     76   AddInputFromArray<float>(TensorShape({}), {0.3});
     77   TF_ASSERT_OK(RunOpKernel());
     78 
     79   // Check the output
     80   Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
     81   test::FillValues<float>(&expected, {0.5, 0.8, 1.1, 1.4, 1.7});
     82   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
     83 }
     84 
     85 TEST_F(RangeOpTest, Large_Double) {
     86   MakeOp(DT_DOUBLE);
     87 
     88   // Feed and run
     89   AddInputFromArray<double>(TensorShape({}), {0.0});
     90   AddInputFromArray<double>(TensorShape({}), {10000});
     91   AddInputFromArray<double>(TensorShape({}), {0.5});
     92   TF_ASSERT_OK(RunOpKernel());
     93 
     94   // Check the output
     95   Tensor expected(allocator(), DT_DOUBLE, TensorShape({20000}));
     96   std::vector<double> result;
     97   for (int32 i = 0; i < 20000; ++i) result.push_back(i * 0.5);
     98   test::FillValues<double>(&expected, gtl::ArraySlice<double>(result));
     99   test::ExpectTensorEqual<double>(expected, *GetOutput(0));
    100 }
    101 
    102 TEST_F(LinSpaceOpTest, Simple_D32) {
    103   MakeOp(DT_FLOAT, DT_INT32);
    104 
    105   // Feed and run
    106   AddInputFromArray<float>(TensorShape({}), {3.0});
    107   AddInputFromArray<float>(TensorShape({}), {7.0});
    108   AddInputFromArray<int32>(TensorShape({}), {3});
    109   TF_ASSERT_OK(RunOpKernel());
    110 
    111   // Check the output
    112   Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
    113   test::FillValues<float>(&expected, {3.0, 5.0, 7.0});
    114   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
    115 }
    116 
    117 TEST_F(LinSpaceOpTest, Single_D64) {
    118   MakeOp(DT_FLOAT, DT_INT64);
    119 
    120   // Feed and run
    121   AddInputFromArray<float>(TensorShape({}), {9.0});
    122   AddInputFromArray<float>(TensorShape({}), {100.0});
    123   AddInputFromArray<int64>(TensorShape({}), {1});
    124   TF_ASSERT_OK(RunOpKernel());
    125 
    126   // Check the output
    127   Tensor expected(allocator(), DT_FLOAT, TensorShape({1}));
    128   test::FillValues<float>(&expected, {9.0});
    129   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
    130 }
    131 
    132 TEST_F(LinSpaceOpTest, Simple_Double) {
    133   MakeOp(DT_DOUBLE, DT_INT32);
    134 
    135   // Feed and run
    136   AddInputFromArray<double>(TensorShape({}), {5.0});
    137   AddInputFromArray<double>(TensorShape({}), {6.0});
    138   AddInputFromArray<int32>(TensorShape({}), {6});
    139   TF_ASSERT_OK(RunOpKernel());
    140 
    141   // Check the output
    142   Tensor expected(allocator(), DT_DOUBLE, TensorShape({6}));
    143   test::FillValues<double>(&expected, {5.0, 5.2, 5.4, 5.6, 5.8, 6.0});
    144   test::ExpectTensorEqual<double>(expected, *GetOutput(0));
    145 }
    146 
    147 }  // namespace
    148 }  // namespace tensorflow
    149