1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cstdint> 17 #include <limits> 18 #include <memory> 19 #include <vector> 20 21 #include "tensorflow/compiler/xla/client/computation_builder.h" 22 #include "tensorflow/compiler/xla/client/local_client.h" 23 #include "tensorflow/compiler/xla/shape_util.h" 24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 25 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 26 #include "tensorflow/compiler/xla/tests/test_macros.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 29 #include "tensorflow/core/platform/test.h" 30 #include "tensorflow/core/platform/types.h" 31 32 namespace xla { 33 namespace { 34 35 class ConvertTest : public ClientLibraryTestBase { 36 public: 37 explicit ConvertTest(perftools::gputools::Platform* platform = nullptr) 38 : ClientLibraryTestBase(platform) { 39 mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); 40 mutable_debug_options()->add_xla_disable_hlo_passes("inline"); 41 } 42 }; 43 44 TEST_F(ConvertTest, ConvertR1S32ToR1S32) { 45 ComputationBuilder builder(client_, TestName()); 46 auto a = builder.ConstantR1<int32>({42, 64}); 47 builder.ConvertElementType(a, S32); 48 49 std::vector<int32> expected = {42, 64}; 50 ComputeAndCompareR1<int32>(&builder, expected, {}); 51 } 52 53 TEST_F(ConvertTest, ConvertR1F32ToR1F32) { 54 ComputationBuilder builder(client_, TestName()); 55 auto a = builder.ConstantR1<float>({42.0f, 64.0f}); 56 builder.ConvertElementType(a, F32); 57 58 std::vector<float> expected = {42.0f, 64.0f}; 59 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 60 } 61 62 TEST_F(ConvertTest, ConvertR1S32ToR1F32) { 63 ComputationBuilder builder(client_, TestName()); 64 auto a = builder.ConstantR1<int32>({42, 64}); 65 builder.ConvertElementType(a, F32); 66 67 std::vector<float> expected = {42.0f, 64.0f}; 68 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 69 } 70 71 TEST_F(ConvertTest, ConvertR1PREDToR1S32) { 72 ComputationBuilder builder(client_, TestName()); 73 auto a = builder.ConstantR1<bool>({true, false, true}); 74 builder.ConvertElementType(a, S32); 75 76 std::vector<int32> expected = {1, 0, 1}; 77 ComputeAndCompareR1<int32>(&builder, expected, {}); 78 } 79 80 TEST_F(ConvertTest, ConvertR1PREDToR1F32) { 81 ComputationBuilder builder(client_, TestName()); 82 auto a = builder.ConstantR1<bool>({true, false, true}); 83 builder.ConvertElementType(a, F32); 84 85 std::vector<float> expected = {1., 0., 1.}; 86 ComputeAndCompareR1<float>(&builder, expected, {}); 87 } 88 89 XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { 90 ComputationBuilder builder(client_, TestName()); 91 auto a = builder.ConstantR1<int32>({}); 92 builder.ConvertElementType(a, F32); 93 94 std::vector<float> expected = {}; 95 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 96 } 97 98 TEST_F(ConvertTest, ConvertR1F32ToR1S32) { 99 ComputationBuilder builder(client_, TestName()); 100 auto a = builder.ConstantR1<float>({42.6, 64.4}); 101 builder.ConvertElementType(a, S32); 102 103 std::vector<int32> expected = {42, 64}; 104 ComputeAndCompareR1<int32>(&builder, expected, {}); 105 } 106 107 XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { 108 ComputationBuilder builder(client_, TestName()); 109 auto a = builder.ConstantR1<int64>({32, 64}); 110 builder.ConvertElementType(a, F32); 111 112 std::vector<float> expected = {32.0, 64.0}; 113 ComputeAndCompareR1<float>(&builder, expected, {}); 114 } 115 116 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { 117 ComputationBuilder builder(client_, TestName()); 118 auto a = builder.ConstantR1<uint8_t>({32, 64}); 119 builder.ConvertElementType(a, F32); 120 121 std::vector<float> expected = {32.0, 64.0}; 122 ComputeAndCompareR1<float>(&builder, expected, {}); 123 } 124 125 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { 126 ComputationBuilder builder(client_, TestName()); 127 auto a = builder.ConstantR1<uint8_t>({32, 64}); 128 builder.ConvertElementType(a, S32); 129 130 std::vector<int32_t> expected = {32, 64}; 131 ComputeAndCompareR1<int32_t>(&builder, expected, {}); 132 } 133 134 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { 135 ComputationBuilder builder(client_, TestName()); 136 auto a = builder.ConstantR1<uint8_t>({32, 64}); 137 builder.ConvertElementType(a, U32); 138 139 std::vector<uint32_t> expected = {32, 64}; 140 ComputeAndCompareR1<uint32_t>(&builder, expected, {}); 141 } 142 143 XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { 144 ComputationBuilder builder(client_, TestName()); 145 auto a = builder.ConstantR1<float>({32.0f, 64.0f}); 146 builder.ConvertElementType(a, F64); 147 148 std::vector<double> expected = {32.0, 64.0}; 149 ComputeAndCompareR1<double>(&builder, expected, {}); 150 } 151 152 XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { 153 ComputationBuilder builder(client_, TestName()); 154 auto a = builder.ConstantR1<double>({32.0, 64.0}); 155 builder.ConvertElementType(a, F32); 156 157 std::vector<float> expected = {32.0f, 64.0f}; 158 ComputeAndCompareR1<float>(&builder, expected, {}); 159 } 160 161 TEST_F(ConvertTest, ConvertS32Extremes) { 162 ComputationBuilder builder(client_, TestName()); 163 auto a = builder.ConstantR1<int32>( 164 {std::numeric_limits<int32>::min(), std::numeric_limits<int32>::max()}); 165 builder.ConvertElementType(a, F32); 166 167 std::vector<float> expected = { 168 static_cast<float>(std::numeric_limits<int32>::min()), 169 static_cast<float>(std::numeric_limits<int32>::max())}; 170 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 171 } 172 173 TEST_F(ConvertTest, ConvertMapToS32) { 174 ComputationBuilder builder(client_, TestName()); 175 auto b = builder.CreateSubBuilder("convert"); 176 auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); 177 b->ConvertElementType(param, S32); 178 auto a = builder.ConstantR1<float>({42.0f, 64.0f}); 179 builder.Map({a}, b->BuildAndNoteError(), {0}); 180 181 std::vector<int32> expected = {42, 64}; 182 ComputeAndCompareR1<int32>(&builder, expected, {}); 183 } 184 185 TEST_F(ConvertTest, ConvertMapToF32) { 186 ComputationBuilder builder(client_, TestName()); 187 auto b = builder.CreateSubBuilder("convert"); 188 auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); 189 b->ConvertElementType(param, F32); 190 auto a = builder.ConstantR1<int32>({42, 64}); 191 builder.Map({a}, b->BuildAndNoteError(), {0}); 192 193 std::vector<float> expected = {42.0f, 64.0f}; 194 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); 195 } 196 197 // Regression test for b/31758660. When ReshapeMover transforms 198 // input -> reshape -> convert 199 // to 200 // input -> convert -> reshape 201 // the new convert should have the same element type as the old convert. 202 TEST_F(ConvertTest, ConvertReshape) { 203 ComputationBuilder builder(client_, TestName()); 204 auto input = builder.ConstantR1<int32>({42}); 205 auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); 206 builder.ConvertElementType(reshape, F32); 207 208 ComputeAndCompareR0<float>(&builder, 42.0f, {}, ErrorSpec(0.0001)); 209 } 210 211 } // namespace 212 } // namespace xla 213