1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 17 18 #include <memory> 19 #include <set> 20 #include <string> 21 #include <utility> 22 23 #include "tensorflow/compiler/xla/layout_util.h" 24 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" 25 #include "tensorflow/compiler/xla/ptr_util.h" 26 #include "tensorflow/compiler/xla/service/platform_util.h" 27 #include "tensorflow/compiler/xla/shape_util.h" 28 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 29 #include "tensorflow/compiler/xla/tests/test_utils.h" 30 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" 31 #include "tensorflow/compiler/xla/types.h" 32 #include "tensorflow/core/lib/core/status_test_util.h" 33 #include "tensorflow/core/lib/gtl/array_slice.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/test.h" 36 #include "tensorflow/core/platform/types.h" 37 38 namespace se = ::perftools::gputools; 39 40 namespace xla { 41 42 namespace { 43 44 using tensorflow::StringPiece; 45 using tensorflow::gtl::ArraySlice; 46 using tensorflow::gtl::optional; 47 48 constexpr char kInterpreter[] = "interpreter"; 49 50 // Helper functions to get test and reference platforms. 51 se::Platform* GetReferencePlatform() { 52 auto result = PlatformUtil::GetPlatform(kInterpreter); 53 TF_CHECK_OK(result.status()) << "could not get interpreter platform"; 54 return result.ValueOrDie(); 55 } 56 57 se::Platform* GetTestPlatform() { 58 auto result = PlatformUtil::GetDefaultPlatform(); 59 TF_CHECK_OK(result.status()) << "could not get test platform"; 60 return result.ValueOrDie(); 61 } 62 63 bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { 64 if (lhs.parameters_size() != rhs.parameters_size()) { 65 return false; 66 } 67 for (int i = 0; i < lhs.parameters_size(); i++) { 68 if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) { 69 return false; 70 } 71 } 72 return ShapeUtil::Equal(lhs.result(), rhs.result()); 73 } 74 75 ProgramShape GetProgramShapeWithLayout(const HloModule& module) { 76 ProgramShape program_shape; 77 const auto* entry = module.entry_computation(); 78 for (const auto* param : entry->parameter_instructions()) { 79 *program_shape.add_parameters() = param->shape(); 80 *program_shape.add_parameter_names() = param->name(); 81 } 82 *program_shape.mutable_result() = entry->root_instruction()->shape(); 83 return program_shape; 84 } 85 86 } // namespace 87 88 HloTestBase::HloTestBase() 89 : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {} 90 91 HloTestBase::HloTestBase(se::Platform* test_platform, 92 se::Platform* reference_platform) 93 : test_runner_(test_platform), reference_runner_(reference_platform) { 94 hlo_verifier_ = MakeUnique<HloVerifier>(); 95 } 96 97 /* static */ 98 std::unique_ptr<HloModule> HloTestBase::CreateNewModule() { 99 HloModuleConfig config; 100 config.set_debug_options(GetDebugOptionsForTest()); 101 return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(), 102 config); 103 } 104 105 /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { 106 auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); 107 // TODO(b/38354253): Change tests to use Parameters instead of Constants. 108 debug_options.add_xla_disable_hlo_passes("constant_folding"); 109 return debug_options; 110 } 111 112 StatusOr<std::unique_ptr<Literal>> HloTestBase::Execute( 113 std::unique_ptr<HloModule> module, 114 tensorflow::gtl::ArraySlice<Literal*> arguments) { 115 return test_runner_.Execute(std::move(module), arguments); 116 } 117 118 std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer( 119 std::unique_ptr<HloModule> module, 120 tensorflow::gtl::ArraySlice<Literal*> arguments) { 121 return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); 122 } 123 124 StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule( 125 const HloModule& test_module, 126 const std::function<void(HloModule*)>& reference_preprocessor) { 127 std::unique_ptr<HloModule> reference_module = test_module.Clone(); 128 const auto& program_shape = GetProgramShapeWithLayout(test_module); 129 130 if (reference_preprocessor != nullptr) { 131 reference_preprocessor(reference_module.get()); 132 if (!ProgramShapesEqual(program_shape, 133 GetProgramShapeWithLayout(*reference_module))) { 134 return InvalidArgument( 135 "reference preprocessor must not modify the program shape"); 136 } 137 } 138 TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(), 139 reference_module.get())); 140 return std::move(reference_module); 141 } 142 143 template <typename LiteralPtr> 144 StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( 145 std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments, 146 const optional<ErrorSpec>& error, bool run_hlo_passes, 147 const std::function<void(HloModule*)>& reference_preprocessor) { 148 static_assert( 149 std::is_same<Literal*, LiteralPtr>::value || 150 std::is_same<std::unique_ptr<Literal>, LiteralPtr>::value, 151 "The LiteralPtr type only accepts Literal* or std::unique_ptr<Literal>."); 152 TF_RETURN_IF_ERROR( 153 VerifyHloModule(*test_runner_.backend().platform(), module.get())); 154 TF_ASSIGN_OR_RETURN(auto reference_module, 155 MakeReferenceModule(*module, reference_preprocessor)); 156 157 // Execute on two backends. 158 TF_ASSIGN_OR_RETURN( 159 auto test, 160 test_runner_.Execute(std::move(module), arguments, run_hlo_passes)); 161 TF_ASSIGN_OR_RETURN(auto reference, 162 reference_runner_.Execute(std::move(reference_module), 163 arguments, run_hlo_passes)); 164 return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, 165 error); 166 } 167 168 template <typename LiteralPtr> 169 ::testing::AssertionResult HloTestBase::RunAndCompare( 170 std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments, 171 const optional<ErrorSpec>& error, 172 const std::function<void(HloModule*)>& reference_preprocessor) { 173 auto result = 174 RunAndCompareInternal(std::move(module), arguments, error, 175 /*run_hlo_passes=*/true, reference_preprocessor); 176 if (!result.ok()) { 177 return ::testing::AssertionFailure() << result.status(); 178 } 179 return result.ValueOrDie(); 180 } 181 182 template <typename LiteralPtr> 183 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( 184 std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments, 185 const optional<ErrorSpec>& error, 186 const std::function<void(HloModule*)>& reference_preprocessor) { 187 auto result = 188 RunAndCompareInternal(std::move(module), arguments, error, 189 /*run_hlo_passes=*/false, reference_preprocessor); 190 if (!result.ok()) { 191 return ::testing::AssertionFailure() << result.status(); 192 } 193 return result.ValueOrDie(); 194 } 195 196 ::testing::AssertionResult HloTestBase::RunAndCompare( 197 std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error, 198 const std::function<void(HloModule*)>& reference_preprocessor) { 199 const auto& fake_arguments = 200 MakeFakeArguments(module.get()).ConsumeValueOrDie(); 201 return RunAndCompare<std::unique_ptr<Literal>>( 202 std::move(module), fake_arguments, error, reference_preprocessor); 203 } 204 205 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( 206 std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error, 207 const std::function<void(HloModule*)>& reference_preprocessor) { 208 const auto& fake_arguments = 209 MakeFakeArguments(module.get()).ConsumeValueOrDie(); 210 return RunAndCompareNoHloPasses<std::unique_ptr<Literal>>( 211 std::move(module), fake_arguments, error, reference_preprocessor); 212 } 213 214 ::testing::AssertionResult HloTestBase::RunAndCompare( 215 const StringPiece hlo_string, 216 const tensorflow::gtl::optional<ErrorSpec>& error, 217 const std::function<void(HloModule*)>& reference_preprocessor) { 218 auto module_or_status = 219 HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); 220 if (!module_or_status.ok()) { 221 return ::testing::AssertionFailure() 222 << "Error while parsing HLO text format: " 223 << module_or_status.status().ToString(); 224 } 225 return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, 226 reference_preprocessor); 227 } 228 229 ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( 230 const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error, 231 const std::function<void(HloModule*)>& reference_preprocessor) { 232 auto module_or_status = 233 HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); 234 if (!module_or_status.ok()) { 235 return ::testing::AssertionFailure() 236 << "failed reading hlo module from file"; 237 } 238 return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, 239 reference_preprocessor); 240 } 241 242 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( 243 const StringPiece hlo_string, 244 const tensorflow::gtl::optional<ErrorSpec>& error, 245 const std::function<void(HloModule*)>& reference_preprocessor) { 246 auto module_or_status = 247 HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); 248 if (!module_or_status.ok()) { 249 return ::testing::AssertionFailure() 250 << "Error while parsing HLO text format: " 251 << module_or_status.status().ToString(); 252 } 253 return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, 254 reference_preprocessor); 255 } 256 257 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( 258 const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error, 259 const std::function<void(HloModule*)>& reference_preprocessor) { 260 auto module_or_status = 261 HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); 262 if (!module_or_status.ok()) { 263 return ::testing::AssertionFailure() 264 << "failed reading hlo module from file"; 265 } 266 return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, 267 reference_preprocessor); 268 } 269 270 Backend& HloTestBase::backend() { return test_runner_.backend(); } 271 272 /* static */ 273 string HloTestBase::TestName() { 274 return ::testing::UnitTest::GetInstance()->current_test_info()->name(); 275 } 276 277 } // namespace xla 278