1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cmath> 17 #include <memory> 18 #include <vector> 19 20 #include "tensorflow/compiler/xla/array2d.h" 21 #include "tensorflow/compiler/xla/array4d.h" 22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 23 #include "tensorflow/compiler/xla/client/local_client.h" 24 #include "tensorflow/compiler/xla/client/xla_builder.h" 25 #include "tensorflow/compiler/xla/literal.h" 26 #include "tensorflow/compiler/xla/reference_util.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/shape_util.h" 31 #include "tensorflow/compiler/xla/statusor.h" 32 #include "tensorflow/compiler/xla/test.h" 33 #include "tensorflow/compiler/xla/test_helpers.h" 34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 36 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 37 #include "tensorflow/compiler/xla/tests/test_macros.h" 38 #include "tensorflow/compiler/xla/tests/test_utils.h" 39 #include "tensorflow/compiler/xla/util.h" 40 #include "tensorflow/core/platform/logging.h" 41 #include "tensorflow/core/platform/test.h" 42 #include "tensorflow/core/platform/types.h" 43 44 namespace xla { 45 namespace { 46 47 class Bfloat16Test : public ClientLibraryTestBase { 48 protected: 49 const ErrorSpec error_spec_{0.001, 0.001}; 50 }; 51 52 XLA_TEST_F(Bfloat16Test, ScalarOperation) { 53 XlaBuilder builder(TestName()); 54 auto x = ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(2.0f)); 55 auto y = ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(1.0f)); 56 Add(x, y); 57 58 ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(3.0f), {}, 59 error_spec_); 60 } 61 62 XLA_TEST_F(Bfloat16Test, LogOperation) { 63 XlaBuilder builder(TestName()); 64 auto x = ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(4.0f)); 65 Log(x); 66 67 ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(1.387f), {}, 68 ErrorSpec(0.01, 0.01)); 69 } 70 71 XLA_TEST_F(Bfloat16Test, NegateScalarF16) { 72 XlaBuilder builder(TestName()); 73 Neg(ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(2.1f))); 74 75 ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(-2.1f), {}, 76 error_spec_); 77 } 78 79 // Disabled on interpreter since BatchNormExanper is not run by default on the 80 // intepreter backend. 81 XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) { 82 const int kFeatureIndex = 2; 83 XlaBuilder builder(TestName()); 84 85 auto operand = ConstantR4FromArray4D<bfloat16>( 86 &builder, 87 {{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}}, 88 {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}}, 89 {{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}}, 90 {{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}}); 91 92 auto scale = ConstantR1<bfloat16>( 93 &builder, {static_cast<bfloat16>(2.0f), static_cast<bfloat16>(3.0f)}); 94 95 auto offset = ConstantR1<bfloat16>( 96 &builder, {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(2.0f)}); 97 98 BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); 99 100 auto expected = LiteralUtil::MakeTupleFromSlices( 101 {LiteralUtil::CreateR4<bfloat16>( 102 {{{{static_cast<bfloat16>(-1.6875f)}, 103 {static_cast<bfloat16>(-2.04f)}}, 104 {{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}}, 105 {{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}}, 106 {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}), 107 LiteralUtil::CreateR1<bfloat16>( 108 {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}), 109 LiteralUtil::CreateR1<bfloat16>( 110 {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})}); 111 112 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); 113 } 114 115 // Disabled on interpreter since BatchNormExanper is not run by default on the 116 // intepreter backend. 117 XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormGrad)) { 118 const int kFeatureIndex = 2; 119 XlaBuilder builder(TestName()); 120 121 auto operand = ConstantR4FromArray4D<bfloat16>( 122 &builder, Array4D<bfloat16>(2, 2, 2, 1, static_cast<bfloat16>(0.0f))); 123 124 auto scale = ConstantR1<bfloat16>( 125 &builder, {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)}); 126 127 auto mean = ConstantR1<bfloat16>( 128 &builder, {static_cast<bfloat16>(0.0f), static_cast<bfloat16>(0.0f)}); 129 130 auto var = ConstantR1<bfloat16>( 131 &builder, {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)}); 132 133 auto grad_output = ConstantR4FromArray4D<bfloat16>( 134 &builder, 135 {{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}}, 136 {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}}, 137 {{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}}, 138 {{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}}); 139 140 BatchNormGrad(operand, scale, mean, var, grad_output, 141 /*epsilon=*/0.0, kFeatureIndex); 142 143 auto expected = LiteralUtil::MakeTupleFromSlices( 144 {LiteralUtil::CreateR4<bfloat16>( 145 {{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}}, 146 {{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}}, 147 {{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}}, 148 {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}), 149 LiteralUtil::CreateR1<bfloat16>( 150 {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}), 151 LiteralUtil::CreateR1<bfloat16>( 152 {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})}); 153 154 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); 155 } 156 157 } // namespace 158 } // namespace xla 159