Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific reduction Ops.
     17 
     18 #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
     19 #include "tensorflow/compiler/tf2xla/type_util.h"
     20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     22 #include "tensorflow/compiler/xla/literal_util.h"
     23 #include "tensorflow/core/framework/kernel_def_builder.h"
     24 
     25 namespace tensorflow {
     26 namespace {
     27 
     28 class SumOp : public XlaReductionOp {
     29  public:
     30   explicit SumOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
     31   void BuildReducer(xla::ComputationBuilder* builder,
     32                     const xla::ComputationDataHandle& scalar_lhs,
     33                     const xla::ComputationDataHandle& scalar_rhs) override {
     34     builder->Add(scalar_lhs, scalar_rhs);
     35   }
     36 };
     37 
     38 REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp);
     39 
     40 class ProdOp : public XlaReductionOp {
     41  public:
     42   explicit ProdOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
     43 
     44   xla::ComputationDataHandle InitialValue(
     45       xla::ComputationBuilder* builder) override {
     46     return XlaHelpers::One(builder, input_type(0));
     47   }
     48 
     49   void BuildReducer(xla::ComputationBuilder* builder,
     50                     const xla::ComputationDataHandle& scalar_lhs,
     51                     const xla::ComputationDataHandle& scalar_rhs) override {
     52     builder->Mul(scalar_lhs, scalar_rhs);
     53   }
     54 };
     55 
     56 REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"),
     57                 ProdOp);
     58 
     59 class MinOp : public XlaReductionOp {
     60  public:
     61   explicit MinOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
     62 
     63   xla::ComputationDataHandle InitialValue(
     64       xla::ComputationBuilder* builder) override {
     65     xla::PrimitiveType type;
     66     TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
     67     return builder->ConstantLiteral(xla::Literal::MaxValue(type));
     68   }
     69 
     70   void BuildReducer(xla::ComputationBuilder* builder,
     71                     const xla::ComputationDataHandle& scalar_lhs,
     72                     const xla::ComputationDataHandle& scalar_rhs) override {
     73     builder->Min(scalar_lhs, scalar_rhs);
     74   }
     75 };
     76 
     77 REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp);
     78 
     79 class MaxOp : public XlaReductionOp {
     80  public:
     81   explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
     82 
     83   xla::ComputationDataHandle InitialValue(
     84       xla::ComputationBuilder* builder) override {
     85     xla::PrimitiveType type;
     86     TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
     87     return builder->ConstantLiteral(xla::Literal::MinValue(type));
     88   }
     89 
     90   void BuildReducer(xla::ComputationBuilder* builder,
     91                     const xla::ComputationDataHandle& scalar_lhs,
     92                     const xla::ComputationDataHandle& scalar_rhs) override {
     93     builder->Max(scalar_lhs, scalar_rhs);
     94   }
     95 };
     96 
     97 REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp);
     98 
     99 class MeanOp : public XlaReductionOp {
    100  public:
    101   explicit MeanOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
    102 
    103   void BuildReducer(xla::ComputationBuilder* builder,
    104                     const xla::ComputationDataHandle& scalar_lhs,
    105                     const xla::ComputationDataHandle& scalar_rhs) override {
    106     builder->Add(scalar_lhs, scalar_rhs);
    107   }
    108 
    109   xla::ComputationDataHandle BuildFinalizer(
    110       xla::ComputationBuilder* builder,
    111       const xla::ComputationDataHandle& reduce_output,
    112       int64 num_elements_reduced) override {
    113     auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0),
    114                                               num_elements_reduced);
    115     return builder->Div(reduce_output, divisor);
    116   }
    117 };
    118 
    119 REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"),
    120                 MeanOp);
    121 
    122 class AllOp : public XlaReductionOp {
    123  public:
    124   explicit AllOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
    125 
    126   xla::ComputationDataHandle InitialValue(
    127       xla::ComputationBuilder* builder) override {
    128     return builder->ConstantR0<bool>(true);
    129   }
    130 
    131   void BuildReducer(xla::ComputationBuilder* builder,
    132                     const xla::ComputationDataHandle& scalar_lhs,
    133                     const xla::ComputationDataHandle& scalar_rhs) override {
    134     builder->And(scalar_lhs, scalar_rhs);
    135   }
    136 };
    137 
    138 REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp);
    139 
    140 class AnyOp : public XlaReductionOp {
    141  public:
    142   explicit AnyOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
    143 
    144   xla::ComputationDataHandle InitialValue(
    145       xla::ComputationBuilder* builder) override {
    146     return builder->ConstantR0<bool>(false);
    147   }
    148 
    149   void BuildReducer(xla::ComputationBuilder* builder,
    150                     const xla::ComputationDataHandle& scalar_lhs,
    151                     const xla::ComputationDataHandle& scalar_rhs) override {
    152     builder->Or(scalar_lhs, scalar_rhs);
    153   }
    154 };
    155 
    156 REGISTER_XLA_OP(Name("Any").CompileTimeConstInput("reduction_indices"), AnyOp);
    157 
    158 }  // namespace
    159 }  // namespace tensorflow
    160