1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/util/strided_slice_op.h" 17 #include "tensorflow/compiler/tf2xla/literal_util.h" 18 #include "tensorflow/compiler/tf2xla/type_util.h" 19 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/kernels/ops_util.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/lib/gtl/array_slice.h" 28 #include "tensorflow/core/platform/mem.h" 29 30 namespace tensorflow { 31 namespace { 32 33 class StridedSliceOp : public XlaOpKernel { 34 public: 35 explicit StridedSliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 36 OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); 37 OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); 38 OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); 39 OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); 40 OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); 41 OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); 42 } 43 44 void Compile(XlaOpKernelContext* ctx) override { 45 const TensorShape input_shape = ctx->InputShape(0); 46 47 TensorShape final_shape; 48 gtl::InlinedVector<int64, 4> begin; 49 gtl::InlinedVector<int64, 4> end; 50 gtl::InlinedVector<int64, 4> strides; 51 52 xla::Literal begin_literal, end_literal, strides_literal; 53 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); 54 OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); 55 OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); 56 57 Tensor begin_tensor, end_tensor, strides_tensor; 58 OP_REQUIRES_OK( 59 ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); 60 OP_REQUIRES_OK(ctx, 61 LiteralToHostTensor(end_literal, index_type_, &end_tensor)); 62 OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, 63 &strides_tensor)); 64 65 TensorShape dummy_processing_shape; 66 bool dummy = false; 67 OP_REQUIRES_OK(ctx, 68 ValidateStridedSliceOp( 69 &begin_tensor, &end_tensor, strides_tensor, input_shape, 70 begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, 71 shrink_axis_mask_, &dummy_processing_shape, &final_shape, 72 &dummy, &dummy, &dummy, &begin, &end, &strides)); 73 74 gtl::InlinedVector<int64, 4> dimensions_to_reverse; 75 gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides; 76 77 for (int i = 0; i < begin.size(); ++i) { 78 if (strides[i] > 0) { 79 slice_begin.push_back(begin[i]); 80 slice_end.push_back(end[i]); 81 slice_strides.push_back(strides[i]); 82 } else { 83 // Negative stride: swap begin and end, add 1 because the interval 84 // is semi-open, and mark the dimension to be reversed. 85 slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); 86 slice_end.push_back(input_shape.dim_size(i) - end[i] - 1); 87 slice_strides.push_back(-strides[i]); 88 dimensions_to_reverse.push_back(i); 89 } 90 } 91 92 xla::ComputationDataHandle slice = ctx->Input(0); 93 if (!dimensions_to_reverse.empty()) { 94 slice = ctx->builder()->Rev(slice, dimensions_to_reverse); 95 } 96 97 slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides); 98 99 slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); 100 ctx->SetOutput(0, slice); 101 } 102 103 private: 104 int32 begin_mask_, end_mask_; 105 int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; 106 DataType index_type_; 107 }; 108 109 REGISTER_XLA_OP(Name("StridedSlice") 110 .CompileTimeConstInput("begin") 111 .CompileTimeConstInput("end") 112 .CompileTimeConstInput("strides"), 113 StridedSliceOp); 114 115 class StridedSliceGradOp : public XlaOpKernel { 116 public: 117 explicit StridedSliceGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 118 OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); 119 OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); 120 OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); 121 OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); 122 OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); 123 OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); 124 } 125 126 void Compile(XlaOpKernelContext* ctx) override { 127 TensorShape processing_shape, final_shape; 128 gtl::InlinedVector<int64, 4> begin; 129 gtl::InlinedVector<int64, 4> end; 130 gtl::InlinedVector<int64, 4> strides; 131 132 TensorShape input_shape; 133 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); 134 135 xla::Literal begin_literal, end_literal, strides_literal; 136 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); 137 OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); 138 OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); 139 140 Tensor begin_tensor, end_tensor, strides_tensor; 141 OP_REQUIRES_OK( 142 ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); 143 OP_REQUIRES_OK(ctx, 144 LiteralToHostTensor(end_literal, index_type_, &end_tensor)); 145 OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, 146 &strides_tensor)); 147 148 bool dummy = false; 149 OP_REQUIRES_OK( 150 ctx, ValidateStridedSliceOp( 151 &begin_tensor, &end_tensor, strides_tensor, input_shape, 152 begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, 153 shrink_axis_mask_, &processing_shape, &final_shape, &dummy, 154 &dummy, &dummy, &begin, &end, &strides)); 155 156 // Check to make sure dy is consistent with the original slice 157 const TensorShape dy_shape = ctx->InputShape(4); 158 OP_REQUIRES( 159 ctx, final_shape == dy_shape, 160 errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(), 161 " instead of ", final_shape.DebugString())); 162 163 OP_REQUIRES( 164 ctx, input_shape.dims() == processing_shape.dims(), 165 errors::Internal( 166 "input shape and processing shape must have same number of dims")); 167 168 auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); 169 170 xla::ComputationDataHandle grad = ctx->Input(4); 171 172 // Undo any new/shrink axes. 173 grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes()); 174 175 // Pad the input gradients. 176 gtl::InlinedVector<int64, 4> dimensions_to_reverse; 177 xla::PaddingConfig padding_config; 178 179 for (int i = 0; i < processing_shape.dims(); ++i) { 180 auto* dims = padding_config.add_dimensions(); 181 if (strides[i] > 0) { 182 dims->set_edge_padding_low(begin[i]); 183 dims->set_interior_padding(strides[i] - 1); 184 185 // Pad the upper dimension up to the expected input shape. (It's 186 // not sufficient simply to use "end[i]" to compute the padding in 187 // cases where the stride does not divide evenly into the interval 188 // between begin[i] and end[i].) 189 int64 size = 190 dims->edge_padding_low() + processing_shape.dim_size(i) + 191 (processing_shape.dim_size(i) - 1) * dims->interior_padding(); 192 dims->set_edge_padding_high(input_shape.dim_size(i) - size); 193 } else { 194 dimensions_to_reverse.push_back(i); 195 dims->set_edge_padding_high(input_shape.dim_size(i) - begin[i] - 1); 196 dims->set_interior_padding(-strides[i] - 1); 197 198 // Pad the lower dimension up to the expected input shape. 199 int64 size = 200 dims->edge_padding_high() + processing_shape.dim_size(i) + 201 (processing_shape.dim_size(i) - 1) * dims->interior_padding(); 202 dims->set_edge_padding_low(input_shape.dim_size(i) - size); 203 } 204 } 205 if (!dimensions_to_reverse.empty()) { 206 grad = ctx->builder()->Rev(grad, dimensions_to_reverse); 207 } 208 grad = ctx->builder()->Pad(grad, zero, padding_config); 209 ctx->SetOutput(0, grad); 210 } 211 212 private: 213 int32 begin_mask_, end_mask_; 214 int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; 215 DataType index_type_; 216 }; 217 218 REGISTER_XLA_OP(Name("StridedSliceGrad") 219 .CompileTimeConstInput("shape") 220 .CompileTimeConstInput("begin") 221 .CompileTimeConstInput("end") 222 .CompileTimeConstInput("strides"), 223 StridedSliceGradOp); 224 225 class StridedSliceAssignOp : public XlaOpKernel { 226 public: 227 explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 228 OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); 229 OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); 230 OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); 231 OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); 232 OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); 233 OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); 234 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 235 } 236 237 void Compile(XlaOpKernelContext* ctx) override { 238 TensorShape final_shape; 239 gtl::InlinedVector<int64, 4> begin; 240 gtl::InlinedVector<int64, 4> end; 241 gtl::InlinedVector<int64, 4> strides; 242 243 xla::Literal begin_literal, end_literal, strides_literal; 244 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); 245 OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); 246 OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); 247 248 Tensor begin_tensor, end_tensor, strides_tensor; 249 OP_REQUIRES_OK( 250 ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); 251 OP_REQUIRES_OK(ctx, 252 LiteralToHostTensor(end_literal, index_type_, &end_tensor)); 253 OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, 254 &strides_tensor)); 255 256 TensorShape lhs_shape; 257 xla::ComputationDataHandle lhs; 258 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); 259 260 const TensorShape rhs_shape = ctx->InputShape(4); 261 262 TensorShape dummy_processing_shape; 263 bool dummy = false; 264 OP_REQUIRES_OK(ctx, 265 ValidateStridedSliceOp( 266 &begin_tensor, &end_tensor, strides_tensor, lhs_shape, 267 begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, 268 shrink_axis_mask_, &dummy_processing_shape, &final_shape, 269 &dummy, &dummy, &dummy, &begin, &end, &strides)); 270 271 if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { 272 // DynamicUpdateSlice does not allow 0-element updates. We should probably 273 // check that rhs_shape can be broadcast to final_shape, but that is 274 // probably better handled when implementing broadcasting more generally. 275 return; 276 } 277 278 // TODO(aselle): This check is too strong, we only should need 279 // input_shape to be broadcastable to final_shape 280 OP_REQUIRES(ctx, final_shape == rhs_shape, 281 errors::Unimplemented( 282 "sliced l-value shape ", final_shape.DebugString(), 283 " does not match r-value shape ", rhs_shape.DebugString(), 284 ". Automatic broadcasting not yet implemented.")); 285 286 xla::ComputationDataHandle rhs = ctx->Input(4); 287 288 gtl::InlinedVector<int64, 4> dimensions_to_reverse; 289 gtl::InlinedVector<int64, 4> slice_begin, slice_dims; 290 for (int i = 0; i < begin.size(); ++i) { 291 // TODO(phawkins): implement strides != 1 292 OP_REQUIRES( 293 ctx, strides[i] == 1 || strides[i] == -1, 294 errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); 295 if (strides[i] > 0) { 296 slice_begin.push_back(begin[i]); 297 slice_dims.push_back(end[i] - begin[i]); 298 } else { 299 // Negative stride: swap begin and end, add 1 because the interval 300 // is semi-open, and mark the dimension to be reversed. 301 slice_begin.push_back(end[i] + 1); 302 slice_dims.push_back(begin[i] - end[i]); 303 dimensions_to_reverse.push_back(i); 304 } 305 } 306 307 if (!dimensions_to_reverse.empty()) { 308 rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse); 309 } 310 rhs = ctx->builder()->Reshape(rhs, slice_dims); 311 312 if (lhs_shape.dims() == 0) { 313 // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix 314 // and remove this workaround. 315 lhs = rhs; 316 } else { 317 lhs = ctx->builder()->DynamicUpdateSlice( 318 lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin)); 319 } 320 321 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); 322 } 323 324 private: 325 int32 begin_mask_, end_mask_; 326 int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; 327 DataType index_type_; 328 DataType dtype_; 329 }; 330 331 REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") 332 .CompileTimeConstInput("begin") 333 .CompileTimeConstInput("end") 334 .CompileTimeConstInput("strides"), 335 StridedSliceAssignOp); 336 337 } // namespace 338 } // namespace tensorflow 339