Home | History | Annotate | Download | only in lib
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/client/lib/constants.h"
     17 
     18 #include "tensorflow/compiler/xla/literal_util.h"
     19 #include "tensorflow/compiler/xla/util.h"
     20 
     21 namespace xla {
     22 
     23 XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
     24   return ConstantLiteral(builder, LiteralUtil::Zero(type));
     25 }
     26 
     27 XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
     28   return Broadcast(Zero(builder, shape.element_type()),
     29                    AsInt64Slice(shape.dimensions()));
     30 }
     31 
     32 XlaOp ZerosLike(XlaOp prototype) {
     33   XlaBuilder* builder = prototype.builder();
     34   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     35     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
     36     return Zeros(builder, shape);
     37   });
     38 }
     39 
     40 XlaOp One(XlaBuilder* builder, PrimitiveType type) {
     41   return ConstantLiteral(builder, LiteralUtil::One(type));
     42 }
     43 
     44 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
     45   switch (type) {
     46     case F16:
     47       return ConstantR0<Eigen::half>(
     48           builder,
     49           static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
     50     case BF16:
     51       return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
     52     case F32:
     53       return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
     54     case F64:
     55       return ConstantR0<double>(builder,
     56                                 std::numeric_limits<double>::epsilon());
     57     default:
     58       return builder->ReportError(InvalidArgument(
     59           "Invalid type for Epsilon (%s).", PrimitiveType_Name(type)));
     60   }
     61 }
     62 
     63 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
     64   return ConstantLiteral(builder, LiteralUtil::MinValue(type));
     65 }
     66 
     67 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
     68   switch (type) {
     69     case F16:
     70       return ConstantR0<Eigen::half>(builder,
     71                                      Eigen::NumTraits<Eigen::half>::lowest());
     72     case BF16:
     73       return ConstantR0<bfloat16>(builder, bfloat16::lowest());
     74     case F32:
     75       return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
     76     case F64:
     77       return ConstantR0<double>(builder, -std::numeric_limits<double>::max());
     78     default:
     79       return MinValue(builder, type);
     80   }
     81 }
     82 
     83 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
     84   switch (type) {
     85     case F16:
     86       return ConstantR0<Eigen::half>(builder,
     87                                      std::numeric_limits<Eigen::half>::min());
     88     case BF16:
     89       return ConstantR0<bfloat16>(builder, bfloat16::min_positive_normal());
     90     case F32:
     91       return ConstantR0<float>(builder, std::numeric_limits<float>::min());
     92     case F64:
     93       return ConstantR0<double>(builder, std::numeric_limits<double>::min());
     94     default:
     95       return builder->ReportError(
     96           InvalidArgument("Invalid type for MinPositiveNormalValue (%s).",
     97                           PrimitiveType_Name(type)));
     98   }
     99 }
    100 
    101 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
    102   return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
    103 }
    104 
    105 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
    106   switch (type) {
    107     case F16:
    108       return ConstantR0<Eigen::half>(builder,
    109                                      Eigen::NumTraits<Eigen::half>::highest());
    110     case BF16:
    111       return ConstantR0<bfloat16>(builder, bfloat16::highest());
    112     case F32:
    113       return ConstantR0<float>(builder, std::numeric_limits<float>::max());
    114     case F64:
    115       return ConstantR0<double>(builder, std::numeric_limits<double>::max());
    116     default:
    117       return MaxValue(builder, type);
    118   }
    119 }
    120 
    121 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
    122   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
    123     switch (type) {
    124       case F16:
    125         return ConstantR0<Eigen::half>(
    126             builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
    127       case BF16:
    128         return ConstantR0<bfloat16>(
    129             builder, bfloat16(std::numeric_limits<float>::quiet_NaN()));
    130       case F32:
    131         return ConstantR0<float>(builder,
    132                                  std::numeric_limits<float>::quiet_NaN());
    133       case F64:
    134         return ConstantR0<double>(builder,
    135                                   std::numeric_limits<double>::quiet_NaN());
    136       default:
    137         return InvalidArgument(
    138             "Operand to NanValue was %s, but must be a real-valued "
    139             "floating-point type.",
    140             PrimitiveType_Name(type));
    141     }
    142   });
    143 }
    144 
    145 }  // namespace xla
    146