Home | History | Annotate | Download | only in lib
      1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/client/lib/comparators.h"
     17 
     18 #include <limits>
     19 #include <string>
     20 #include <vector>
     21 
     22 #include "absl/strings/str_cat.h"
     23 #include "absl/types/span.h"
     24 #include "tensorflow/compiler/xla/client/lib/constants.h"
     25 #include "tensorflow/compiler/xla/client/xla_builder.h"
     26 #include "tensorflow/compiler/xla/client/xla_computation.h"
     27 #include "tensorflow/compiler/xla/primitive_util.h"
     28 #include "tensorflow/compiler/xla/shape_util.h"
     29 #include "tensorflow/compiler/xla/types.h"
     30 #include "tensorflow/compiler/xla/xla_data.pb.h"
     31 
     32 namespace xla {
     33 namespace {
     34 
     35 using XlaOpGenerator = XlaOp (*)(const XlaOp&, const XlaOp&,
     36                                  absl::Span<const int64>);
     37 
     38 XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value,
     39                                             int64 bit_width) {
     40   PrimitiveType signed_type;
     41   PrimitiveType unsigned_type;
     42   XlaOp max_value;
     43   switch (bit_width) {
     44     case 16:
     45       max_value =
     46           ConstantR0(value.builder(),
     47                      static_cast<uint16>(std::numeric_limits<int16>::max()));
     48       signed_type = S16;
     49       unsigned_type = U16;
     50       break;
     51     case 32:
     52       max_value =
     53           ConstantR0(value.builder(),
     54                      static_cast<uint32>(std::numeric_limits<int32>::max()));
     55       signed_type = S32;
     56       unsigned_type = U32;
     57       break;
     58     case 64:
     59       max_value =
     60           ConstantR0(value.builder(),
     61                      static_cast<uint64>(std::numeric_limits<int64>::max()));
     62       signed_type = S64;
     63       unsigned_type = U64;
     64       break;
     65     default:
     66       return value.builder()->ReportError(
     67           InvalidArgument("Invalid bit width %lld for Comparator floating "
     68                           "point parameter.",
     69                           bit_width));
     70   }
     71   // Switch from a floating point value to a integer value in such a way that
     72   // when using the integer value to compare, we get the same result for normal
     73   // values, and -Nan is treated as the smallest value, and Nan is treated as
     74   // the largest value.
     75   // If f is a float, and
     76   // x = bit_cast<int32>(f);
     77   // y = x < 0 ? numeric_limits<int32>::max() - x : x;
     78   // then y is ordered as an int32 such that finite values have the obvious
     79   // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
     80   // and end of the ordering.
     81   // Note that in order to avoid -x to overflow, we calculate
     82   // numeric_limits<int32>::max() - x as unsigned, and then convert back to
     83   // signed.
     84   auto signed_value = BitcastConvertType(value, signed_type);
     85   auto unsigned_value = BitcastConvertType(value, unsigned_type);
     86   auto flipped_value =
     87       BitcastConvertType(Sub(max_value, unsigned_value), signed_type);
     88   auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type));
     89   return Select(is_negative, flipped_value, signed_value);
     90 }
     91 
     92 XlaComputation CreateScalarComparisonComputation(
     93     const string& name, const std::vector<PrimitiveType>& operand_types,
     94     XlaBuilder* builder, XlaOpGenerator generator) {
     95   // Create a default computation where we compare only the first two
     96   // parameters of type 'operand_types[0]'.
     97   auto b = builder->CreateSubBuilder(name);
     98   if (operand_types.empty()) {
     99     b->ReportError(InvalidArgument("operand_types should not be empty"));
    100     return b->BuildAndNoteError();
    101   }
    102 
    103   int64 parameter_count = 0;
    104   XlaOp first_lhs_param;
    105   XlaOp first_rhs_param;
    106 
    107   // For each type in 'operand_types' we create two parameters of this type. The
    108   // idea is that this computation can be used by n-ary Sort, and potentially
    109   // should support comparing also the other operands of sort. In this default
    110   // computation, however, we will not actually use any parameters except the
    111   // first two.
    112   for (auto operand_type : operand_types) {
    113     auto scalar_shape = ShapeUtil::MakeShape(operand_type, {});
    114     auto lhs_param = Parameter(b.get(), parameter_count * 2, scalar_shape,
    115                                absl::StrCat("p.", parameter_count, ".lhs"));
    116     auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape,
    117                                absl::StrCat("p.", parameter_count, ".rhs"));
    118     if (parameter_count == 0) {
    119       first_lhs_param = lhs_param;
    120       first_rhs_param = rhs_param;
    121     }
    122     ++parameter_count;
    123   }
    124   if (primitive_util::IsFloatingPointType(operand_types[0])) {
    125     PrimitiveType compare_type = operand_types[0];
    126     // Special-case handling for BF16. We currently do not support direct
    127     // comparisons with BF16, so we convert to F32 and then use the F32
    128     // comparison logic.
    129     if (compare_type == BF16) {
    130       compare_type = F32;
    131       first_lhs_param = ConvertElementType(first_lhs_param, F32);
    132       first_rhs_param = ConvertElementType(first_rhs_param, F32);
    133     }
    134     int64 bit_width = primitive_util::BitWidth(compare_type);
    135     first_lhs_param =
    136         BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width);
    137     first_rhs_param =
    138         BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width);
    139   }
    140   generator(first_lhs_param, first_rhs_param, {});
    141   return b->BuildAndNoteError();
    142 }
    143 }  // namespace
    144 
    145 // Creates a scalar less-than computation and returns it.
    146 XlaComputation CreateScalarLtComputation(
    147     const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
    148   return CreateScalarComparisonComputation("compare-less-than", operand_types,
    149                                            builder, Lt);
    150 }
    151 
    152 // Creates a scalar greater-than computation and returns it.
    153 XlaComputation CreateScalarGtComputation(
    154     const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
    155   return CreateScalarComparisonComputation("compare-greater-than",
    156                                            operand_types, builder, Gt);
    157 }
    158 
    159 }  // namespace xla
    160