1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <numeric> 17 #include <vector> 18 19 #include "tensorflow/compiler/tf2xla/shape_util.h" 20 #include "tensorflow/compiler/tf2xla/type_util.h" 21 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/compiler/xla/array4d.h" 25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 26 #include "tensorflow/compiler/xla/client/lib/constants.h" 27 #include "tensorflow/compiler/xla/client/xla_builder.h" 28 #include "tensorflow/compiler/xla/literal.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/framework/kernel_def_builder.h" 32 #include "tensorflow/core/framework/op_kernel.h" 33 #include "tensorflow/core/framework/register_types.h" 34 #include "tensorflow/core/framework/tensor_shape.h" 35 #include "tensorflow/core/framework/types.pb.h" 36 #include "tensorflow/core/lib/core/errors.h" 37 #include "tensorflow/core/lib/gtl/inlined_vector.h" 38 #include "tensorflow/core/lib/math/math_util.h" 39 #include "tensorflow/core/platform/types.h" 40 41 namespace tensorflow { 42 namespace { 43 44 using xla::XlaOp; 45 46 // Calculates the bilinear weight tensor, given basis ratio (px, py) of the 47 // sampling position: 48 // W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py] 49 // 'ratio' tensor has dimensions [batch, dim_0, ...dim_n, 2]. 50 // 51 // The returned tensor has dimensions [batch, dim_0, ... dim_n, 4]. 52 XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio, 53 const TensorShape warp_shape, 54 xla::PrimitiveType xla_type) { 55 auto first_term = xla::ConstantR2<float>( 56 ctx->builder(), {{1.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}, {0.0, 0.0}}); 57 first_term = xla::ConvertElementType(first_term, xla_type); 58 59 auto warp_dims = warp_shape.dim_sizes(); 60 std::vector<int64> broadcast_dims(warp_dims.begin(), warp_dims.end() - 1); 61 broadcast_dims.push_back(4); 62 broadcast_dims.push_back(2); 63 64 const int64 broadcast_dims_size = broadcast_dims.size(); 65 66 std::vector<int64> last_two_dims_indices = {(broadcast_dims_size - 2), 67 (broadcast_dims_size - 1)}; 68 69 auto broadcast_first_term = 70 xla::BroadcastInDim(first_term, broadcast_dims, last_two_dims_indices); 71 72 // Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n, 73 // 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the 74 // [batch, dim_0,...dim_n] dimensions and the [2] dimension as the last 75 // dimension. 76 std::vector<int64> ratio_broadcast_indices(broadcast_dims.size()); 77 std::iota(ratio_broadcast_indices.begin(), ratio_broadcast_indices.end(), 0); 78 ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2); 79 80 auto broadcast_ratio = 81 xla::BroadcastInDim(ratio, broadcast_dims, ratio_broadcast_indices); 82 83 auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio; 84 85 // Now we have [(1-px, 1-py), (-px, 1-py), (1-px, -py), (px, py)], need to 86 // flip the signs of the second and the third term. 87 auto sign_change = xla::ConstantR2<float>( 88 ctx->builder(), {{1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {1.0, 1.0}}); 89 sign_change = xla::ConvertElementType(sign_change, xla_type); 90 91 auto broadcast_sign_change = 92 xla::BroadcastInDim(sign_change, broadcast_dims, last_two_dims_indices); 93 94 auto flipped = first_term_subtract_weights * broadcast_sign_change; 95 96 // Build up the final bilinear weight tensor by multiply reduction, which 97 // gives: 98 // [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py] 99 // for each 4 neighboring pixels where px and py are the weight of the target 100 // pixel we are sampling from. 101 return xla::Reduce( 102 flipped, xla::One(ctx->builder(), xla_type), 103 xla::CreateScalarMultiplyComputation(xla_type, ctx->builder()), 104 {broadcast_dims_size - 1}); 105 } 106 107 // Concatenates the batch indices to the (x, y) coordinate indices. 108 // This is done by first creating an Iota tensor that represents the current 109 // batch it is in, then concatenate with the givin (coordinate) indices. 110 // 111 // The resulting tensor has dimension (batch, dim_0, ... dim_n, 3) where 112 // the last dimension of size 3 in turn is [batch_number, x, y]. 113 // The [batch_number, x, y] dimension is needed because the indices 114 // [x,y] alone cannot allow the xla::Gather operation to gather from the input 115 // data, which is of dimension [batch, height(y), width(x), channel] with 116 // 'batch' being the first dimension. 117 XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, 118 const TensorShape& warp_shape) { 119 // We need to create an iota tensor with the same batch dimension. 120 std::vector<int64> dimensions; 121 for (auto dim : warp_shape) { 122 dimensions.push_back(dim.size); 123 } 124 // Except the last dimension, which is of size 1. 125 dimensions.back() = 1; 126 127 auto batch_indices = 128 xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions), 129 /*iota_dimension=*/0); 130 131 return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); 132 } 133 134 // Gathers the 2x2 neighbors of the input starting_indices, and return a 135 // tensor of dimension [batch, dim_0, ... dim_n, 4, data_channels]. 136 // 'gather_indices' is of dimension [batch, dim_0, ..., dim_n, 3] where the last 137 // dimension of size 3 is (batch_no, x, y). 138 XlaOp Gather2by2Neighbors(xla::XlaBuilder* b, XlaOp data, XlaOp gather_indices, 139 int64 data_channels, int warp_dims) { 140 xla::GatherDimensionNumbers gather_dim_numbers; 141 const int64 neighbor_data_dimensions = warp_dims + 2; 142 // Since the Gather output dimensions are [batch, dim_0, ... dim_n, 2, 2, 143 // data_channels], the offset dimensions for Gather is the last 3 dimensions. 144 gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 3); 145 gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 2); 146 gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 1); 147 // The last dimension of 'gather_indices' is the starting indices for gather. 148 gather_dim_numbers.set_index_vector_dim(warp_dims - 1); 149 gather_dim_numbers.add_collapsed_slice_dims(0); 150 gather_dim_numbers.add_start_index_map(0); 151 // Since input is of dimension [batch, height(y), width(x), channel], and warp 152 // is of dimension [batch, x, y], the ordering of x, y here needs to be 153 // swapped when gathering. 154 gather_dim_numbers.add_start_index_map(2); 155 gather_dim_numbers.add_start_index_map(1); 156 // Data dimensions are [batch, x, y, channel]. 157 // Output dimensions are [batch, dim_0, ... dim_n, 2, 2, data_channels]. 158 auto neighbors_data = xla::Gather(data, gather_indices, gather_dim_numbers, 159 /*slice_sizes=*/{1, 2, 2, data_channels}); 160 // Collapse the ...,2,2,... dimensions into ...,4,... 161 return xla::Collapse(neighbors_data, {warp_dims - 1, warp_dims}); 162 } 163 164 // Scatter 'updates' tensor to 'grad_data' based on 'indices'. Returns the 165 // resulting tensor of dimension: [batch, dim_0, ...dim_n, 2, 2, data_channels]. 166 // This function can also be seen as the inverse of 'Gather2by2Neighbors'. 167 XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, 168 XlaOp updates, int64 warp_dims, 169 xla::PrimitiveType xla_type) { 170 xla::ScatterDimensionNumbers scatter_dim_numbers; 171 const int64 neighbor_data_dimensions = warp_dims + 2; 172 // Since the Scatter output dimensions are [batch, dim_0, ... dim_n, 2, 2, 173 // data_channels], the update window dimensions is the last 3 dimensions. 174 scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 3); 175 scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 2); 176 scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 1); 177 scatter_dim_numbers.set_index_vector_dim(warp_dims - 1); 178 179 scatter_dim_numbers.add_inserted_window_dims(0); 180 scatter_dim_numbers.add_scatter_dims_to_operand_dims(0); 181 // Since input is of dimension [batch, height(y), width(x), channel], and warp 182 // is of dimension [batch, x, y], the ordering of x, y here needs to be 183 // swapped when scattering. 184 scatter_dim_numbers.add_scatter_dims_to_operand_dims(2); 185 scatter_dim_numbers.add_scatter_dims_to_operand_dims(1); 186 187 return xla::Scatter(grad_data, indices, updates, 188 xla::CreateScalarAddComputation(xla_type, ctx->builder()), 189 scatter_dim_numbers); 190 } 191 192 // Bounds samples to 0 if the warp image indices are out of the (-1, image_size) 193 // bound. 194 // The resulting dimension is given by 'result_dims'. 195 XlaOp BoundSamples(XlaOpKernelContext* ctx, XlaOp warp, 196 xla::PrimitiveType warp_type, TensorShape warp_shape, 197 std::vector<int64> result_dims, 198 std::vector<int64> broadcasted_dims, int64 last_warp_dim, 199 xla::Shape data_shape, XlaOp sample) { 200 auto is_gt_minus_one = 201 xla::Gt(warp, 202 xla::ConvertElementType( 203 xla::ConstantR1<float>(ctx->builder(), {-1, -1}), warp_type), 204 /*broadcast_dimensions=*/{warp_shape.dims() - 1}); 205 auto is_lt_image_size = xla::Lt( 206 warp, 207 xla::ConvertElementType( 208 xla::ConstantR1<float>( 209 ctx->builder(), 210 {/*width=*/static_cast<float>(data_shape.dimensions(2)), 211 /*height=*/static_cast<float>(data_shape.dimensions(1))}), 212 warp_type), 213 /*broadcast_dimensions=*/{warp_shape.dims() - 1}); 214 215 auto is_in_bound_padded_x_y = xla::And(is_gt_minus_one, is_lt_image_size); 216 // Reduce along last dimension. The resulting dimension is: 217 // [batch, dim_0, ...dim_n]. 218 auto is_in_bound = xla::Reduce( 219 is_in_bound_padded_x_y, xla::ConstantR0<bool>(ctx->builder(), true), 220 xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, ctx->builder()), 221 {last_warp_dim}); 222 223 // Broadcast 'is_in_bound' to the same dimension as 'result_dims'. 224 auto broadcasted_is_in_bound = 225 xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims); 226 227 // Set out of bound samples to zero. 228 auto zeros = 229 xla::Broadcast(xla::Zero(ctx->builder(), warp_type), result_dims); 230 return xla::Select(broadcasted_is_in_bound, sample, zeros); 231 } 232 233 // Build computation the backprop into input 'data'. 234 // Where input: 235 // grad_output is of dimension [batch, dim_0, ...dim_n, channel] 236 // ratio is of dimension [batch, dim_0, ...dim_n, 2] 237 // gather_indices is of dimension [batch, dim_0, ...dim_n, 3] 238 // data_shape is of dimension [batch, x(width), y(height), channel] 239 // 240 // Output: 241 // scatter-add to each 2x2 grad_data neighbor: 242 // grad_data[fx, fy, chan] += output_grad * dx * dy 243 // grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy 244 // grad_data[fx, cy, chan] += output_grad * dx * (1 - dy) 245 // grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy) 246 // where (dx, dy) is (1 - ratio). If (dx, dy) is out of bound, then the their 247 // contribution is 0 to 'grad_data'. 248 XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, 249 XlaOp gather_indices, XlaOp warp, 250 xla::PrimitiveType warp_type, TensorShape warp_shape, 251 int64 last_warp_dim, int64 data_channels, 252 xla::Shape data_shape) { 253 // Weights tensor has dimension [batch, dim_0, ... dim_n, 4]. 254 auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type); 255 256 auto warp_dims = warp_shape.dim_sizes(); 257 std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(), 258 warp_dims.end() - 1); 259 260 std::vector<int64> reshaped_weights_dims = warp_dims_without_last_dims; 261 // Reshape the last dimension of size 4 to two dimensions [2, 2]. 262 reshaped_weights_dims.push_back(2); 263 reshaped_weights_dims.push_back(2); 264 std::vector<int64> reshape_dims(warp_shape.dims()); 265 std::iota(reshape_dims.begin(), reshape_dims.end(), 0); 266 // The dimension is [batch, dim_0,..., dim_n, 2, 2]. 267 auto reshaped_weights = xla::Reshape(weights, /*dimensions=*/reshape_dims, 268 /*new_sizes=*/reshaped_weights_dims); 269 270 std::vector<int64> weights_with_channels_dims = reshaped_weights_dims; 271 weights_with_channels_dims.push_back(data_channels); 272 std::vector<int64> reshaped_weights_indices(reshaped_weights_dims.size()); 273 std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), 274 0); 275 276 // Set out of bound weights to 0. 277 // The dimension of the reshaped_weight: [batch, dim_0, ...dim_n, 2, 2]. 278 std::vector<int64> reshaped_result_dims(warp_dims.begin(), 279 warp_dims.end() - 1); 280 reshaped_result_dims.push_back(2); 281 reshaped_result_dims.push_back(2); 282 std::vector<int64> broadcasted_dims(warp_dims.size() - 1); 283 std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); 284 reshaped_weights = BoundSamples(ctx, warp, warp_type, warp_shape, 285 reshaped_result_dims, broadcasted_dims, 286 last_warp_dim, data_shape, reshaped_weights); 287 288 // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. 289 auto broadcast_reshaped_weights = xla::BroadcastInDim( 290 reshaped_weights, weights_with_channels_dims, reshaped_weights_indices); 291 292 std::vector<int64> grad_output_indices(warp_dims_without_last_dims.size()); 293 std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0); 294 grad_output_indices.push_back(weights_with_channels_dims.size() - 1); 295 XlaOp broadcast_grad_output = xla::BroadcastInDim( 296 grad_output, weights_with_channels_dims, grad_output_indices); 297 298 auto grad_output_multiply_weights = 299 broadcast_grad_output * broadcast_reshaped_weights; 300 301 auto grad_data = xla::ConstantLiteral( 302 ctx->builder(), xla::Literal::CreateFromShape(data_shape)); 303 304 // Pad grad data then slice it back. 305 // 306 // After left and right column 0-padding, the new dimension of padded data 307 // will be [batch, x+2, y+2, channel]. 308 auto padded_grad_data = 309 xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type), 310 xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); 311 312 auto shifting_value = xla::ConstantR1<int32>( 313 ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); 314 auto shifted_gather_indices = 315 xla::Add(gather_indices, shifting_value, {last_warp_dim}); 316 317 auto updated_grad_data = ScatterToGradData( 318 ctx, padded_grad_data, shifted_gather_indices, 319 grad_output_multiply_weights, warp_shape.dims(), warp_type); 320 321 const int64 batch_size = data_shape.dimensions(0); 322 const int64 width = data_shape.dimensions(1); 323 const int64 height = data_shape.dimensions(2); 324 // Slice out the result accounting for the padding. 325 return xla::Slice( 326 updated_grad_data, /*start_indices=*/{0, 1, 1, 0}, 327 /*limit_indices=*/{batch_size, width + 1, height + 1, data_channels}, 328 /*strides=*/{1, 1, 1, 1}); 329 } 330 331 // Build computation for the backprop into input 'warp'. 332 // Where input: 333 // warp is of dimension [batch, dim_0, ...dim_n, 2] 334 // grad_output is of dimension [batch, dim_0, ...dim_n, channel] 335 // ratio is of dimension [batch, dim_0, ...dim_n, 2] 336 // gather_indices is of dimension [batch, dim_0, ...dim_n, 3] where the last 337 // dimension of size 3 is for {batch, x(width), y(height)}. 338 // data is of dimension [batch, x, y, channel] 339 // 340 // Output (simplified by ignoring the batch dimensions): 341 // Since the forward path has: 342 // output = dot(weights * neighbors) 343 // The backprop into warp will therefore be: 344 // grad_warp = output_grad * d_output / d_warp 345 // = output_grad * (d_weights / d_warp * neighbors + d_neighbors / 346 // d_warp * weight) 347 // Where: 348 // d_weights / d_warp_x = [-(1 - py), (1 - py), -py, py] 349 // d_weights / d_warp_y = [-(1 - px), -px, (1-px), px] 350 // and 351 // d_neighbors / d_warp_x = 0 352 // 353 // Therefore: 354 // grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy) 355 // grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy) 356 // 357 // where (px, py) is warp, (fx, fy) is the top left corner and (cx, cy) is the 358 // bottom right corner in a 2x2 neighborhood. 359 XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, 360 XlaOp gather_indices, XlaOp data, 361 TensorShape warp_shape, int64 data_channels, 362 xla::PrimitiveType data_type, xla::Shape data_shape) { 363 auto warp_dims = warp_shape.dim_sizes(); 364 std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(), 365 warp_dims.end() - 1); 366 367 // With dimension [batch, dim_0, ...dim_n, 4] 368 std::vector<int64> neighbor_broadcast_dims = warp_dims_without_last_dims; 369 neighbor_broadcast_dims.push_back(4); 370 371 // With dimension [batch, dim_0, ...dim_n, 4] 372 auto neighbor_broadcast_shape = 373 xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); 374 375 const int64 last_warp_dim = warp_shape.dims() - 1; 376 377 // Pad data with 0, before gathering such that 0 will be returned for samples 378 // in the range of (-1, 0) or (image_dimension-1, image_dimension). 379 // After left and right column 0-padding, the new dimension of padded data 380 // will be [batch, x+2, y+2, channel]. 381 auto padded_data = 382 xla::Pad(data, xla::Zero(ctx->builder(), data_type), 383 xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); 384 385 auto shifting_value = xla::ConstantR1<int32>( 386 ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); 387 auto shifted_gather_indices = 388 xla::Add(gather_indices, shifting_value, {last_warp_dim}); 389 390 // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] 391 auto neighbors_data = 392 Gather2by2Neighbors(ctx->builder(), padded_data, shifted_gather_indices, 393 data_channels, warp_shape.dims()); 394 395 // Since we will be creating the dot product of: 396 // lhs: [batch, dim_0, ...dim_n, 4] 397 // and 398 // rhs: [batch, dim_0, ...dim_n, 4, data_channels] 399 // we choose the last dimension of lhs and the second last dimension of rhs, 400 // with size 4, as the contracting dimension. 401 xla::DotDimensionNumbers dot_dims; 402 for (int i = 0; i < warp_shape.dims() - 1; ++i) { 403 dot_dims.add_lhs_batch_dimensions(i); 404 dot_dims.add_rhs_batch_dimensions(i); 405 } 406 dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); 407 dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); 408 409 // img_cxcy - img_fxcy 410 auto bottom_right_minus_bottom_left = xla::DotGeneral( 411 xla::BroadcastInDim( 412 xla::ConvertElementType( 413 xla::ConstantR1<float>(ctx->builder(), {0, 0, -1, 1}), data_type), 414 neighbor_broadcast_dims, {last_warp_dim}), 415 neighbors_data, dot_dims, /*precision_config=*/nullptr); 416 417 // img_cxfy - img_fxfy 418 auto top_right_minus_top_left = xla::DotGeneral( 419 xla::BroadcastInDim( 420 xla::ConvertElementType( 421 xla::ConstantR1<float>(ctx->builder(), {-1, 1, 0, 0}), data_type), 422 neighbor_broadcast_dims, {last_warp_dim}), 423 neighbors_data, dot_dims, /*precision_config=*/nullptr); 424 425 // img_cxcy - img_cxfy 426 auto bottom_right_minus_top_right = xla::DotGeneral( 427 xla::BroadcastInDim( 428 xla::ConvertElementType( 429 xla::ConstantR1<float>(ctx->builder(), {0, -1, 0, 1}), data_type), 430 neighbor_broadcast_dims, {last_warp_dim}), 431 neighbors_data, dot_dims, /*precision_config=*/nullptr); 432 433 // img_fxcy - img_fxfy 434 auto bottom_left_minus_top_left = xla::DotGeneral( 435 xla::BroadcastInDim( 436 xla::ConvertElementType( 437 xla::ConstantR1<float>(ctx->builder(), {-1, 0, 1, 0}), data_type), 438 neighbor_broadcast_dims, {last_warp_dim}), 439 neighbors_data, dot_dims, /*precision_config=*/nullptr); 440 441 // Slice out x and y. 442 auto weight_x = xla::SliceInDim(ratio, /*start_index=*/0, /*limit_index=*/1, 443 /*stride=*/1, /*dimno=*/last_warp_dim); 444 auto weight_y = xla::SliceInDim(ratio, /*start_index=*/1, /*limit_index=*/2, 445 /*stride=*/1, /*dimno=*/last_warp_dim); 446 447 // Build 1 - y and 1 - x. 448 auto one_minus_y = xla::One(ctx->builder(), data_type) - weight_y; 449 auto one_minus_x = xla::One(ctx->builder(), data_type) - weight_x; 450 451 auto x_before_reduce = 452 grad_output * weight_y * bottom_right_minus_bottom_left + 453 one_minus_y * top_right_minus_top_left; 454 455 std::vector<int64> reshaped_sizes = warp_dims_without_last_dims; 456 reshaped_sizes.push_back(1); 457 458 std::vector<int64> reshaped_dims(warp_dims_without_last_dims.size()); 459 std::iota(reshaped_dims.begin(), reshaped_dims.end(), 0); 460 461 // Reduce-add along the channel dimension. 462 auto x_result = 463 xla::Reduce(x_before_reduce, xla::Zero(ctx->builder(), data_type), 464 xla::CreateScalarAddComputation(data_type, ctx->builder()), 465 {last_warp_dim}); 466 // Reshape before concatenating with y values. 467 XlaOp reshaped_x = xla::Reshape(x_result, reshaped_dims, reshaped_sizes); 468 469 auto y_before_reduce = grad_output * weight_x * bottom_right_minus_top_right + 470 one_minus_x * bottom_left_minus_top_left; 471 // Reduce-add along the channel dimension. 472 auto y_result = 473 xla::Reduce(y_before_reduce, xla::Zero(ctx->builder(), data_type), 474 475 xla::CreateScalarAddComputation(data_type, ctx->builder()), 476 {last_warp_dim}); 477 XlaOp reshaped_y = xla::Reshape(y_result, reshaped_dims, reshaped_sizes); 478 479 return xla::ConcatInDim(ctx->builder(), {reshaped_x, reshaped_y}, 480 last_warp_dim); 481 } 482 483 class ResamplerOp : public XlaOpKernel { 484 public: 485 explicit ResamplerOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 486 487 void Compile(XlaOpKernelContext* ctx) override { 488 TensorShape data_shape = ctx->InputShape("data"); 489 OP_REQUIRES(ctx, data_shape.dims() == 4, 490 errors::InvalidArgument("data must be 4-dimensional", 491 data_shape.DebugString())); 492 const int64 data_channels = data_shape.dim_size(3); 493 xla::PrimitiveType data_type = ctx->input_xla_type(0); 494 495 TensorShape warp_shape = ctx->InputShape("warp"); 496 OP_REQUIRES(ctx, warp_shape.dims() >= 2, 497 errors::InvalidArgument("warp must be at least 2-dimensional", 498 warp_shape.DebugString())); 499 for (int size : warp_shape.dim_sizes()) { 500 OP_REQUIRES(ctx, size > 0, 501 errors::InvalidArgument("warp sizes must be positive, got [", 502 size, "]")); 503 } 504 const int64 last_warp_dim = warp_shape.dims() - 1; 505 // Last dimension of warp shape must be of size 2. 506 OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, 507 errors::InvalidArgument( 508 "the last dimension of warp must be exactly size 2.")); 509 xla::PrimitiveType warp_type = ctx->input_xla_type(1); 510 511 XlaOp data = ctx->Input("data"); 512 XlaOp warp = ctx->Input("warp"); 513 514 // Find the coordinates of the top left corner for the 2x2 region to be 515 // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the 516 // last dimension of size 2 in turn is [x, y]. 517 XlaOp top_left = xla::ConvertElementType(warp, xla::S32); 518 519 auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); 520 521 // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] 522 auto neighbors_data = Gather2by2Neighbors( 523 ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); 524 525 // Dimensions are [batch, dim_0, ... dim_n, 2]. 526 XlaOp ratio = warp - xla::ConvertElementType(top_left, data_type); 527 528 // Obtain the bilinear blending weights, the dimension is [batch, dim_0, 529 // ...dim_n, 4]. 530 auto weights = BilinearWeights(ctx, ratio, warp_shape, data_type); 531 532 // Since we will be creating the dot product of: 533 // lhs: [batch, dim_0, ...dim_n, 4] 534 // and 535 // rhs: [batch, dim_0, ...dim_n, 4, data_channels] 536 // we choose the last dimension of lhs and the second last dimension of rhs, 537 // with size 4, as the contracting dimension. 538 xla::DotDimensionNumbers dot_dims; 539 for (int i = 0; i < warp_shape.dims() - 1; ++i) { 540 dot_dims.add_lhs_batch_dimensions(i); 541 dot_dims.add_rhs_batch_dimensions(i); 542 } 543 dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); 544 dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); 545 546 // The dimension is [batch, dim_0, ...dim_n, data_channels]. 547 auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims, 548 /*precision_config=*/nullptr); 549 550 // Handle out of boundary cases by constructing a predicate mask array based 551 // on the in-bound condition, and output 0 for the blended pixel value if 552 // out-bound. The dimension is the same as top_left: [batch, dim_0, 553 // ...dim_n, 2] where the last dimension of size 2 is the [x, y] coordinate. 554 555 auto is_ge_zero = xla::Ge(warp, xla::ZerosLike(warp)); 556 557 auto is_lt_image_size = xla::Lt( 558 warp, 559 xla::ConvertElementType( 560 xla::ConstantR1<float>( 561 ctx->builder(), 562 {/*width=*/static_cast<float>(data_shape.dim_size(2) - 1), 563 /*height=*/static_cast<float>(data_shape.dim_size(1) - 1)}), 564 warp_type), 565 /*broadcast_dimensions=*/{warp_shape.dims() - 1}); 566 567 auto is_in_bound_x_y = xla::And(is_ge_zero, is_lt_image_size); 568 // Reduce along last dimension. The resulting dimension is: 569 // [batch, dim_0, ...dim_n]. 570 auto is_in_bound = xla::Reduce( 571 is_in_bound_x_y, xla::ConstantR0<bool>(ctx->builder(), true), 572 xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, 573 ctx->builder()), 574 {last_warp_dim}); 575 576 // Broadcast 'is_in_bound' to the same dimension as 'blended_pixels', which 577 // is the dimension of the result: 578 // [batch, dim_0, ...dim_n, data_channels]. 579 auto warp_dims = warp_shape.dim_sizes(); 580 std::vector<int64> result_dims(warp_dims.begin(), warp_dims.end() - 1); 581 result_dims.push_back(data_channels); 582 583 std::vector<int64> broadcasted_dims(warp_dims.size() - 1); 584 std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); 585 auto broadcasted_is_in_bound = 586 xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims); 587 588 // Set out of bound samples to zero. 589 auto zeros = 590 xla::Broadcast(xla::Zero(ctx->builder(), data_type), result_dims); 591 auto result = xla::Select(broadcasted_is_in_bound, blended_pixels, zeros); 592 593 ctx->SetOutput(0, result); 594 } 595 }; 596 597 REGISTER_XLA_OP(Name("Resampler"), ResamplerOp); 598 599 class ResamplerGradOp : public XlaOpKernel { 600 public: 601 explicit ResamplerGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 602 DataType output_dtype; 603 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); 604 } 605 606 // TODO(b/112295522): note that sampling from image boundary is not currently 607 // being handled properly. 608 void Compile(XlaOpKernelContext* ctx) override { 609 TensorShape data_shape_tf = ctx->InputShape("data"); 610 OP_REQUIRES(ctx, data_shape_tf.dims() == 4, 611 errors::InvalidArgument("data must be 4-dimensional", 612 data_shape_tf.DebugString())); 613 const int64 data_channels = data_shape_tf.dim_size(3); 614 xla::PrimitiveType data_type = ctx->input_xla_type(0); 615 616 TensorShape warp_shape = ctx->InputShape("warp"); 617 OP_REQUIRES(ctx, warp_shape.dims() >= 2, 618 errors::InvalidArgument("warp must be at least 2-dimensional", 619 warp_shape.DebugString())); 620 for (int size : warp_shape.dim_sizes()) { 621 OP_REQUIRES(ctx, size > 0, 622 errors::InvalidArgument("warp sizes must be positive, got [", 623 size, "]")); 624 } 625 // Last dimension of warp shape must be of size 2. 626 const int64 last_warp_dim = warp_shape.dims() - 1; 627 OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, 628 errors::InvalidArgument( 629 "the last dimension of warp must be exactly size 2.")); 630 xla::PrimitiveType warp_type = ctx->input_xla_type(1); 631 632 TensorShape output_grad_shape = ctx->InputShape("grad_output"); 633 OP_REQUIRES( 634 ctx, output_grad_shape.dims() >= 2, 635 errors::InvalidArgument("output_grad must be at least 2-dimensional", 636 output_grad_shape.DebugString())); 637 638 // Dimensions are [batch, x, y, channel]. 639 XlaOp data = ctx->Input("data"); 640 xla::Shape data_shape = TensorShapeToXLAShape(data_type, data_shape_tf); 641 642 // Dimensions are [batch, dim_0, ...dim_n, 2]. 643 XlaOp warp = ctx->Input("warp"); 644 // Dimensions are [batch, dim_0, ...dim_n, channel]. 645 XlaOp grad_output = ctx->Input("grad_output"); 646 647 // Find the top left corner coordinate for the region to be sampled from. 648 // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension 649 // of size 2 in turn is [x, y]. 650 XlaOp top_left = xla::ConvertElementType(xla::Floor(warp), xla::S32); 651 652 // Dimensions are [batch, dim_0, ... dim_n, 2]. 653 XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type); 654 655 // Indices for gathering neighboring pixels. 656 auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); 657 658 auto grad_data = CalculateGradData( 659 ctx, grad_output, ratio, gather_indices, warp, warp_type, warp_shape, 660 last_warp_dim, data_channels, data_shape); 661 662 auto grad_warp = 663 CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data, 664 warp_shape, data_channels, data_type, data_shape); 665 auto warp_dims = warp_shape.dim_sizes(); 666 std::vector<int64> result_dims(warp_dims.begin(), warp_dims.end() - 1); 667 result_dims.push_back(2); 668 std::vector<int64> broadcasted_dims(warp_dims.size() - 1); 669 std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); 670 auto grad_warp_bounded = 671 BoundSamples(ctx, warp, warp_type, warp_shape, result_dims, 672 broadcasted_dims, last_warp_dim, data_shape, grad_warp); 673 674 ctx->SetOutput(0, grad_data); 675 ctx->SetOutput(1, grad_warp_bounded); 676 } 677 }; 678 679 REGISTER_XLA_OP(Name("ResamplerGrad"), ResamplerGradOp); 680 681 } // namespace 682 } // namespace tensorflow 683