1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" 17 #include "tensorflow/compiler/tf2xla/lib/while_loop.h" 18 #include "tensorflow/compiler/tf2xla/shape_util.h" 19 #include "tensorflow/compiler/tf2xla/type_util.h" 20 #include "tensorflow/compiler/tf2xla/xla_context.h" 21 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/core/framework/kernel_def_builder.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 27 namespace tensorflow { 28 29 Status XlaGather(const xla::ComputationDataHandle& input, 30 const TensorShape& input_shape, 31 const xla::ComputationDataHandle& indices, 32 TensorShape indices_shape, int64 axis, bool indices_are_nd, 33 DataType dtype, DataType index_type, 34 xla::ComputationBuilder* builder, 35 xla::ComputationDataHandle* gather_output) { 36 // If the indices are N-dimensional, then the minor dimension of indices 37 // should be of size N and correspond to the N indices. 38 int64 num_index_dims = 1; 39 if (indices_are_nd) { 40 CHECK_GE(indices_shape.dims(), 1); 41 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); 42 indices_shape.RemoveLastDims(1); 43 } 44 45 // Although the indices Tensor is flattened into rank 1 during the lookup, 46 // and each scalar entry is used as an index into the first dimension of the 47 // input, the output is returned with shape: 48 // input.shape[:axis] + indices.shape + input.shape[axis+1:] 49 50 const int64 num_indices = indices_shape.num_elements(); 51 TensorShape input_shape_pre_axis(input_shape); 52 input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims()); 53 TensorShape input_shape_post_axis(input_shape); 54 input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); 55 // Each slice of the input tensor has shape: 56 // [<input_shape_pre_axis>, 1, ..., 1, <input shape_post_axis>] 57 TensorShape slice_shape(input_shape); 58 for (int64 i = 0; i < num_index_dims; ++i) { 59 slice_shape.set_dim(axis + i, 1); 60 } 61 62 TensorShape loop_out_shape; 63 loop_out_shape.AppendShape(input_shape_pre_axis); 64 loop_out_shape.AddDim(num_indices); 65 loop_out_shape.AppendShape(input_shape_post_axis); 66 TensorShape loop_out_slice_shape; 67 loop_out_slice_shape.AppendShape(input_shape_pre_axis); 68 loop_out_slice_shape.AddDim(1); 69 loop_out_slice_shape.AppendShape(input_shape_post_axis); 70 71 TensorShape out_shape; 72 out_shape.AppendShape(input_shape_pre_axis); 73 out_shape.AppendShape(indices_shape); 74 out_shape.AppendShape(input_shape_post_axis); 75 76 // Degenerate case: empty indices. 77 if (num_indices == 0) { 78 *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype), 79 out_shape.dim_sizes()); 80 return Status::OK(); 81 } 82 83 for (int64 i = 0; i < num_index_dims; ++i) { 84 if (input_shape.dim_size(axis + i) == 0) { 85 return errors::InvalidArgument("Gather dimension ", axis + i, 86 " is of size zero in tensor with shape ", 87 input_shape.DebugString()); 88 } 89 } 90 91 // Flatten the major dimensions of indices into a single dimension for ease of 92 // iteration. If there is an axis dimension, we must leave it alone. 93 std::vector<int64> flat_indices_shape = {num_indices}; 94 if (indices_are_nd) { 95 flat_indices_shape.push_back(num_index_dims); 96 } 97 98 // Specify the shape of the loop-carried Tensor tuple. 99 100 // Construct the initial values of the loop-carried Tensors. 101 auto flat_indices = builder->Reshape(indices, flat_indices_shape); 102 auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), 103 loop_out_shape.dim_sizes()); 104 auto init = {input, flat_indices, init_out}; 105 106 // Construct the while loop body's function. The implementation of gather is: 107 // for i in range(num_indices): 108 // index = dynamic-slice(indices, i) 109 // xi = dynamic-slice(input, index) 110 // output = dynamic-update-slice(output, xi, i) 111 auto body_fn = [&](xla::ComputationDataHandle i, 112 gtl::ArraySlice<xla::ComputationDataHandle> loop_vars, 113 xla::ComputationBuilder* bodyb) { 114 auto input = loop_vars[0]; 115 auto indices = loop_vars[1]; 116 auto output = loop_vars[2]; 117 118 auto zero_index = XlaHelpers::Zero(bodyb, index_type); 119 120 // Slice the i-th index from the indices array. 121 xla::ComputationDataHandle index; 122 auto indices_offset = bodyb->Reshape(i, {1}); 123 if (indices_are_nd) { 124 // Slice out the entire nd index, if applicable. 125 indices_offset = bodyb->Pad(indices_offset, zero_index, 126 xla::MakeEdgePaddingConfig({{0, 1}})); 127 index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims}); 128 index = bodyb->Collapse(index, {0, 1}); 129 } else { 130 index = bodyb->DynamicSlice(indices, indices_offset, {1}); 131 } 132 133 // Slice the corresponding data from the input array. 134 auto start_indices = bodyb->Pad( 135 index, zero_index, 136 xla::MakeEdgePaddingConfig( 137 {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); 138 auto slice_i = bodyb->Reshape( 139 bodyb->DynamicSlice(input, start_indices, slice_shape.dim_sizes()), 140 loop_out_slice_shape.dim_sizes()); 141 142 // Construct the index into the output Tensor 0, ..., <index>, 0, ... 143 std::vector<xla::ComputationDataHandle> out_index_vals( 144 loop_out_shape.dims(), bodyb->Reshape(zero_index, {1})); 145 out_index_vals[input_shape_pre_axis.dims()] = bodyb->Reshape(i, {1}); 146 auto out_index = bodyb->ConcatInDim(out_index_vals, 0); 147 148 // Update the output Tensor 149 auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index); 150 151 return std::vector<xla::ComputationDataHandle>{input, indices, 152 updated_output}; 153 }; 154 155 // Construct the While loop, extract and reshape the output. 156 xla::PrimitiveType ptype; 157 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype)); 158 TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn, 159 init, "gather", builder)); 160 *gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes()); 161 return Status::OK(); 162 } 163 164 class GatherOp : public XlaOpKernel { 165 public: 166 explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {} 167 168 void Compile(XlaOpKernelContext* context) override { 169 xla::ComputationBuilder* builder = context->builder(); 170 auto input = context->Input(0); 171 auto input_shape = context->InputShape(0); 172 auto indices = context->Input(1); 173 auto indices_shape = context->InputShape(1); 174 int64 axis = 0; 175 if (context->num_inputs() == 3) { 176 const TensorShape axis_shape = context->InputShape(2); 177 OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), 178 errors::InvalidArgument("axis must be scalar")); 179 DataType axis_type = input_type(2); 180 OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, 181 errors::InvalidArgument("axis must be int32 or int64")); 182 183 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); 184 const auto params_dims = input_shape.dims(); 185 if (axis < 0) { 186 axis += params_dims; 187 } 188 OP_REQUIRES( 189 context, 0 <= axis && axis < params_dims, 190 errors::InvalidArgument("Expected axis in the range [", -params_dims, 191 ", ", params_dims, "), but got ", axis)); 192 } 193 194 DataType index_type = input_type(1); 195 OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, 196 errors::InvalidArgument("indices must be int32 or int64")); 197 198 xla::ComputationDataHandle gather; 199 OP_REQUIRES_OK( 200 context, XlaGather(input, input_shape, indices, indices_shape, axis, 201 /*indices_are_nd=*/false, input_type(0), index_type, 202 builder, &gather)); 203 context->SetOutput(0, gather); 204 } 205 206 private: 207 TF_DISALLOW_COPY_AND_ASSIGN(GatherOp); 208 }; 209 210 REGISTER_XLA_OP(Name("Gather"), GatherOp); 211 REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), GatherOp); 212 213 class GatherNdOp : public XlaOpKernel { 214 public: 215 explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {} 216 217 void Compile(XlaOpKernelContext* context) override { 218 DataType params_type = context->input_type(0); 219 DataType indices_type = context->input_type(1); 220 221 TensorShape params_shape = context->InputShape(0); 222 TensorShape indices_shape = context->InputShape(1); 223 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape), 224 errors::InvalidArgument("params must be at least a vector")); 225 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape), 226 errors::InvalidArgument("indices must be at least a vector")); 227 const int64 num_index_dims = 228 indices_shape.dim_size(indices_shape.dims() - 1); 229 OP_REQUIRES( 230 context, num_index_dims <= params_shape.dims(), 231 errors::InvalidArgument( 232 "index innermost dimension length must be <= params rank; saw: ", 233 indices_shape.dim_size(indices_shape.dims() - 1), " vs. ", 234 params_shape.dims())); 235 236 xla::ComputationBuilder* builder = context->builder(); 237 auto params = context->Input(0); 238 auto indices = context->Input(1); 239 xla::ComputationDataHandle gather; 240 OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices, 241 indices_shape, /*axis=*/0, 242 /*indices_are_nd=*/true, params_type, 243 indices_type, builder, &gather)); 244 context->SetOutput(0, gather); 245 } 246 }; 247 248 REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp); 249 250 } // namespace tensorflow 251