1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 17 18 #include "llvm/IR/Constants.h" 19 #include "llvm/IR/Instructions.h" 20 #include "tensorflow/compiler/xla/layout_util.h" 21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 22 #include "tensorflow/compiler/xla/shape_util.h" 23 #include "tensorflow/compiler/xla/statusor.h" 24 #include "tensorflow/compiler/xla/util.h" 25 #include "tensorflow/compiler/xla/xla_data.pb.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace xla { 30 namespace llvm_ir { 31 32 IrArray::Index::Index(llvm::Value* linear, const Shape& shape, 33 llvm::IRBuilder<>* ir_builder) 34 : multidim_(ShapeUtil::Rank(shape)), 35 linear_(linear), 36 layout_(shape.layout()), 37 dims_(shape.dimensions().begin(), shape.dimensions().end()) { 38 CHECK(LayoutUtil::HasLayout(shape)) 39 << "Shape " << ShapeUtil::HumanStringWithLayout(shape) 40 << " should have a layout."; 41 int64 divisor = 1; 42 for (int64 i = 0; i < layout_.minor_to_major_size(); ++i) { 43 int64 dimension = layout_.minor_to_major(i); 44 int64 size_of_current_dimension = shape.dimensions(dimension); 45 46 // If i is not the last dimension, compute 47 // (linear_index / divisor) % current_dimension. 48 // If i is the last dimension, we can skip the mod, because we assume that 49 // linear is in bounds. 50 // 51 // TODO(jlebar): We could add bounds checks here and elsewhere in this file, 52 // guarded under some sort of xla-memcheck flag. This might be particularly 53 // useful because cuda-memcheck can't help us much in XLA: Most of our 54 // memory lives in one big allocation, so cuda-memcheck can't detect 55 // out-of-bounds accesses. 56 auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)); 57 if (i < layout_.minor_to_major_size() - 1) { 58 multidim_[dimension] = ir_builder->CreateURem( 59 quot, ir_builder->getInt64(size_of_current_dimension)); 60 } else { 61 multidim_[dimension] = quot; 62 } 63 divisor *= size_of_current_dimension; 64 } 65 } 66 67 IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim, 68 llvm::Value* linear, const Shape& shape) 69 : multidim_(multidim.begin(), multidim.end()), 70 linear_(linear), 71 layout_(shape.layout()), 72 dims_(shape.dimensions().begin(), shape.dimensions().end()) { 73 CHECK_EQ(shape.dimensions_size(), multidim.size()); 74 CHECK(LayoutUtil::HasLayout(shape)) 75 << "Shape " << ShapeUtil::HumanStringWithLayout(shape) 76 << " should have a layout."; 77 } 78 79 IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim, 80 const Shape& shape, llvm::IRBuilder<>* ir_builder) 81 : multidim_(multidim.begin(), multidim.end()), 82 layout_(shape.layout()), 83 dims_(shape.dimensions().begin(), shape.dimensions().end()) { 84 CHECK_EQ(shape.dimensions_size(), multidim.size()); 85 CHECK(LayoutUtil::HasLayout(shape)); 86 linear_ = Linearize(AsInt64Slice(shape.dimensions()), ir_builder); 87 } 88 89 IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) 90 : base_ptr_(base_ptr), shape_(&shape) { 91 TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); 92 CHECK(base_ptr_->getType()->isPointerTy()); 93 int depth = 0; 94 element_type_ = 95 llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType(); 96 while (llvm::ArrayType* array_type = 97 llvm::dyn_cast<llvm::ArrayType>(element_type_)) { 98 element_type_ = array_type->getElementType(); 99 ++depth; 100 } 101 102 if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) { 103 DCHECK(depth == 1 || depth == 0) << depth; 104 } else { 105 DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString(); 106 } 107 } 108 109 // Returns whether given linear index valid on given shape. 110 bool IrArray::Index::LinearValidOnShape(const Shape& a) const { 111 auto b = ShapeUtil::MakeShape(PRED /* irrelevant */, dims_); 112 *b.mutable_layout() = layout_; 113 return linear_ != nullptr && 114 ContainersEqual( 115 ShapeUtil::StripDegenerateDimensions(a).dimensions(), 116 ShapeUtil::StripDegenerateDimensions(b).dimensions()) && 117 LayoutUtil::Equal(ShapeUtil::StripDegenerateDimensions(a).layout(), 118 ShapeUtil::StripDegenerateDimensions(b).layout()); 119 } 120 121 IrArray::Index IrArray::Index::SourceIndexOfReshape( 122 const Shape& output_shape, const Shape& input_shape, 123 llvm::IRBuilder<>* builder) const { 124 const auto& target_index = *this; 125 CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape)); 126 std::vector<std::pair<int64, int64>> common_factors = 127 CommonFactors(AsInt64Slice(input_shape.dimensions()), 128 AsInt64Slice(output_shape.dimensions())); 129 std::vector<llvm::Value*> source_multidim_index( 130 ShapeUtil::Rank(input_shape), 131 llvm::UndefValue::get(builder->getInt64Ty())); 132 // We compute the source indices in each common factor from only the target 133 // indices in the same common factor. 134 for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { 135 llvm::Value* logical_linear_index = 136 Index(tensorflow::gtl::ArraySlice<llvm::Value*>( 137 multidim_, common_factors[k].second, 138 common_factors[k + 1].second - common_factors[k].second)) 139 .Linearize( 140 tensorflow::gtl::ArraySlice<int64>( 141 AsInt64Slice(output_shape.dimensions()), 142 common_factors[k].second, 143 common_factors[k + 1].second - common_factors[k].second), 144 builder); 145 // Delinearizes logical_linear_index for the source array in row-major 146 // collapsed order. The first rank-1 indices are the remainder of the 147 // linear index by each dimension size. 148 for (int64 i = common_factors[k + 1].first - 1; 149 i >= common_factors[k].first; --i) { 150 llvm::Value* divisor = builder->getInt64(input_shape.dimensions(i)); 151 if (input_shape.dimensions(i) == 1) { 152 source_multidim_index[i] = builder->getInt64(0); 153 } else if (i == common_factors[k].first) { 154 source_multidim_index[i] = logical_linear_index; 155 } else { 156 source_multidim_index[i] = 157 builder->CreateURem(logical_linear_index, divisor); 158 } 159 logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor); 160 } 161 } 162 163 if (linear() != nullptr && 164 ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { 165 return Index(source_multidim_index, linear(), input_shape); 166 } 167 return Index(source_multidim_index); 168 } 169 170 IrArray::Index IrArray::Index::SourceIndexOfSlice( 171 const Shape& shape, tensorflow::gtl::ArraySlice<int64> starts, 172 tensorflow::gtl::ArraySlice<int64> strides, 173 llvm::IRBuilder<>* builder) const { 174 Index source_index(multidim_.size()); 175 for (int i = 0; i < multidim_.size(); ++i) { 176 int64 stride = strides[i]; 177 auto type = multidim_[i]->getType(); 178 179 if (stride != 1) { 180 source_index[i] = builder->CreateAdd( 181 builder->CreateMul(multidim_[i], 182 llvm::ConstantInt::get(type, stride)), 183 llvm::ConstantInt::get(type, starts[i])); 184 } else { 185 source_index[i] = builder->CreateAdd( 186 multidim_[i], llvm::ConstantInt::get(type, starts[i])); 187 } 188 } 189 return source_index; 190 } 191 192 IrArray::Index IrArray::Index::SourceIndexOfTranspose( 193 const Shape& shape, const Shape& operand_shape, 194 tensorflow::gtl::ArraySlice<int64> dimension_mapping, 195 llvm::IRBuilder<>* builder) const { 196 std::vector<llvm::Value*> operand_multidim_index = 197 Permute(dimension_mapping, multidim()); 198 if (linear() != nullptr && 199 ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) { 200 return Index(operand_multidim_index, linear(), operand_shape); 201 } 202 return Index(operand_multidim_index); 203 } 204 205 llvm::Value* IrArray::Index::Linearize( 206 tensorflow::gtl::ArraySlice<int64> dimensions, 207 llvm::IRBuilder<>* builder) const { 208 // Each dimension is multiplied by the product of the sizes of all 209 // earlier dimensions and added to the accumulator logical_linear_index. 210 llvm::Value* logical_linear_index = builder->getInt64(0); 211 int64 multiplier = 1; 212 for (ssize_t i = size() - 1; i >= 0; --i) { 213 llvm::Value* addend = 214 builder->CreateMul((*this)[i], builder->getInt64(multiplier), "", 215 /*HasNUW=*/true, /*HasNSW=*/true); 216 logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "", 217 /*HasNUW=*/true, /*HasNSW=*/true); 218 multiplier *= dimensions[i]; 219 } 220 return logical_linear_index; 221 } 222 223 llvm::Value* IrArray::EmitArrayElementAddress( 224 const IrArray::Index& index, llvm::IRBuilder<>* ir_builder, 225 tensorflow::StringPiece name) const { 226 if (ShapeUtil::IsScalar(*shape_)) { 227 // Special handling of scalars: a scalar pretends to have the same value for 228 // every index, thus effectively implementing broadcasting of its value 229 // over higher-rank arrays. 230 return base_ptr_; 231 } 232 CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); 233 234 std::vector<llvm::Value*> actual_index; 235 bool is_implicit_broadcast = false; 236 // We perform broadcasting when the operand shape has dimension(s) of size 237 // 1. In this case we fix the index value for that dimension to zero. This 238 // effectively broadcasts along this dimension. 239 for (int64 i = 0; i < index.size(); ++i) { 240 auto dim = shape_->dimensions(i); 241 actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); 242 is_implicit_broadcast |= dim == 1; 243 } 244 245 if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) { 246 llvm::Module* module = 247 ir_builder->GetInsertBlock()->getParent()->getParent(); 248 return ir_builder->CreateInBoundsGEP( 249 ir_builder->CreateBitCast( 250 base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), module) 251 ->getPointerTo()), 252 {index.linear()}, llvm_ir::AsStringRef(name)); 253 } 254 255 // "base_ptr_" has the type of "<ir_type_for_its_shape>*" 256 // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element 257 // should be computed by 258 // 259 // getelementptr base_ptr_, 0, most major index, ..., most minor index 260 std::vector<llvm::Value*> gep_indices(1, ir_builder->getInt64(0)); 261 for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) { 262 int64 dimension = LayoutUtil::Major(shape_->layout(), i); 263 gep_indices.push_back(actual_index[dimension]); 264 } 265 return ir_builder->CreateInBoundsGEP(base_ptr_, gep_indices, 266 llvm_ir::AsStringRef(name)); 267 } 268 269 void IrArray::AnnotateLoadStoreInstructionWithMetadata( 270 llvm::Instruction* instruction) const { 271 CHECK(llvm::isa<llvm::LoadInst>(instruction) || 272 llvm::isa<llvm::StoreInst>(instruction)); 273 CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_) 274 << "Trying to create a store to an invariant IRArray."; 275 276 for (const auto& kind_md_pair : metadata_) { 277 instruction->setMetadata(kind_md_pair.first, kind_md_pair.second); 278 } 279 } 280 281 llvm::Value* IrArray::EmitReadArrayElement(const Index& index, 282 llvm::IRBuilder<>* ir_builder, 283 tensorflow::StringPiece name) const { 284 llvm::Value* element_address = 285 EmitArrayElementAddress(index, ir_builder, name); 286 llvm::LoadInst* load = ir_builder->CreateLoad(element_address); 287 AnnotateLoadStoreInstructionWithMetadata(load); 288 return load; 289 } 290 291 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, 292 llvm::IRBuilder<>* ir_builder) const { 293 llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder); 294 llvm::StoreInst* store = ir_builder->CreateStore(value, element_address); 295 AnnotateLoadStoreInstructionWithMetadata(store); 296 } 297 298 IrArray IrArray::CastToShape(const Shape& new_shape, 299 llvm::IRBuilder<>* ir_builder) const { 300 llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); 301 llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module); 302 return IrArray( 303 ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), 304 new_shape); 305 } 306 307 /* static */ IrArray::Index IrArray::BumpIndex(const Index& index, 308 int64 which_dimension, 309 int64 addend, 310 llvm::IRBuilder<>* ir_builder) { 311 Index new_index = index; 312 new_index[which_dimension] = ir_builder->CreateAdd( 313 index[which_dimension], ir_builder->getInt64(addend), "", /*HasNUW=*/true, 314 /*HasNSW=*/true); 315 return new_index; 316 } 317 318 } // namespace llvm_ir 319 } // namespace xla 320