Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cstdint>
     17 #include <limits>
     18 #include <memory>
     19 #include <vector>
     20 
     21 #include "tensorflow/compiler/xla/client/computation_builder.h"
     22 #include "tensorflow/compiler/xla/client/local_client.h"
     23 #include "tensorflow/compiler/xla/shape_util.h"
     24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     25 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     26 #include "tensorflow/compiler/xla/tests/test_macros.h"
     27 #include "tensorflow/compiler/xla/xla_data.pb.h"
     28 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     29 #include "tensorflow/core/platform/test.h"
     30 #include "tensorflow/core/platform/types.h"
     31 
     32 namespace xla {
     33 namespace {
     34 
     35 class ConvertTest : public ClientLibraryTestBase {
     36  public:
     37   explicit ConvertTest(perftools::gputools::Platform* platform = nullptr)
     38       : ClientLibraryTestBase(platform) {
     39     mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
     40     mutable_debug_options()->add_xla_disable_hlo_passes("inline");
     41   }
     42 };
     43 
     44 TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
     45   ComputationBuilder builder(client_, TestName());
     46   auto a = builder.ConstantR1<int32>({42, 64});
     47   builder.ConvertElementType(a, S32);
     48 
     49   std::vector<int32> expected = {42, 64};
     50   ComputeAndCompareR1<int32>(&builder, expected, {});
     51 }
     52 
     53 TEST_F(ConvertTest, ConvertR1F32ToR1F32) {
     54   ComputationBuilder builder(client_, TestName());
     55   auto a = builder.ConstantR1<float>({42.0f, 64.0f});
     56   builder.ConvertElementType(a, F32);
     57 
     58   std::vector<float> expected = {42.0f, 64.0f};
     59   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
     60 }
     61 
     62 TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
     63   ComputationBuilder builder(client_, TestName());
     64   auto a = builder.ConstantR1<int32>({42, 64});
     65   builder.ConvertElementType(a, F32);
     66 
     67   std::vector<float> expected = {42.0f, 64.0f};
     68   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
     69 }
     70 
     71 TEST_F(ConvertTest, ConvertR1PREDToR1S32) {
     72   ComputationBuilder builder(client_, TestName());
     73   auto a = builder.ConstantR1<bool>({true, false, true});
     74   builder.ConvertElementType(a, S32);
     75 
     76   std::vector<int32> expected = {1, 0, 1};
     77   ComputeAndCompareR1<int32>(&builder, expected, {});
     78 }
     79 
     80 TEST_F(ConvertTest, ConvertR1PREDToR1F32) {
     81   ComputationBuilder builder(client_, TestName());
     82   auto a = builder.ConstantR1<bool>({true, false, true});
     83   builder.ConvertElementType(a, F32);
     84 
     85   std::vector<float> expected = {1., 0., 1.};
     86   ComputeAndCompareR1<float>(&builder, expected, {});
     87 }
     88 
     89 XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) {
     90   ComputationBuilder builder(client_, TestName());
     91   auto a = builder.ConstantR1<int32>({});
     92   builder.ConvertElementType(a, F32);
     93 
     94   std::vector<float> expected = {};
     95   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
     96 }
     97 
     98 TEST_F(ConvertTest, ConvertR1F32ToR1S32) {
     99   ComputationBuilder builder(client_, TestName());
    100   auto a = builder.ConstantR1<float>({42.6, 64.4});
    101   builder.ConvertElementType(a, S32);
    102 
    103   std::vector<int32> expected = {42, 64};
    104   ComputeAndCompareR1<int32>(&builder, expected, {});
    105 }
    106 
    107 XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
    108   ComputationBuilder builder(client_, TestName());
    109   auto a = builder.ConstantR1<int64>({32, 64});
    110   builder.ConvertElementType(a, F32);
    111 
    112   std::vector<float> expected = {32.0, 64.0};
    113   ComputeAndCompareR1<float>(&builder, expected, {});
    114 }
    115 
    116 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) {
    117   ComputationBuilder builder(client_, TestName());
    118   auto a = builder.ConstantR1<uint8_t>({32, 64});
    119   builder.ConvertElementType(a, F32);
    120 
    121   std::vector<float> expected = {32.0, 64.0};
    122   ComputeAndCompareR1<float>(&builder, expected, {});
    123 }
    124 
    125 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) {
    126   ComputationBuilder builder(client_, TestName());
    127   auto a = builder.ConstantR1<uint8_t>({32, 64});
    128   builder.ConvertElementType(a, S32);
    129 
    130   std::vector<int32_t> expected = {32, 64};
    131   ComputeAndCompareR1<int32_t>(&builder, expected, {});
    132 }
    133 
    134 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) {
    135   ComputationBuilder builder(client_, TestName());
    136   auto a = builder.ConstantR1<uint8_t>({32, 64});
    137   builder.ConvertElementType(a, U32);
    138 
    139   std::vector<uint32_t> expected = {32, 64};
    140   ComputeAndCompareR1<uint32_t>(&builder, expected, {});
    141 }
    142 
    143 XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) {
    144   ComputationBuilder builder(client_, TestName());
    145   auto a = builder.ConstantR1<float>({32.0f, 64.0f});
    146   builder.ConvertElementType(a, F64);
    147 
    148   std::vector<double> expected = {32.0, 64.0};
    149   ComputeAndCompareR1<double>(&builder, expected, {});
    150 }
    151 
    152 XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) {
    153   ComputationBuilder builder(client_, TestName());
    154   auto a = builder.ConstantR1<double>({32.0, 64.0});
    155   builder.ConvertElementType(a, F32);
    156 
    157   std::vector<float> expected = {32.0f, 64.0f};
    158   ComputeAndCompareR1<float>(&builder, expected, {});
    159 }
    160 
    161 TEST_F(ConvertTest, ConvertS32Extremes) {
    162   ComputationBuilder builder(client_, TestName());
    163   auto a = builder.ConstantR1<int32>(
    164       {std::numeric_limits<int32>::min(), std::numeric_limits<int32>::max()});
    165   builder.ConvertElementType(a, F32);
    166 
    167   std::vector<float> expected = {
    168       static_cast<float>(std::numeric_limits<int32>::min()),
    169       static_cast<float>(std::numeric_limits<int32>::max())};
    170   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    171 }
    172 
    173 TEST_F(ConvertTest, ConvertMapToS32) {
    174   ComputationBuilder builder(client_, TestName());
    175   auto b = builder.CreateSubBuilder("convert");
    176   auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in");
    177   b->ConvertElementType(param, S32);
    178   auto a = builder.ConstantR1<float>({42.0f, 64.0f});
    179   builder.Map({a}, b->BuildAndNoteError(), {0});
    180 
    181   std::vector<int32> expected = {42, 64};
    182   ComputeAndCompareR1<int32>(&builder, expected, {});
    183 }
    184 
    185 TEST_F(ConvertTest, ConvertMapToF32) {
    186   ComputationBuilder builder(client_, TestName());
    187   auto b = builder.CreateSubBuilder("convert");
    188   auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in");
    189   b->ConvertElementType(param, F32);
    190   auto a = builder.ConstantR1<int32>({42, 64});
    191   builder.Map({a}, b->BuildAndNoteError(), {0});
    192 
    193   std::vector<float> expected = {42.0f, 64.0f};
    194   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
    195 }
    196 
    197 // Regression test for b/31758660. When ReshapeMover transforms
    198 //   input -> reshape -> convert
    199 // to
    200 //   input -> convert -> reshape
    201 // the new convert should have the same element type as the old convert.
    202 TEST_F(ConvertTest, ConvertReshape) {
    203   ComputationBuilder builder(client_, TestName());
    204   auto input = builder.ConstantR1<int32>({42});
    205   auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
    206   builder.ConvertElementType(reshape, F32);
    207 
    208   ComputeAndCompareR0<float>(&builder, 42.0f, {}, ErrorSpec(0.0001));
    209 }
    210 
    211 }  // namespace
    212 }  // namespace xla
    213