1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific reduction Ops. 17 18 #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" 19 #include "tensorflow/compiler/tf2xla/type_util.h" 20 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/literal_util.h" 23 #include "tensorflow/core/framework/kernel_def_builder.h" 24 25 namespace tensorflow { 26 namespace { 27 28 class SumOp : public XlaReductionOp { 29 public: 30 explicit SumOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} 31 void BuildReducer(xla::ComputationBuilder* builder, 32 const xla::ComputationDataHandle& scalar_lhs, 33 const xla::ComputationDataHandle& scalar_rhs) override { 34 builder->Add(scalar_lhs, scalar_rhs); 35 } 36 }; 37 38 REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp); 39 40 class ProdOp : public XlaReductionOp { 41 public: 42 explicit ProdOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} 43 44 xla::ComputationDataHandle InitialValue( 45 xla::ComputationBuilder* builder) override { 46 return XlaHelpers::One(builder, input_type(0)); 47 } 48 49 void BuildReducer(xla::ComputationBuilder* builder, 50 const xla::ComputationDataHandle& scalar_lhs, 51 const xla::ComputationDataHandle& scalar_rhs) override { 52 builder->Mul(scalar_lhs, scalar_rhs); 53 } 54 }; 55 56 REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"), 57 ProdOp); 58 59 class MinOp : public XlaReductionOp { 60 public: 61 explicit MinOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} 62 63 xla::ComputationDataHandle InitialValue( 64 xla::ComputationBuilder* builder) override { 65 xla::PrimitiveType type; 66 TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); 67 return builder->ConstantLiteral(xla::Literal::MaxValue(type)); 68 } 69 70 void BuildReducer(xla::ComputationBuilder* builder, 71 const xla::ComputationDataHandle& scalar_lhs, 72 const xla::ComputationDataHandle& scalar_rhs) override { 73 builder->Min(scalar_lhs, scalar_rhs); 74 } 75 }; 76 77 REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp); 78 79 class MaxOp : public XlaReductionOp { 80 public: 81 explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} 82 83 xla::ComputationDataHandle InitialValue( 84 xla::ComputationBuilder* builder) override { 85 xla::PrimitiveType type; 86 TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); 87 return builder->ConstantLiteral(xla::Literal::MinValue(type)); 88 } 89 90 void BuildReducer(xla::ComputationBuilder* builder, 91 const xla::ComputationDataHandle& scalar_lhs, 92 const xla::ComputationDataHandle& scalar_rhs) override { 93 builder->Max(scalar_lhs, scalar_rhs); 94 } 95 }; 96 97 REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp); 98 99 class MeanOp : public XlaReductionOp { 100 public: 101 explicit MeanOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} 102 103 void BuildReducer(xla::ComputationBuilder* builder, 104 const xla::ComputationDataHandle& scalar_lhs, 105 const xla::ComputationDataHandle& scalar_rhs) override { 106 builder->Add(scalar_lhs, scalar_rhs); 107 } 108 109 xla::ComputationDataHandle BuildFinalizer( 110 xla::ComputationBuilder* builder, 111 const xla::ComputationDataHandle& reduce_output, 112 int64 num_elements_reduced) override { 113 auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), 114 num_elements_reduced); 115 return builder->Div(reduce_output, divisor); 116 } 117 }; 118 119 REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"), 120 MeanOp); 121 122 class AllOp : public XlaReductionOp { 123 public: 124 explicit AllOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} 125 126 xla::ComputationDataHandle InitialValue( 127 xla::ComputationBuilder* builder) override { 128 return builder->ConstantR0<bool>(true); 129 } 130 131 void BuildReducer(xla::ComputationBuilder* builder, 132 const xla::ComputationDataHandle& scalar_lhs, 133 const xla::ComputationDataHandle& scalar_rhs) override { 134 builder->And(scalar_lhs, scalar_rhs); 135 } 136 }; 137 138 REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp); 139 140 class AnyOp : public XlaReductionOp { 141 public: 142 explicit AnyOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} 143 144 xla::ComputationDataHandle InitialValue( 145 xla::ComputationBuilder* builder) override { 146 return builder->ConstantR0<bool>(false); 147 } 148 149 void BuildReducer(xla::ComputationBuilder* builder, 150 const xla::ComputationDataHandle& scalar_lhs, 151 const xla::ComputationDataHandle& scalar_rhs) override { 152 builder->Or(scalar_lhs, scalar_rhs); 153 } 154 }; 155 156 REGISTER_XLA_OP(Name("Any").CompileTimeConstInput("reduction_indices"), AnyOp); 157 158 } // namespace 159 } // namespace tensorflow 160