Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA implementation of OneHot operator.
     17 
     18 #include "tensorflow/compiler/tf2xla/literal_util.h"
     19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     22 
     23 namespace tensorflow {
     24 namespace {
     25 
     26 class OneHotOp : public XlaOpKernel {
     27  public:
     28   explicit OneHotOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     29     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
     30   }
     31 
     32   void Compile(XlaOpKernelContext* ctx) override {
     33     const TensorShape indices_shape = ctx->InputShape(0);
     34     const TensorShape depth_shape = ctx->InputShape(1);
     35     const TensorShape on_value_shape = ctx->InputShape(2);
     36     const TensorShape off_value_shape = ctx->InputShape(3);
     37 
     38     const int indices_dims = indices_shape.dims();
     39     const int output_dims = indices_dims + 1;
     40 
     41     // Preliminary validation of sizes.
     42     OP_REQUIRES(
     43         ctx, axis_ == -1 || (axis_ >= 0 && axis_ < output_dims),
     44         errors::InvalidArgument("Expected axis to be -1 or between [0, ",
     45                                 output_dims, ").  But received: ", axis_));
     46     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(depth_shape),
     47                 errors::InvalidArgument("depth must be a scalar, but got: ",
     48                                         depth_shape.DebugString()));
     49     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(on_value_shape),
     50                 errors::InvalidArgument("on_value must be a scalar, but got: ",
     51                                         on_value_shape.DebugString()));
     52     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(off_value_shape),
     53                 errors::InvalidArgument("off_value must be a scalar, but got: ",
     54                                         off_value_shape.DebugString()));
     55 
     56     const int axis = (axis_ == -1) ? indices_dims : axis_;
     57 
     58     // The one-hot dimension.
     59     int64 depth;
     60     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &depth));
     61     OP_REQUIRES(
     62         ctx, depth >= 0,
     63         errors::InvalidArgument("depth must be non-negative, got: ", depth));
     64 
     65     xla::ComputationDataHandle one_hot;
     66     OP_REQUIRES_OK(
     67         ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0),
     68                                 indices_shape, ctx->Input(0), ctx->Input(2),
     69                                 ctx->Input(3), &one_hot));
     70     ctx->SetOutput(0, one_hot);
     71   }
     72 
     73  private:
     74   int32 axis_;
     75 
     76   TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
     77 };
     78 
     79 REGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp);
     80 
     81 }  // namespace
     82 }  // namespace tensorflow
     83