1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific base classes for Unary and Binary Ops. 17 18 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" 19 20 #include "tensorflow/compiler/tf2xla/type_util.h" 21 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/compiler/xla/client/client_library.h" 25 #include "tensorflow/compiler/xla/client/computation_builder.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/util/bcast.h" 29 30 namespace tensorflow { 31 32 void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { 33 const TensorShape lhs_shape = ctx->InputShape(0); 34 const TensorShape rhs_shape = ctx->InputShape(1); 35 36 // By TensorFlow conventions the inputs may not have the same 37 // shapes, in which case they will be automatically broadcast if 38 // possible before mapping. Use the standard TensorFlow helper to 39 // compute valid broadcast shapes, but rely below on XLA to 40 // automatically perform the broadcast assuming its valid shapes are 41 // a superset of TensorFlow's valid shapes. 42 BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); 43 if (!bcast.IsValid()) { 44 ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ", 45 lhs_shape.DebugString(), " vs. ", 46 rhs_shape.DebugString())); 47 return; 48 } 49 TensorShape bcast_shape = BCast::ToShape(bcast.output_shape()); 50 51 // Fetch the expressions containing the input tensors. 52 auto lhs_handle = ctx->Input(0); 53 auto rhs_handle = ctx->Input(1); 54 55 // If the ranks of the inputs don't match, TensorFlow automatically 56 // reshapes the smaller by padding with dimensions of size 1 as a 57 // prefix. In other words to pad a 5-vector to a 3-dimensional 58 // tensor it is reshaped to have shape [1,1,5]. XLA's automatic 59 // broadcast code is able to broadcast from lower to higher rank, 60 // but doesn't assume you want to pad as a prefix of the dimensions, 61 // and instead needs to be told which dimensions of the higher rank 62 // tensor to match to the lower rank tensor. In this example it 63 // would be dimensions [2]. If we were matching a matrix against a 64 // 4-D tensor the dimensions to match would be [2,3], 65 // etc. extend_dimension encodes the general case. 66 std::vector<int64> extend_dimension; 67 int max_rank = std::max(lhs_shape.dims(), rhs_shape.dims()); 68 int min_rank = std::min(lhs_shape.dims(), rhs_shape.dims()); 69 if (min_rank != max_rank) { 70 for (int i = 0; i < min_rank; ++i) { 71 // Match the lower rank tensor along the larger-numbered 72 // dimensions of the higher rank tensor. 73 extend_dimension.push_back(max_rank - min_rank + i); 74 } 75 } 76 77 // Call virtual method to emit the computation. 78 xla::ComputationDataHandle output = 79 Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle, 80 rhs_shape.dim_sizes(), bcast, extend_dimension); 81 82 // The TensorFlow helper computed the post-broadcast shape in 83 // output_shape: we rely on subclassed Computations to implement the 84 // same broadcast semantics. 85 ctx->SetOutput(0, output); 86 } 87 88 /* static */ std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle> 89 XlaBinaryOp::Broadcast(xla::ComputationBuilder* builder, 90 const xla::ComputationDataHandle& lhs, 91 const xla::ComputationDataHandle& rhs, 92 const BCast& broadcast_helper) { 93 // Manually construct the broadcasting since MapN does not do 94 // automatic broadcasting. The bcast helper ensures that 95 // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and 96 // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have 97 // the same shape, so can be operated on by MapN. 98 99 // First reshape the inputs, which should be a metadata-only 100 // operation since we are flattening the dimensions in order. 101 auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape()); 102 auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape()); 103 104 // Next broadcast the necessary input dimensions. We rely on the 105 // XLA optimizer to be smart about the fact that we are asking 106 // it to broadcast size 1 on some of these dimensions, to avoid 107 // adding complexity to this code. 108 auto lhs_broadcast = 109 builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast()); 110 int lhs_size = broadcast_helper.x_bcast().size(); 111 auto rhs_broadcast = 112 builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast()); 113 int rhs_size = broadcast_helper.y_bcast().size(); 114 115 // Now reshape them to the correct output shape. After the 116 // broadcast each side is twice as wide as it should be, since the 117 // broadcast dimensions were prepended to the shape. Reshape 118 // flattening each original dimension with the prepended broadcast 119 // dimension. E.g. if we started out with lhs_shaped with shape 120 // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have 121 // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21]. 122 std::vector<int64> lhs_reorder; 123 for (int i = 0; i < lhs_size; ++i) { 124 lhs_reorder.push_back(i); 125 lhs_reorder.push_back(i + lhs_size); 126 } 127 auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder, 128 broadcast_helper.output_shape()); 129 std::vector<int64> rhs_reorder; 130 for (int i = 0; i < rhs_size; ++i) { 131 rhs_reorder.push_back(i); 132 rhs_reorder.push_back(i + rhs_size); 133 } 134 auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder, 135 broadcast_helper.output_shape()); 136 137 return {lhs_output, rhs_output}; 138 } 139 140 } // namespace tensorflow 141