Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific base classes for Unary and Binary Ops.
     17 
     18 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
     19 
     20 #include "tensorflow/compiler/tf2xla/type_util.h"
     21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     24 #include "tensorflow/compiler/xla/client/client_library.h"
     25 #include "tensorflow/compiler/xla/client/computation_builder.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/util/bcast.h"
     29 
     30 namespace tensorflow {
     31 
     32 void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
     33   const TensorShape lhs_shape = ctx->InputShape(0);
     34   const TensorShape rhs_shape = ctx->InputShape(1);
     35 
     36   // By TensorFlow conventions the inputs may not have the same
     37   // shapes, in which case they will be automatically broadcast if
     38   // possible before mapping. Use the standard TensorFlow helper to
     39   // compute valid broadcast shapes, but rely below on XLA to
     40   // automatically perform the broadcast assuming its valid shapes are
     41   // a superset of TensorFlow's valid shapes.
     42   BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
     43   if (!bcast.IsValid()) {
     44     ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
     45                                            lhs_shape.DebugString(), " vs. ",
     46                                            rhs_shape.DebugString()));
     47     return;
     48   }
     49   TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
     50 
     51   // Fetch the expressions containing the input tensors.
     52   auto lhs_handle = ctx->Input(0);
     53   auto rhs_handle = ctx->Input(1);
     54 
     55   // If the ranks of the inputs don't match, TensorFlow automatically
     56   // reshapes the smaller by padding with dimensions of size 1 as a
     57   // prefix. In other words to pad a 5-vector to a 3-dimensional
     58   // tensor it is reshaped to have shape [1,1,5]. XLA's automatic
     59   // broadcast code is able to broadcast from lower to higher rank,
     60   // but doesn't assume you want to pad as a prefix of the dimensions,
     61   // and instead needs to be told which dimensions of the higher rank
     62   // tensor to match to the lower rank tensor. In this example it
     63   // would be dimensions [2]. If we were matching a matrix against a
     64   // 4-D tensor the dimensions to match would be [2,3],
     65   // etc. extend_dimension encodes the general case.
     66   std::vector<int64> extend_dimension;
     67   int max_rank = std::max(lhs_shape.dims(), rhs_shape.dims());
     68   int min_rank = std::min(lhs_shape.dims(), rhs_shape.dims());
     69   if (min_rank != max_rank) {
     70     for (int i = 0; i < min_rank; ++i) {
     71       // Match the lower rank tensor along the larger-numbered
     72       // dimensions of the higher rank tensor.
     73       extend_dimension.push_back(max_rank - min_rank + i);
     74     }
     75   }
     76 
     77   // Call virtual method to emit the computation.
     78   xla::ComputationDataHandle output =
     79       Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle,
     80                   rhs_shape.dim_sizes(), bcast, extend_dimension);
     81 
     82   // The TensorFlow helper computed the post-broadcast shape in
     83   // output_shape: we rely on subclassed Computations to implement the
     84   // same broadcast semantics.
     85   ctx->SetOutput(0, output);
     86 }
     87 
     88 /* static */ std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
     89 XlaBinaryOp::Broadcast(xla::ComputationBuilder* builder,
     90                        const xla::ComputationDataHandle& lhs,
     91                        const xla::ComputationDataHandle& rhs,
     92                        const BCast& broadcast_helper) {
     93   // Manually construct the broadcasting since MapN does not do
     94   // automatic broadcasting. The bcast helper ensures that
     95   // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
     96   // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have
     97   // the same shape, so can be operated on by MapN.
     98 
     99   // First reshape the inputs, which should be a metadata-only
    100   // operation since we are flattening the dimensions in order.
    101   auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape());
    102   auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape());
    103 
    104   // Next broadcast the necessary input dimensions. We rely on the
    105   // XLA optimizer to be smart about the fact that we are asking
    106   // it to broadcast size 1 on some of these dimensions, to avoid
    107   // adding complexity to this code.
    108   auto lhs_broadcast =
    109       builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast());
    110   int lhs_size = broadcast_helper.x_bcast().size();
    111   auto rhs_broadcast =
    112       builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast());
    113   int rhs_size = broadcast_helper.y_bcast().size();
    114 
    115   // Now reshape them to the correct output shape. After the
    116   // broadcast each side is twice as wide as it should be, since the
    117   // broadcast dimensions were prepended to the shape. Reshape
    118   // flattening each original dimension with the prepended broadcast
    119   // dimension. E.g. if we started out with lhs_shaped with shape
    120   // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have
    121   // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21].
    122   std::vector<int64> lhs_reorder;
    123   for (int i = 0; i < lhs_size; ++i) {
    124     lhs_reorder.push_back(i);
    125     lhs_reorder.push_back(i + lhs_size);
    126   }
    127   auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder,
    128                                      broadcast_helper.output_shape());
    129   std::vector<int64> rhs_reorder;
    130   for (int i = 0; i < rhs_size; ++i) {
    131     rhs_reorder.push_back(i);
    132     rhs_reorder.push_back(i + rhs_size);
    133   }
    134   auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder,
    135                                      broadcast_helper.output_shape());
    136 
    137   return {lhs_output, rhs_output};
    138 }
    139 
    140 }  // namespace tensorflow
    141