Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific base classes for Unary and Binary Ops.
     17 
     18 #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
     19 #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
     20 
     21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     22 #include "tensorflow/compiler/xla/client/client_library.h"
     23 #include "tensorflow/compiler/xla/client/computation_builder.h"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/util/bcast.h"
     26 
     27 namespace tensorflow {
     28 
     29 // Coefficient-wise binary operations. Each binary Op expects two
     30 // inputs that can be broadcast to the same shape. The base class
     31 // contains pure virtual methods to override: description is a textual
     32 // description of the operation; and Computation adds the
     33 // implementation of the operation to a xla::ComputationBuilder. For most
     34 // arithmetic Ops XLA handles the broadcasting automatically given the input
     35 // tensors.
     36 class XlaBinaryOp : public XlaOpKernel {
     37  public:
     38   explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     39     const DataType lhs = BaseType(input_type(0));
     40     const DataType rhs = BaseType(input_type(1));
     41     OP_REQUIRES(ctx, lhs == rhs,
     42                 errors::InvalidArgument("Input types of binary op must match"));
     43   }
     44   ~XlaBinaryOp() override {}
     45 
     46   // Implement the (tensor,tensor)->tensor lambda that should be
     47   // applied to the inputs. The desired computation should be added to
     48   // 'tc->builder()' and '(lhs,rhs)' are the function's inputs and
     49   // (lhs_shape,rhs_shape) are their respective
     50   // shapes. 'broadcast_helper' contains metadata about the shapes of
     51   // the inputs and the dimensions that need to be broadcast, which
     52   // may be useful for Ops that can't use standard XLA automatic
     53   // broadcasting. 'extend_dimension' is non-empty if lhs and rhs have
     54   // different ranks, and indicates which dimensions of the
     55   // higher-rank input should be matched when broadcasting the
     56   // lower-rank input. See comment below and the documentation on broadcasting
     57   // in the XLA documentation.
     58   virtual xla::ComputationDataHandle Computation(
     59       XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs,
     60       const gtl::ArraySlice<int64>& lhs_shape,
     61       const xla::ComputationDataHandle& rhs,
     62       const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
     63       const std::vector<int64>& extend_dimensions) = 0;
     64 
     65   void Compile(XlaOpKernelContext* ctx) override;
     66 
     67   // Helper function that performs the broadcasting described by
     68   // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
     69   // shape.
     70   static std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
     71   Broadcast(xla::ComputationBuilder* builder,
     72             const xla::ComputationDataHandle& lhs,
     73             const xla::ComputationDataHandle& rhs,
     74             const BCast& broadcast_helper);
     75 };
     76 
     77 }  // namespace tensorflow
     78 
     79 #endif  // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
     80