Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific Ops for broadcasting used in gradient
     17 // code.
     18 
     19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     22 #include "tensorflow/compiler/xla/literal_util.h"
     23 #include "tensorflow/core/platform/macros.h"
     24 #include "tensorflow/core/platform/types.h"
     25 #include "tensorflow/core/util/bcast.h"
     26 
     27 namespace tensorflow {
     28 namespace {
     29 
     30 // Given shapes of two tensors, computes the broadcast shape.
     31 class BCastArgsOp : public XlaOpKernel {
     32  public:
     33   explicit BCastArgsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     34     OP_REQUIRES_OK(ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32}));
     35   }
     36 
     37   void Compile(XlaOpKernelContext* ctx) override {
     38     OP_REQUIRES(
     39         ctx, ctx->num_inputs() == 2,
     40         errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
     41     gtl::InlinedVector<BCast::Vec, 2> shapes;
     42     for (int i = 0; i < ctx->num_inputs(); ++i) {
     43       const TensorShape in_shape = ctx->InputShape(i);
     44       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
     45                   errors::InvalidArgument("In[", i, "] must be a vector.",
     46                                           in_shape.DebugString()));
     47       std::vector<int64> shape;
     48       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &shape));
     49       shapes.push_back(BCast::Vec(shape.begin(), shape.end()));
     50     }
     51     BCast bcast(shapes[0], shapes[1]);
     52     OP_REQUIRES(ctx, bcast.IsValid(),
     53                 errors::InvalidArgument(
     54                     "Incompatible shapes: [", str_util::Join(shapes[0], ","),
     55                     "] vs. [", str_util::Join(shapes[1], ","), "]"));
     56 
     57     const int64 len = bcast.output_shape().size();
     58     Tensor output(DT_INT32, TensorShape({len}));
     59     for (int64 i = 0; i < len; ++i) {
     60       output.flat<int32>()(i) = static_cast<int32>(bcast.output_shape()[i]);
     61     }
     62     ctx->SetConstantOutput(0, output);
     63   }
     64 
     65  private:
     66   TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp);
     67 };
     68 REGISTER_XLA_OP(Name("BroadcastArgs")
     69                     .CompileTimeConstInput("s0")
     70                     .CompileTimeConstInput("s1"),
     71                 BCastArgsOp);
     72 
     73 // Given shapes of two tensors, computes the reduction indices for the
     74 // gradient computation.
     75 //
     76 // TODO(zhifengc):
     77 //   1. Adds support for n-ary (n >= 2).
     78 class BCastGradArgsOp : public XlaOpKernel {
     79  public:
     80   explicit BCastGradArgsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     81     OP_REQUIRES_OK(
     82         ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}));
     83   }
     84 
     85   void Compile(XlaOpKernelContext* ctx) override {
     86     OP_REQUIRES(
     87         ctx, ctx->num_inputs() == 2,
     88         errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
     89 
     90     gtl::InlinedVector<BCast::Vec, 4> shapes;
     91     for (int i = 0; i < ctx->num_inputs(); ++i) {
     92       const TensorShape in_shape = ctx->InputShape(i);
     93       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
     94                   errors::InvalidArgument("In[", i, "] must be a vector.",
     95                                           in_shape.DebugString()));
     96       xla::Literal literal;
     97       OP_REQUIRES_OK(ctx, ctx->ConstantInput(i, &literal));
     98 
     99       BCast::Vec vec;
    100       for (int64 i = 0; i < in_shape.num_elements(); ++i) {
    101         vec.push_back(literal.Get<int>({i}));
    102       }
    103       shapes.push_back(vec);
    104     }
    105     BCast bcast(shapes[0], shapes[1]);
    106     OP_REQUIRES(ctx, bcast.IsValid(),
    107                 errors::InvalidArgument(
    108                     "Incompatible shapes: [", str_util::Join(shapes[0], ","),
    109                     "] vs. [", str_util::Join(shapes[1], ","), "]"));
    110     Output(ctx, 0, bcast.grad_x_reduce_idx());
    111     Output(ctx, 1, bcast.grad_y_reduce_idx());
    112   }
    113 
    114  private:
    115   void Output(XlaOpKernelContext* ctx, int idx, const BCast::Vec& v) {
    116     const int64 len = v.size();
    117     Tensor constant(DT_INT32, TensorShape({len}));
    118     for (int64 i = 0; i < len; ++i) {
    119       constant.flat<int32>()(i) = static_cast<int32>(v[i]);
    120     }
    121     ctx->SetConstantOutput(idx, constant);
    122   }
    123 
    124   TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp);
    125 };
    126 
    127 REGISTER_XLA_OP(Name("BroadcastGradientArgs")
    128                     .CompileTimeConstInput("s0")
    129                     .CompileTimeConstInput("s1"),
    130                 BCastGradArgsOp);
    131 
    132 }  // namespace
    133 }  // namespace tensorflow
    134