Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     19 
     20 namespace tensorflow {
     21 namespace {
     22 
     23 class CrossOp : public XlaOpKernel {
     24  public:
     25   explicit CrossOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
     26 
     27   void Compile(XlaOpKernelContext* ctx) override {
     28     TensorShape in0_shape = ctx->InputShape(0);
     29     TensorShape in1_shape = ctx->InputShape(1);
     30     OP_REQUIRES(ctx, in0_shape == in1_shape,
     31                 errors::InvalidArgument("Both inputs must be of same shape: ",
     32                                         in0_shape.DebugString(), " vs. ",
     33                                         in1_shape.DebugString()));
     34     OP_REQUIRES(ctx, in0_shape.dims() >= 1,
     35                 errors::InvalidArgument("Input must be at least 1D",
     36                                         in0_shape.DebugString()));
     37 
     38     auto inner_dim = in0_shape.dim_size(in0_shape.dims() - 1);
     39     OP_REQUIRES(ctx, inner_dim == 3,
     40                 errors::FailedPrecondition(
     41                     "Cross-products are only defined for 3-element vectors."));
     42 
     43     // in0 is a [...,X,Y,Z,3]
     44     // in1 is the same shape as in0
     45     // So slice 0 is: in0[...,:,:,:,0:1]
     46     // So slice 1 is: in0[...,:,:,:,1:2]
     47     // So slice 2 is: in0[...,:,:,:,2:3]
     48 
     49     std::vector<int64> starts(in0_shape.dims(), 0);
     50     std::vector<int64> limits;
     51     for (auto dim_size : in0_shape.dim_sizes()) {
     52       limits.push_back(dim_size);
     53     }
     54     std::vector<int64> strides(in0_shape.dims(), 1);
     55 
     56     xla::ComputationBuilder* b = ctx->builder();
     57     auto in0 = ctx->Input(0);
     58     auto in1 = ctx->Input(1);
     59     starts.back() = 0;
     60     limits.back() = 1;
     61     auto u1 = b->Slice(in0, starts, limits, strides);
     62     auto v1 = b->Slice(in1, starts, limits, strides);
     63     starts.back() = 1;
     64     limits.back() = 2;
     65     auto u2 = b->Slice(in0, starts, limits, strides);
     66     auto v2 = b->Slice(in1, starts, limits, strides);
     67     starts.back() = 2;
     68     limits.back() = 3;
     69     auto u3 = b->Slice(in0, starts, limits, strides);
     70     auto v3 = b->Slice(in1, starts, limits, strides);
     71 
     72     auto s1 = b->Sub(b->Mul(u2, v3), b->Mul(u3, v2));
     73     auto s2 = b->Sub(b->Mul(u3, v1), b->Mul(u1, v3));
     74     auto s3 = b->Sub(b->Mul(u1, v2), b->Mul(u2, v1));
     75     auto output = b->ConcatInDim({s1, s2, s3}, in0_shape.dims() - 1);
     76 
     77     ctx->SetOutput(0, output);
     78   }
     79 
     80  private:
     81   TF_DISALLOW_COPY_AND_ASSIGN(CrossOp);
     82 };
     83 
     84 REGISTER_XLA_OP(Name("Cross"), CrossOp);
     85 
     86 }  // namespace
     87 }  // namespace tensorflow
     88