Home | History | Annotate | Download | only in lib
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
     17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
     18 
     19 #include <type_traits>
     20 
     21 #include "tensorflow/compiler/xla/client/xla_builder.h"
     22 #include "tensorflow/compiler/xla/primitive_util.h"
     23 #include "tensorflow/compiler/xla/types.h"
     24 #include "tensorflow/compiler/xla/xla_data.pb.h"
     25 
     26 namespace xla {
     27 
     28 // Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is
     29 // determined at C++ run-time, rather than C++ compile-time.
     30 // If 'value' is floating point but 'type' is not, or if 'value' is complex but
     31 // 'type' is not, an error will be returned. This is to catch accidental
     32 // truncation; in such cases, use an explicit cast.
     33 template <typename T>
     34 XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
     35   if (std::is_floating_point<T>::value &&
     36       !(primitive_util::IsFloatingPointType(type) ||
     37         primitive_util::IsComplexType(type))) {
     38     return builder->ReportError(InvalidArgument(
     39         "Invalid cast from floating point type to %s in ConstantR0WithType.",
     40         PrimitiveType_Name(type)));
     41   }
     42   if (std::is_same<T, complex64>::value &&
     43       !primitive_util::IsComplexType(type)) {
     44     return builder->ReportError(InvalidArgument(
     45         "Invalid cast from complex type to %s in ConstantR0WithType.",
     46         PrimitiveType_Name(type)));
     47   }
     48   switch (type) {
     49     case F16:
     50       return ConstantR0<half>(builder, static_cast<half>(value));
     51     case BF16:
     52       return ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value));
     53     case F32:
     54       return ConstantR0<float>(builder, static_cast<float>(value));
     55     case F64:
     56       return ConstantR0<double>(builder, static_cast<double>(value));
     57     case C64:
     58       return ConstantR0<complex64>(builder, static_cast<complex64>(value));
     59     case C128:
     60       return ConstantR0<complex128>(builder, static_cast<complex128>(value));
     61     case U8:
     62       return ConstantR0<uint8>(builder, static_cast<uint8>(value));
     63     case U32:
     64       return ConstantR0<uint32>(builder, static_cast<uint32>(value));
     65     case U64:
     66       return ConstantR0<uint64>(builder, static_cast<uint64>(value));
     67     case S8:
     68       return ConstantR0<int8>(builder, static_cast<int8>(value));
     69     case S32:
     70       return ConstantR0<int32>(builder, static_cast<int32>(value));
     71     case S64:
     72       return ConstantR0<int64>(builder, static_cast<int64>(value));
     73     default:
     74       return builder->ReportError(
     75           InvalidArgument("Invalid type for ConstantR0WithType (%s).",
     76                           PrimitiveType_Name(type)));
     77   }
     78 }
     79 
     80 // Returns a scalar containing 'value' cast to the same run-time type as
     81 // 'prototype'.
     82 // If 'value' is floating point but 'prototype' is not, or if 'value' is complex
     83 // 'prototype' is not, an error will be returned.
     84 template <typename T>
     85 XlaOp ScalarLike(XlaOp prototype, T value) {
     86   XlaBuilder* builder = prototype.builder();
     87   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     88     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
     89     return ConstantR0WithType(builder, shape.element_type(), value);
     90   });
     91 }
     92 
     93 // Returns an array or scalar containing copies of `value` cast to the same
     94 // run-type type as `prototype` and broadcast to the same dimensions as
     95 // `prototype`.
     96 //
     97 // If `prototype` is not a scalar or array, returns an error.
     98 template <typename T>
     99 XlaOp FullLike(XlaOp prototype, T value) {
    100   XlaBuilder* builder = prototype.builder();
    101   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
    102     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
    103     if (ShapeUtil::IsScalar(shape) || shape.IsArray()) {
    104       return Broadcast(ScalarLike(prototype, value), shape.dimensions());
    105     } else {
    106       return InvalidArgument(
    107           "Prototype shape for BroadcastConstantLike must be a scalar or "
    108           "array, but was %s",
    109           shape.ToString());
    110     }
    111   });
    112 }
    113 
    114 // Returns a scalar with value '0' of 'type'.
    115 XlaOp Zero(XlaBuilder* builder, PrimitiveType type);
    116 
    117 // Returns a zero-filled tensor with shape `shape`.
    118 XlaOp Zeros(XlaBuilder* builder, const Shape& shape);
    119 
    120 // Returns a zero-filled tensor with the same shape as `prototype`.
    121 XlaOp ZerosLike(XlaOp prototype);
    122 
    123 // Returns a scalar with value '1' of 'type'.
    124 XlaOp One(XlaBuilder* builder, PrimitiveType type);
    125 
    126 // Returns the machine epsilon for floating-point type `type`, i.e.,
    127 // the difference between 1.0 and the next representable value.
    128 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type);
    129 
    130 // Returns the minimum representable finite or infinite value for 'type'.
    131 // Returns '-inf' for floating-point types.
    132 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type);
    133 
    134 // Returns the minimum representable finite value for 'type'. For a floating
    135 // point type, this is equal to -MaxFiniteValue().
    136 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type);
    137 
    138 // Returns the minimum positive normal value for floating-point type `type`.
    139 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type);
    140 
    141 // Returns the maximum representable finite or infinite value for 'type'.
    142 // Returns 'inf' for floating-point types.
    143 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type);
    144 
    145 // Returns the maximum representable finite value for 'type'.
    146 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type);
    147 
    148 // Returns a nan for the given type.  Only valid for real-valued fp types.
    149 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type);
    150 
    151 }  // namespace xla
    152 
    153 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
    154