1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <complex> 16 17 #include <gtest/gtest.h> 18 #include "tensorflow/lite/interpreter.h" 19 #include "tensorflow/lite/kernels/register.h" 20 #include "tensorflow/lite/kernels/test_util.h" 21 #include "tensorflow/lite/model.h" 22 23 namespace tflite { 24 namespace { 25 26 using ::testing::ElementsAreArray; 27 28 class CastOpModel : public SingleOpModel { 29 public: 30 CastOpModel(const TensorData& input, const TensorData& output) { 31 input_ = AddInput(input); 32 output_ = AddOutput(output); 33 SetBuiltinOp(BuiltinOperator_CAST, BuiltinOptions_CastOptions, 34 CreateCastOptions(builder_).Union()); 35 BuildInterpreter({GetShape(input_)}); 36 } 37 38 int input() const { return input_; } 39 int output() const { return output_; } 40 41 protected: 42 int input_; 43 int output_; 44 }; 45 46 TEST(CastOpModel, CastIntToFloat) { 47 CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); 48 m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600}); 49 m.Invoke(); 50 EXPECT_THAT(m.ExtractVector<float>(m.output()), 51 ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f})); 52 } 53 54 TEST(CastOpModel, CastFloatToInt) { 55 CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}}); 56 m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f}); 57 m.Invoke(); 58 EXPECT_THAT(m.ExtractVector<int>(m.output()), 59 ElementsAreArray({100, 20, 3, 0, 0, 1})); 60 } 61 62 TEST(CastOpModel, CastFloatToBool) { 63 CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_BOOL, {3, 2}}); 64 m.PopulateTensor<float>(m.input(), {100.f, -1.0f, 0.f, 0.4f, 0.999f, 1.1f}); 65 m.Invoke(); 66 EXPECT_THAT(m.ExtractVector<bool>(m.output()), 67 ElementsAreArray({true, true, false, true, true, true})); 68 } 69 70 TEST(CastOpModel, CastBoolToFloat) { 71 CastOpModel m({TensorType_BOOL, {3, 2}}, {TensorType_FLOAT32, {3, 2}}); 72 m.PopulateTensor<bool>(m.input(), {true, true, false, true, false, true}); 73 m.Invoke(); 74 EXPECT_THAT(m.ExtractVector<float>(m.output()), 75 ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f})); 76 } 77 78 TEST(CastOpModel, CastComplex64ToFloat) { 79 CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); 80 m.PopulateTensor<std::complex<float>>( 81 m.input(), 82 {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), 83 std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), 84 std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)}); 85 m.Invoke(); 86 EXPECT_THAT(m.ExtractVector<float>(m.output()), 87 ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); 88 } 89 90 TEST(CastOpModel, CastFloatToComplex64) { 91 CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); 92 m.PopulateTensor<float>(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); 93 m.Invoke(); 94 EXPECT_THAT( 95 m.ExtractVector<std::complex<float>>(m.output()), 96 ElementsAreArray( 97 {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f), 98 std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f), 99 std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)})); 100 } 101 102 TEST(CastOpModel, CastComplex64ToInt) { 103 CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}}); 104 m.PopulateTensor<std::complex<float>>( 105 m.input(), 106 {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), 107 std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), 108 std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)}); 109 m.Invoke(); 110 EXPECT_THAT(m.ExtractVector<int>(m.output()), 111 ElementsAreArray({1, 2, 3, 4, 5, 6})); 112 } 113 114 TEST(CastOpModel, CastIntToComplex64) { 115 CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); 116 m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6}); 117 m.Invoke(); 118 EXPECT_THAT( 119 m.ExtractVector<std::complex<float>>(m.output()), 120 ElementsAreArray( 121 {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f), 122 std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f), 123 std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)})); 124 } 125 126 TEST(CastOpModel, CastComplex64ToComplex64) { 127 CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); 128 m.PopulateTensor<std::complex<float>>( 129 m.input(), 130 {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), 131 std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), 132 std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)}); 133 m.Invoke(); 134 EXPECT_THAT( 135 m.ExtractVector<std::complex<float>>(m.output()), 136 ElementsAreArray( 137 {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), 138 std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), 139 std::complex<float>(5.0f, 15.0f), 140 std::complex<float>(6.0f, 16.0f)})); 141 } 142 143 } // namespace 144 } // namespace tflite 145 int main(int argc, char** argv) { 146 ::tflite::LogToStderr(); 147 ::testing::InitGoogleTest(&argc, argv); 148 return RUN_ALL_TESTS(); 149 } 150