1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ 17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ 18 19 #include <type_traits> 20 21 #include "tensorflow/compiler/xla/client/xla_builder.h" 22 #include "tensorflow/compiler/xla/primitive_util.h" 23 #include "tensorflow/compiler/xla/types.h" 24 #include "tensorflow/compiler/xla/xla_data.pb.h" 25 26 namespace xla { 27 28 // Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is 29 // determined at C++ run-time, rather than C++ compile-time. 30 // If 'value' is floating point but 'type' is not, or if 'value' is complex but 31 // 'type' is not, an error will be returned. This is to catch accidental 32 // truncation; in such cases, use an explicit cast. 33 template <typename T> 34 XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { 35 if (std::is_floating_point<T>::value && 36 !(primitive_util::IsFloatingPointType(type) || 37 primitive_util::IsComplexType(type))) { 38 return builder->ReportError(InvalidArgument( 39 "Invalid cast from floating point type to %s in ConstantR0WithType.", 40 PrimitiveType_Name(type))); 41 } 42 if (std::is_same<T, complex64>::value && 43 !primitive_util::IsComplexType(type)) { 44 return builder->ReportError(InvalidArgument( 45 "Invalid cast from complex type to %s in ConstantR0WithType.", 46 PrimitiveType_Name(type))); 47 } 48 switch (type) { 49 case F16: 50 return ConstantR0<half>(builder, static_cast<half>(value)); 51 case BF16: 52 return ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value)); 53 case F32: 54 return ConstantR0<float>(builder, static_cast<float>(value)); 55 case F64: 56 return ConstantR0<double>(builder, static_cast<double>(value)); 57 case C64: 58 return ConstantR0<complex64>(builder, static_cast<complex64>(value)); 59 case C128: 60 return ConstantR0<complex128>(builder, static_cast<complex128>(value)); 61 case U8: 62 return ConstantR0<uint8>(builder, static_cast<uint8>(value)); 63 case U32: 64 return ConstantR0<uint32>(builder, static_cast<uint32>(value)); 65 case U64: 66 return ConstantR0<uint64>(builder, static_cast<uint64>(value)); 67 case S8: 68 return ConstantR0<int8>(builder, static_cast<int8>(value)); 69 case S32: 70 return ConstantR0<int32>(builder, static_cast<int32>(value)); 71 case S64: 72 return ConstantR0<int64>(builder, static_cast<int64>(value)); 73 default: 74 return builder->ReportError( 75 InvalidArgument("Invalid type for ConstantR0WithType (%s).", 76 PrimitiveType_Name(type))); 77 } 78 } 79 80 // Returns a scalar containing 'value' cast to the same run-time type as 81 // 'prototype'. 82 // If 'value' is floating point but 'prototype' is not, or if 'value' is complex 83 // 'prototype' is not, an error will be returned. 84 template <typename T> 85 XlaOp ScalarLike(XlaOp prototype, T value) { 86 XlaBuilder* builder = prototype.builder(); 87 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { 88 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); 89 return ConstantR0WithType(builder, shape.element_type(), value); 90 }); 91 } 92 93 // Returns an array or scalar containing copies of `value` cast to the same 94 // run-type type as `prototype` and broadcast to the same dimensions as 95 // `prototype`. 96 // 97 // If `prototype` is not a scalar or array, returns an error. 98 template <typename T> 99 XlaOp FullLike(XlaOp prototype, T value) { 100 XlaBuilder* builder = prototype.builder(); 101 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { 102 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); 103 if (ShapeUtil::IsScalar(shape) || shape.IsArray()) { 104 return Broadcast(ScalarLike(prototype, value), shape.dimensions()); 105 } else { 106 return InvalidArgument( 107 "Prototype shape for BroadcastConstantLike must be a scalar or " 108 "array, but was %s", 109 shape.ToString()); 110 } 111 }); 112 } 113 114 // Returns a scalar with value '0' of 'type'. 115 XlaOp Zero(XlaBuilder* builder, PrimitiveType type); 116 117 // Returns a zero-filled tensor with shape `shape`. 118 XlaOp Zeros(XlaBuilder* builder, const Shape& shape); 119 120 // Returns a zero-filled tensor with the same shape as `prototype`. 121 XlaOp ZerosLike(XlaOp prototype); 122 123 // Returns a scalar with value '1' of 'type'. 124 XlaOp One(XlaBuilder* builder, PrimitiveType type); 125 126 // Returns the machine epsilon for floating-point type `type`, i.e., 127 // the difference between 1.0 and the next representable value. 128 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type); 129 130 // Returns the minimum representable finite or infinite value for 'type'. 131 // Returns '-inf' for floating-point types. 132 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); 133 134 // Returns the minimum representable finite value for 'type'. For a floating 135 // point type, this is equal to -MaxFiniteValue(). 136 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); 137 138 // Returns the minimum positive normal value for floating-point type `type`. 139 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type); 140 141 // Returns the maximum representable finite or infinite value for 'type'. 142 // Returns 'inf' for floating-point types. 143 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); 144 145 // Returns the maximum representable finite value for 'type'. 146 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); 147 148 // Returns a nan for the given type. Only valid for real-valued fp types. 149 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type); 150 151 } // namespace xla 152 153 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ 154