1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/tests/test_utils.h" 17 #include "tensorflow/compiler/xla/primitive_util.h" 18 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 19 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 20 #include "tensorflow/compiler/xla/service/transfer_manager.h" 21 22 namespace xla { 23 24 namespace { 25 26 template <typename FloatT> 27 void PopulateWithRandomFloatingPointData(Literal* literal, 28 std::minstd_rand0* engine) { 29 CHECK_EQ(literal->shape().element_type(), 30 primitive_util::NativeToPrimitiveType<FloatT>()); 31 // Create uniform numbers between 1 and 1.125 to avoid creating denormal 32 // numbers. 33 std::uniform_real_distribution<FloatT> generator(1.0f, 1.125f); 34 const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; 35 TF_CHECK_OK(literal->Populate<FloatT>( 36 [&](tensorflow::gtl::ArraySlice<int64> indices) { 37 // Generate a random uniform number from -0.0625 and 0.0625 and bias it 38 // with a position dependent number with mean 0.037109375. These number 39 // should allow for long chains of accumulation without being too close 40 // to zero or too large to accumulate all numbers accurately. Only do 41 // this for large literals where the number of elements is much greater 42 // than 47 otherwise only negative values are produced. 43 // 44 // The value is positionally biased using a product of the indices. Add 45 // one to each index value to avoid collapsing to zero if any of the 46 // indices are zero. 47 int64 index_product = 1; 48 for (int64 i : indices) { 49 index_product *= (1 + i); 50 } 51 const int64 negative_bias = should_index_bias ? 47 : 0; 52 FloatT index_bias = 53 static_cast<FloatT>(index_product % 113 - negative_bias) / 54 static_cast<FloatT>(256.0f); 55 return (generator(*engine) - 1.0625) + index_bias; 56 })); 57 } 58 59 // The standard library does not have a case for bfloat16, unsurprisingly, so we 60 // handle that one specially. 61 template <> 62 void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal, 63 std::minstd_rand0* engine) { 64 CHECK_EQ(literal->shape().element_type(), BF16); 65 std::uniform_real_distribution<float> generator(-0.9f, 1.0f); 66 TF_CHECK_OK(literal->Populate<bfloat16>( 67 [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { 68 return static_cast<bfloat16>(generator(*engine)); 69 })); 70 } 71 72 template <typename IntT> 73 void PopulateWithRandomIntegralData(Literal* literal, 74 std::minstd_rand0* engine) { 75 CHECK_EQ(literal->shape().element_type(), 76 primitive_util::NativeToPrimitiveType<IntT>()); 77 std::uniform_int_distribution<IntT> generator( 78 std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max()); 79 TF_CHECK_OK(literal->Populate<IntT>( 80 [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { 81 return generator(*engine); 82 })); 83 } 84 85 // Similar to MakeFakeLiteral but takes a random number generator engine to 86 // enable reusing the engine across randomly generated literals. 87 StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( 88 const Shape& shape, std::minstd_rand0* engine) { 89 if (ShapeUtil::IsTuple(shape)) { 90 std::vector<std::unique_ptr<Literal>> elements; 91 for (const Shape& element_shape : shape.tuple_shapes()) { 92 TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element, 93 MakeFakeLiteralInternal(element_shape, engine)); 94 elements.push_back(std::move(element)); 95 } 96 return Literal::MakeTupleOwned(std::move(elements)); 97 } 98 std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape); 99 switch (shape.element_type()) { 100 case BF16: 101 PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine); 102 break; 103 case F32: 104 PopulateWithRandomFloatingPointData<float>(literal.get(), engine); 105 break; 106 case F64: 107 PopulateWithRandomFloatingPointData<double>(literal.get(), engine); 108 break; 109 case S8: 110 PopulateWithRandomIntegralData<int8>(literal.get(), engine); 111 break; 112 case U8: 113 PopulateWithRandomIntegralData<uint8>(literal.get(), engine); 114 break; 115 case S16: 116 PopulateWithRandomIntegralData<int16>(literal.get(), engine); 117 break; 118 case U16: 119 PopulateWithRandomIntegralData<uint16>(literal.get(), engine); 120 break; 121 case S32: 122 PopulateWithRandomIntegralData<int32>(literal.get(), engine); 123 break; 124 case U32: 125 PopulateWithRandomIntegralData<uint32>(literal.get(), engine); 126 break; 127 case S64: 128 PopulateWithRandomIntegralData<int64>(literal.get(), engine); 129 break; 130 case U64: 131 PopulateWithRandomIntegralData<uint64>(literal.get(), engine); 132 break; 133 case PRED: { 134 std::uniform_int_distribution<int> generator(0, 1); 135 TF_CHECK_OK(literal->Populate<bool>( 136 [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { 137 return generator(*engine); 138 })); 139 break; 140 } 141 default: 142 return Unimplemented("Unsupported type for fake literal generation: %s", 143 ShapeUtil::HumanString(shape).c_str()); 144 } 145 return std::move(literal); 146 } 147 148 // Matches binary addition computations. 149 bool LooksLikeSum(const HloComputation& computation) { 150 const HloInstruction* const root = computation.root_instruction(); 151 return root->opcode() == HloOpcode::kAdd && 152 computation.num_parameters() == 2 && 153 root->operand(0)->opcode() == HloOpcode::kParameter && 154 root->operand(1)->opcode() == HloOpcode::kParameter && 155 root->operand(0) != root->operand(1); 156 } 157 158 // Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition, 159 // which requires an init_value of 0 rather than a random value. 160 bool NeedsZeroInitValue(const HloUse& use) { 161 const HloInstruction* const instruction = use.instruction; 162 const HloOpcode opcode = instruction->opcode(); 163 const int64 op_num = use.operand_number; 164 return ( 165 ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && 166 op_num == 1 && LooksLikeSum(*instruction->to_apply())) || 167 (opcode == HloOpcode::kSelectAndScatter && op_num == 2 && 168 LooksLikeSum(*instruction->scatter()))); 169 } 170 171 // Generate random values that are constrained to the input_shape minus the 172 // output_shape so as not to produce wrapping slices, for instance. 173 std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex( 174 const Shape& input_shape, const Shape& slice_shape, 175 std::minstd_rand0* engine) { 176 const int64 rank = ShapeUtil::Rank(input_shape); 177 std::vector<int32> start_indices(rank); 178 for (int i = 0; i < rank; ++i) { 179 const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - 180 ShapeUtil::GetDimension(slice_shape, i); 181 std::uniform_int_distribution<int32> generator(0, upper_bound); 182 start_indices[i] = generator(*engine); 183 } 184 return Literal::CreateR1<int32>(start_indices); 185 } 186 187 // Use dataflow analysis on each parameter to see if there are uses that would 188 // be problematic when generating input data. Returns the list of instructions 189 // that correspond to their uses. 190 // 191 // Should be paired with the CreateLiteralForConstrainedUses() function below. 192 std::vector<HloInstruction*> FindConstrainedUses( 193 const HloDataflowAnalysis& dataflow, const HloInstruction& param) { 194 std::vector<HloInstruction*> constrained_uses; 195 for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) { 196 const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first); 197 for (const HloUse& use : value.uses()) { 198 HloInstruction* instruction = use.instruction; 199 const HloOpcode opcode = instruction->opcode(); 200 const int64 op_num = use.operand_number; 201 if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) || 202 (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) { 203 constrained_uses.push_back(instruction); 204 } else if (opcode == HloOpcode::kFusion) { 205 const HloInstruction* const to_analyze = 206 instruction->fused_parameter(op_num); 207 auto fused_uses = FindConstrainedUses(dataflow, *to_analyze); 208 constrained_uses.insert(constrained_uses.end(), fused_uses.begin(), 209 fused_uses.end()); 210 } else if (NeedsZeroInitValue(use)) { 211 constrained_uses.push_back(instruction); 212 } else if (opcode == HloOpcode::kConvert || 213 opcode == HloOpcode::kReducePrecision) { 214 auto converted_uses = FindConstrainedUses(dataflow, *instruction); 215 constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), 216 converted_uses.end()); 217 } 218 } 219 } 220 return constrained_uses; 221 } 222 223 // Given a parameter, generate a random Literal to use as input if there exist 224 // no constrained uses in the dataflow graph. If such constraints exist, 225 // generate a constrained literal (either bounded in the case of indices, or 226 // zero in the case of init_values for reductions). 227 StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( 228 const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses, 229 const HloInstruction& param, std::minstd_rand0* engine) { 230 HloInstruction* needs_index = nullptr; 231 HloInstruction* needs_zero = nullptr; 232 for (HloInstruction* use : constrained_uses) { 233 switch (use->opcode()) { 234 case HloOpcode::kDynamicSlice: 235 case HloOpcode::kDynamicUpdateSlice: 236 if (needs_index != nullptr && 237 !ShapeUtil::Equal(needs_index->shape(), use->shape())) { 238 return Unimplemented( 239 "Conflicting operand generation slice index constraints\n"); 240 } 241 needs_index = use; 242 break; 243 244 case HloOpcode::kReduce: 245 case HloOpcode::kReduceWindow: 246 case HloOpcode::kSelectAndScatter: 247 needs_zero = use; 248 break; 249 250 default: 251 return Unimplemented( 252 "Constrained operand generation not implemented for %s.", 253 use->ToString().c_str()); 254 } 255 } 256 if (needs_index != nullptr && needs_zero != nullptr) { 257 return Unimplemented( 258 "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " 259 "zero: %s\n", 260 needs_index->ToString().c_str(), needs_zero->ToString().c_str()); 261 } 262 if (needs_index != nullptr) { 263 return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), 264 needs_index->shape(), engine); 265 } else if (needs_zero != nullptr) { 266 return Literal::CreateFromShape(param.shape()); 267 } else { 268 return MakeFakeLiteralInternal(param.shape(), engine); 269 } 270 } 271 272 // Given a module entry parameter, use the dataflow analysis to see if a 273 // special case literal must be created, or if we can generate fake data. 274 StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument( 275 const HloDataflowAnalysis& dataflow, const HloInstruction& param, 276 std::minstd_rand0* engine) { 277 const auto constrained_uses = FindConstrainedUses(dataflow, param); 278 return CreateLiteralForConstrainedUses(constrained_uses, param, engine); 279 } 280 281 } // namespace 282 283 StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) { 284 std::minstd_rand0 engine; 285 return MakeFakeLiteralInternal(shape, &engine); 286 } 287 288 StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( 289 HloModule* const module) { 290 TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); 291 const auto params = module->entry_computation()->parameter_instructions(); 292 std::minstd_rand0 engine; 293 std::vector<std::unique_ptr<Literal>> arguments(params.size()); 294 for (int i = 0; i < params.size(); ++i) { 295 TF_ASSIGN_OR_RETURN( 296 arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); 297 } 298 return std::move(arguments); 299 } 300 301 Status VerifyHloModule(const perftools::gputools::Platform& platform, 302 HloModule* const module) { 303 return HloVerifier().Run(module).status(); 304 } 305 306 } // namespace xla 307