1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines helper routines for XLA compilation. 17 18 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 19 #include "tensorflow/compiler/tf2xla/lib/util.h" 20 21 #include "tensorflow/compiler/tf2xla/literal_util.h" 22 #include "tensorflow/compiler/tf2xla/type_util.h" 23 #include "tensorflow/compiler/tf2xla/xla_context.h" 24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 25 #include "tensorflow/compiler/xla/client/computation_builder.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 30 namespace tensorflow { 31 32 namespace { 33 34 Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, 35 const xla::ComputationDataHandle& input, 36 const TensorShape& input_shape, DataType input_type, 37 DataType output_type, int axis, bool is_min, 38 xla::ComputationDataHandle* argminmax) { 39 xla::ComputationDataHandle init_value; 40 const xla::Computation* reducer; 41 if (is_min) { 42 init_value = XlaHelpers::MaxValue(builder, input_type); 43 reducer = ctx->GetOrCreateMin(input_type); 44 } else { 45 init_value = XlaHelpers::MinValue(builder, input_type); 46 reducer = ctx->GetOrCreateMax(input_type); 47 } 48 49 xla::PrimitiveType xla_output_type; 50 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); 51 52 xla::ComputationDataHandle input_max = builder->Reduce( 53 input, init_value, *reducer, /*dimensions_to_reduce=*/{axis}); 54 std::vector<int64> broadcast_dims(input_shape.dims() - 1); 55 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); 56 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); 57 // Compute a mask that has 1s for elements equal to the maximum. 58 xla::ComputationDataHandle partial_mask = builder->ConvertElementType( 59 builder->Eq(input, input_max, broadcast_dims), xla_output_type); 60 61 // In order to make identity elements for a bitwise And, we: 62 // Left shift the 1 to the leftmost bit, yielding 0x10...0 63 // Arithmetic right shift the 1 back to the rightmost bit, yielding 64 // 0xFF...F 65 int32 bits_in_type = 66 xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; 67 xla::ComputationDataHandle shift_amount = 68 XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); 69 xla::ComputationDataHandle full_mask = builder->ShiftRightArithmetic( 70 builder->ShiftLeft(partial_mask, shift_amount), shift_amount); 71 72 // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its 73 // index. 74 xla::ComputationDataHandle iota; 75 76 const int64 axis_size = input_shape.dim_size(axis); 77 TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); 78 xla::ComputationDataHandle product = 79 builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); 80 81 // If there are multiple maximum elements, choose the one with the highest 82 // index. 83 xla::ComputationDataHandle output = 84 builder->Reduce(product, XlaHelpers::MinValue(builder, output_type), 85 *ctx->GetOrCreateMax(output_type), 86 /*dimensions_to_reduce=*/{axis}); 87 *argminmax = output; 88 return Status::OK(); 89 } 90 91 } // namespace 92 93 xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b, 94 DataType data_type) { 95 xla::PrimitiveType type; 96 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); 97 return b->ConstantLiteral(xla::Literal::MinValue(type)); 98 } 99 100 xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b, 101 DataType data_type) { 102 xla::PrimitiveType type; 103 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); 104 return b->ConstantLiteral(xla::Literal::MaxValue(type)); 105 } 106 107 xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b, 108 DataType data_type) { 109 xla::PrimitiveType type; 110 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); 111 return b->ConstantLiteral(xla::Literal::Zero(type)); 112 } 113 114 xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, 115 DataType data_type) { 116 xla::PrimitiveType type; 117 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); 118 return b->ConstantLiteral(xla::Literal::One(type)); 119 } 120 121 xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, 122 DataType data_type) { 123 switch (data_type) { 124 case DT_BFLOAT16: 125 return b->ConstantR0<bfloat16>(bfloat16::epsilon()); 126 case DT_FLOAT: 127 return b->ConstantR0<float>(std::numeric_limits<float>::epsilon()); 128 case DT_DOUBLE: 129 return b->ConstantR0<double>(std::numeric_limits<double>::epsilon()); 130 default: 131 LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: " 132 << DataTypeString(data_type); 133 } 134 } 135 136 xla::ComputationDataHandle XlaHelpers::IntegerLiteral( 137 xla::ComputationBuilder* b, DataType data_type, int64 value) { 138 xla::PrimitiveType type; 139 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); 140 return ::tensorflow::IntegerLiteral(b, type, value); 141 } 142 143 xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, 144 DataType data_type, 145 double value) { 146 xla::PrimitiveType type; 147 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); 148 return ::tensorflow::FloatLiteral(b, type, value); 149 } 150 151 /* static */ Status XlaHelpers::ReshapeLiteral( 152 const xla::Literal& input, gtl::ArraySlice<int64> dimensions, 153 xla::Literal* output) { 154 if (xla::ShapeUtil::IsTuple(input.shape())) { 155 return errors::InvalidArgument("ReshapeLiteral does not support tuples."); 156 } 157 xla::Shape shape = 158 xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions); 159 int64 elements_before = xla::ShapeUtil::ElementsIn(input.shape()); 160 int64 elements_after = xla::ShapeUtil::ElementsIn(shape); 161 if (elements_before != elements_after) { 162 return errors::InvalidArgument( 163 "Shapes before and after ReshapeLiteral have different numbers of " 164 "elements."); 165 } 166 167 *output = input.Clone(); 168 output->mutable_shape_do_not_use()->Swap(&shape); 169 return Status::OK(); 170 } 171 172 template <typename T> 173 static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { 174 Tensor linspace(DataTypeToEnum<T>::v(), shape); 175 auto linspace_flat = linspace.flat<T>(); 176 for (int64 i = 0; i < depth; ++i) { 177 linspace_flat(i) = i; 178 } 179 return linspace; 180 } 181 182 Status XlaHelpers::ArgMax(xla::ComputationBuilder* builder, 183 XlaOpKernelContext* ctx, 184 const xla::ComputationDataHandle& input, 185 const TensorShape& input_shape, DataType input_type, 186 DataType output_type, int axis, 187 xla::ComputationDataHandle* argmax) { 188 return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, 189 axis, /*is_min=*/false, argmax); 190 } 191 192 Status XlaHelpers::ArgMin(xla::ComputationBuilder* builder, 193 XlaOpKernelContext* ctx, 194 const xla::ComputationDataHandle& input, 195 const TensorShape& input_shape, DataType input_type, 196 DataType output_type, int axis, 197 xla::ComputationDataHandle* argmin) { 198 return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, 199 axis, /*is_min=*/true, argmin); 200 } 201 202 Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype, 203 int64 size, xla::ComputationDataHandle* iota) { 204 TensorShape linspace_shape({size}); 205 Tensor linspace; 206 switch (dtype) { 207 case DT_UINT8: 208 linspace = MakeLinspaceTensor<uint8>(linspace_shape, size); 209 break; 210 case DT_INT32: 211 linspace = MakeLinspaceTensor<int32>(linspace_shape, size); 212 break; 213 case DT_INT64: 214 linspace = MakeLinspaceTensor<int64>(linspace_shape, size); 215 break; 216 default: 217 return errors::InvalidArgument("Invalid argument type ", 218 DataTypeString(dtype)); 219 } 220 xla::Literal linspace_literal; 221 TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); 222 *iota = builder->ConstantLiteral(linspace_literal); 223 return Status::OK(); 224 } 225 226 Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, 227 int axis, DataType index_type, 228 const TensorShape& indices_shape, 229 const xla::ComputationDataHandle& indices, 230 const xla::ComputationDataHandle& on_value, 231 const xla::ComputationDataHandle& off_value, 232 xla::ComputationDataHandle* one_hot) { 233 const int indices_dims = indices_shape.dims(); 234 const int output_dims = indices_dims + 1; 235 236 TensorShape output_shape = indices_shape; 237 output_shape.InsertDim(axis, depth); 238 239 // Build a Tensor populated with values 0, 1, 2, ... depth. 240 std::vector<int64> linspace_dims(output_dims, 1); 241 linspace_dims[axis] = depth; 242 TensorShape linspace_shape(linspace_dims); 243 Tensor linspace; 244 switch (index_type) { 245 case DT_UINT8: 246 linspace = MakeLinspaceTensor<uint8>(linspace_shape, depth); 247 break; 248 case DT_INT32: 249 linspace = MakeLinspaceTensor<int32>(linspace_shape, depth); 250 break; 251 case DT_INT64: 252 linspace = MakeLinspaceTensor<int64>(linspace_shape, depth); 253 break; 254 default: 255 return errors::InvalidArgument("Invalid argument type ", 256 DataTypeString(index_type)); 257 } 258 xla::Literal linspace_literal; 259 TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); 260 261 // Broadcast the linspace constant across the indices along the new axis, 262 // and test equality at each position. 263 std::vector<int64> broadcast_dims(indices_shape.dims()); 264 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); 265 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); 266 xla::ComputationDataHandle one_hot_bool = builder->Eq( 267 indices, builder->ConstantLiteral(linspace_literal), broadcast_dims); 268 269 // Selects the user-provided off_value and on_value values. 270 *one_hot = builder->Select( 271 one_hot_bool, builder->Broadcast(on_value, output_shape.dim_sizes()), 272 builder->Broadcast(off_value, output_shape.dim_sizes())); 273 return Status::OK(); 274 } 275 276 } // end namespace tensorflow 277