1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/lib/util.h" 17 18 #include <memory> 19 #include <vector> 20 21 #include "tensorflow/compiler/xla/literal_util.h" 22 #include "tensorflow/compiler/xla/shape_util.h" 23 #include "tensorflow/compiler/xla/status_macros.h" 24 #include "tensorflow/compiler/xla/statusor.h" 25 #include "tensorflow/compiler/xla/util.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 28 namespace tensorflow { 29 30 xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, 31 const xla::Shape& shape) { 32 return builder->Broadcast( 33 builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), 34 xla::AsInt64Slice(shape.dimensions())); 35 } 36 37 xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, 38 xla::PrimitiveType type, double value) { 39 switch (type) { 40 case xla::F16: 41 return builder->ConstantR0<xla::half>(static_cast<xla::half>(value)); 42 break; 43 case xla::BF16: 44 return builder->ConstantR0<bfloat16>(static_cast<bfloat16>(value)); 45 break; 46 case xla::F32: 47 return builder->ConstantR0<float>(static_cast<float>(value)); 48 break; 49 case xla::F64: 50 return builder->ConstantR0<double>(value); 51 break; 52 case xla::C64: 53 return builder->ConstantR0<xla::complex64>(value); 54 break; 55 default: 56 LOG(FATAL) << "unhandled element type " << type; 57 } 58 } 59 60 xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, 61 xla::PrimitiveType type, 62 int64 value) { 63 xla::Literal literal; 64 switch (type) { 65 case xla::U8: 66 literal = std::move(*xla::Literal::CreateR0<uint8>(value)); 67 break; 68 case xla::U32: 69 literal = std::move(*xla::Literal::CreateR0<uint32>(value)); 70 break; 71 case xla::U64: 72 literal = std::move(*xla::Literal::CreateR0<uint64>(value)); 73 break; 74 case xla::S8: 75 literal = std::move(*xla::Literal::CreateR0<int8>(value)); 76 break; 77 case xla::S32: 78 literal = std::move(*xla::Literal::CreateR0<int32>(value)); 79 break; 80 case xla::S64: 81 literal = std::move(*xla::Literal::CreateR0<int64>(value)); 82 break; 83 case xla::F32: 84 literal = std::move(*xla::Literal::CreateR0<float>(value)); 85 break; 86 case xla::F64: 87 literal = std::move(*xla::Literal::CreateR0<double>(value)); 88 break; 89 case xla::C64: 90 literal = std::move(*xla::Literal::CreateR0<complex64>(value)); 91 break; 92 case xla::PRED: 93 LOG(FATAL) << "pred element type is not integral"; 94 case xla::S16: 95 case xla::U16: 96 LOG(FATAL) << "u16/s16 literals not yet implemented"; 97 case xla::BF16: 98 literal = std::move( 99 *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value))); 100 break; 101 case xla::F16: 102 literal = std::move( 103 *xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value))); 104 break; 105 case xla::TUPLE: 106 LOG(FATAL) << "tuple element type is not integral"; 107 case xla::OPAQUE: 108 LOG(FATAL) << "opaque element type is not integral"; 109 default: 110 LOG(FATAL) << "unhandled element type " << type; 111 } 112 return builder->ConstantLiteral(literal); 113 } 114 115 xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims( 116 xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, 117 gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end) { 118 TF_RET_CHECK(start.size() == end.size()); 119 int64 n_minor_dims = start.size(); 120 121 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x)); 122 123 const int64 n_dims = xla::ShapeUtil::Rank(*shape); 124 TF_RET_CHECK(n_minor_dims <= n_dims); 125 gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()), 126 /*pos=*/0, 127 /*len=*/n_dims - n_minor_dims); 128 129 // Prepends 0s in the major dim 130 std::vector<int64> padded_start(n_dims, 0); 131 std::copy(start.begin(), start.end(), 132 padded_start.begin() + major_dims.size()); 133 134 // Prepends the shape of the major dims. 135 std::vector<int64> padded_end(n_dims); 136 std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); 137 std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); 138 139 std::vector<int64> strides(n_dims, 1); 140 return builder->Slice(x, padded_start, padded_end, strides); 141 } 142 143 xla::StatusOr<xla::ComputationDataHandle> UpdateSlice( 144 xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, 145 const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) { 146 // TODO(phawkins): make int64 work on all backends, remove the int32 cast. 147 std::vector<int32> start_as_int32(start.begin(), start.end()); 148 return builder->DynamicUpdateSlice( 149 x, update, builder->ConstantR1<int32>(start_as_int32)); 150 } 151 152 xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims( 153 xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, 154 const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) { 155 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x)); 156 const int64 n_dims = xla::ShapeUtil::Rank(*shape); 157 const int64 n_minor_dims = start.size(); 158 TF_RET_CHECK(n_minor_dims <= n_dims); 159 std::vector<int64> padded_start(n_dims, 0); 160 std::copy(start.begin(), start.end(), 161 padded_start.begin() + (n_dims - n_minor_dims)); 162 return UpdateSlice(builder, x, update, padded_start); 163 } 164 165 xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims( 166 xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { 167 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x)); 168 const int64 n_dims = xla::ShapeUtil::Rank(*shape); 169 TF_RET_CHECK(n_dims >= 2); 170 std::vector<int64> permutation(n_dims); 171 std::iota(permutation.begin(), permutation.end(), 0); 172 std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); 173 return builder->Transpose(x, permutation); 174 } 175 176 } // namespace tensorflow 177