Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     17 
     18 #include <memory>
     19 #include <set>
     20 #include <string>
     21 #include <utility>
     22 
     23 #include "absl/algorithm/container.h"
     24 #include "absl/memory/memory.h"
     25 #include "absl/types/span.h"
     26 #include "tensorflow/compiler/xla/debug_options_flags.h"
     27 #include "tensorflow/compiler/xla/layout_util.h"
     28 #include "tensorflow/compiler/xla/service/hlo_module.h"
     29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
     30 #include "tensorflow/compiler/xla/service/platform_util.h"
     31 #include "tensorflow/compiler/xla/shape_util.h"
     32 #include "tensorflow/compiler/xla/statusor.h"
     33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     34 #include "tensorflow/compiler/xla/tests/test_utils.h"
     35 #include "tensorflow/compiler/xla/types.h"
     36 #include "tensorflow/core/lib/core/status_test_util.h"
     37 #include "tensorflow/core/platform/logging.h"
     38 #include "tensorflow/core/platform/test.h"
     39 #include "tensorflow/core/platform/types.h"
     40 
     41 namespace xla {
     42 
     43 namespace {
     44 
     45 using absl::optional;
     46 using absl::string_view;
     47 
     48 constexpr char kInterpreter[] = "interpreter";
     49 
     50 // Helper functions to get test and reference platforms.
     51 se::Platform* GetReferencePlatform() {
     52   auto result = PlatformUtil::GetPlatform(kInterpreter);
     53   TF_CHECK_OK(result.status()) << "could not get interpreter platform";
     54   return result.ValueOrDie();
     55 }
     56 
     57 se::Platform* GetTestPlatform() {
     58   auto result = PlatformUtil::GetDefaultPlatform();
     59   TF_CHECK_OK(result.status()) << "could not get test platform";
     60   return result.ValueOrDie();
     61 }
     62 
     63 bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
     64   if (lhs.parameters_size() != rhs.parameters_size()) {
     65     return false;
     66   }
     67   for (int i = 0; i < lhs.parameters_size(); i++) {
     68     if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
     69       return false;
     70     }
     71   }
     72   return ShapeUtil::Equal(lhs.result(), rhs.result());
     73 }
     74 
     75 ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
     76   ProgramShape program_shape;
     77   const auto* entry = module.entry_computation();
     78   for (const auto* param : entry->parameter_instructions()) {
     79     *program_shape.add_parameters() = param->shape();
     80     *program_shape.add_parameter_names() = param->name();
     81   }
     82   *program_shape.mutable_result() = entry->root_instruction()->shape();
     83   return program_shape;
     84 }
     85 
     86 }  // namespace
     87 
     88 Status VerifiedHloModule::Verify() {
     89   if (computation_count() == 0) {
     90     // The computation was never built. Nothing to verify.
     91     return Status::OK();
     92   }
     93   return verifier_.Run(this).status();
     94 }
     95 
     96 void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
     97   Status status = Verify();
     98   if (!status.ok()) {
     99     ADD_FAILURE() << "HloVerifier failed on module " << name()
    100                   << (message.empty() ? "" : absl::StrCat(" (", message, ")"))
    101                   << ": " << status;
    102     LOG(ERROR) << "Contents of bad module:";
    103     XLA_LOG_LINES(tensorflow::ERROR, ToString());
    104   }
    105 }
    106 
    107 HloTestBase::HloTestBase(bool verifier_layout_sensitive,
    108                          bool allow_mixed_precision_in_hlo_verifier,
    109                          std::function<bool(const HloInstruction*)>
    110                              instruction_can_change_layout_func)
    111     : HloTestBase(GetTestPlatform(), GetReferencePlatform(),
    112                   verifier_layout_sensitive,
    113                   allow_mixed_precision_in_hlo_verifier,
    114                   instruction_can_change_layout_func) {}
    115 
    116 HloTestBase::HloTestBase(se::Platform* test_platform,
    117                          se::Platform* reference_platform,
    118                          bool verifier_layout_sensitive,
    119                          bool allow_mixed_precision_in_hlo_verifier,
    120                          std::function<bool(const HloInstruction*)>
    121                              instruction_can_change_layout_func)
    122     : test_runner_(test_platform),
    123       reference_runner_(reference_platform),
    124       verifier_layout_sensitive_(verifier_layout_sensitive),
    125       allow_mixed_precision_in_hlo_verifier_(
    126           allow_mixed_precision_in_hlo_verifier) {
    127   hlo_verifier_ = absl::make_unique<HloVerifier>(
    128       /*layout_sensitive=*/verifier_layout_sensitive,
    129       /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
    130       instruction_can_change_layout_func);
    131 }
    132 
    133 std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
    134     const string& name) {
    135   return absl::make_unique<HloModule>(name, GetModuleConfigForTest());
    136 }
    137 
    138 std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
    139     const string& name) {
    140   return absl::make_unique<VerifiedHloModule>(
    141       name, GetModuleConfigForTest(), verifier_layout_sensitive_,
    142       allow_mixed_precision_in_hlo_verifier_,
    143       backend().compiler()->ShapeSizeBytesFunction());
    144 }
    145 
    146 StatusOr<std::unique_ptr<VerifiedHloModule>>
    147 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
    148                                           const HloModuleConfig& config) {
    149   auto module = absl::make_unique<VerifiedHloModule>(
    150       TestName(), config, verifier_layout_sensitive_,
    151       allow_mixed_precision_in_hlo_verifier_,
    152       backend().compiler()->ShapeSizeBytesFunction());
    153   TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
    154   TF_RETURN_IF_ERROR(module->Verify());
    155   return std::move(module);
    156 }
    157 
    158 /* static */
    159 StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
    160                                        HloModule* module) {
    161   const string module_str_before_run = module->ToProto().ShortDebugString();
    162   const auto status_or = hlo_pass->Run(module);
    163   if (status_or.status().ok()) {
    164     const string module_str_after_run = module->ToProto().ShortDebugString();
    165     if (!status_or.ValueOrDie()) {
    166       // Check that the proto remains same.
    167       EXPECT_EQ(module_str_after_run, module_str_before_run);
    168     }
    169   }
    170   return status_or;
    171 }
    172 
    173 /* static */
    174 PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
    175   PrecisionConfig precision_config;
    176   precision_config.mutable_operand_precision()->Resize(
    177       operands, PrecisionConfig::DEFAULT);
    178   return precision_config;
    179 }
    180 
    181 DebugOptions HloTestBase::GetDebugOptionsForTest() {
    182   auto debug_options = GetDebugOptionsFromFlags();
    183   // TODO(b/38354253): Change tests to use Parameters instead of Constants.
    184   debug_options.add_xla_disable_hlo_passes("constant_folding");
    185   debug_options.set_xla_gpu_max_kernel_unroll_factor(1);
    186   debug_options.set_xla_hlo_evaluator_use_fast_path(true);
    187   return debug_options;
    188 }
    189 
    190 StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
    191                                        absl::Span<Literal* const> arguments) {
    192   return test_runner_.Execute(std::move(module), arguments);
    193 }
    194 
    195 Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
    196                                         absl::Span<Literal* const> arguments) {
    197   return test_runner_
    198       .Execute(std::move(module), arguments,
    199                /*run_hlo_passes=*/false)
    200       .ValueOrDie();
    201 }
    202 
    203 Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
    204                                         absl::Span<Literal* const> arguments) {
    205   return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
    206 }
    207 
    208 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
    209     std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
    210     int64 num_replicas, bool use_threads) {
    211   HloRunner::ReplicatedExecuteOptions options;
    212   options.num_replicas = num_replicas;
    213   for (auto argument : arguments) {
    214     options.arguments.push_back(argument);
    215   }
    216   return test_runner_.ExecuteReplicated(std::move(module), options,
    217                                         use_threads);
    218 }
    219 
    220 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
    221     std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
    222     int64 num_replicas, DeviceAssignment* device_assignment,
    223     bool run_hlo_passes, bool use_threads) {
    224   HloRunner::ReplicatedExecuteOptions options;
    225   options.num_replicas = num_replicas;
    226   options.run_hlo_passes = run_hlo_passes;
    227   for (auto argument : arguments) {
    228     options.arguments.push_back(argument);
    229   }
    230   return test_runner_.ExecuteReplicated(std::move(module), options,
    231                                         device_assignment, use_threads);
    232 }
    233 
    234 StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
    235     const HloModule& test_module,
    236     const std::function<void(HloModule*)>& reference_preprocessor) {
    237   std::unique_ptr<HloModule> reference_module = test_module.Clone();
    238   const auto& program_shape = GetProgramShapeWithLayout(test_module);
    239 
    240   if (reference_preprocessor != nullptr) {
    241     reference_preprocessor(reference_module.get());
    242     if (!ProgramShapesEqual(program_shape,
    243                             GetProgramShapeWithLayout(*reference_module))) {
    244       return InvalidArgument(
    245           "reference preprocessor must not modify the program shape");
    246     }
    247   }
    248   TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status());
    249   return std::move(reference_module);
    250 }
    251 
    252 StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
    253     std::unique_ptr<HloModule> module,
    254     const absl::Span<Literal* const> arguments,
    255     const optional<ErrorSpec>& error, bool run_hlo_passes,
    256     const std::function<void(HloModule*)>& reference_preprocessor) {
    257   TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
    258   TF_ASSIGN_OR_RETURN(auto reference_module,
    259                       MakeReferenceModule(*module, reference_preprocessor));
    260 
    261   // Execute on two backends.
    262   TF_ASSIGN_OR_RETURN(
    263       auto test,
    264       test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
    265   TF_ASSIGN_OR_RETURN(auto reference,
    266                       reference_runner_.Execute(std::move(reference_module),
    267                                                 arguments, run_hlo_passes));
    268   return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
    269                                       error);
    270 }
    271 
    272 ::testing::AssertionResult HloTestBase::RunAndCompare(
    273     std::unique_ptr<HloModule> module,
    274     const absl::Span<Literal* const> arguments,
    275     const optional<ErrorSpec>& error,
    276     const std::function<void(HloModule*)>& reference_preprocessor) {
    277   auto result =
    278       RunAndCompareInternal(std::move(module), arguments, error,
    279                             /*run_hlo_passes=*/true, reference_preprocessor);
    280   if (!result.ok()) {
    281     return ::testing::AssertionFailure() << result.status();
    282   }
    283   return result.ValueOrDie();
    284 }
    285 
    286 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
    287     std::unique_ptr<HloModule> module,
    288     const absl::Span<Literal* const> arguments,
    289     const optional<ErrorSpec>& error,
    290     const std::function<void(HloModule*)>& reference_preprocessor) {
    291   auto result =
    292       RunAndCompareInternal(std::move(module), arguments, error,
    293                             /*run_hlo_passes=*/false, reference_preprocessor);
    294   if (!result.ok()) {
    295     return ::testing::AssertionFailure() << result.status();
    296   }
    297   return result.ValueOrDie();
    298 }
    299 
    300 ::testing::AssertionResult HloTestBase::RunAndCompare(
    301     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
    302     const std::function<void(HloModule*)>& reference_preprocessor) {
    303   auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
    304 
    305   std::vector<Literal*> fake_argument_ptrs;
    306   absl::c_transform(
    307       fake_arguments, std::back_inserter(fake_argument_ptrs),
    308       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
    309 
    310   return RunAndCompare(std::move(module), fake_argument_ptrs, error,
    311                        reference_preprocessor);
    312 }
    313 
    314 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
    315     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
    316     const std::function<void(HloModule*)>& reference_preprocessor) {
    317   const auto& fake_arguments =
    318       MakeFakeArguments(module.get()).ConsumeValueOrDie();
    319   std::vector<Literal*> fake_argument_ptrs;
    320   absl::c_transform(
    321       fake_arguments, std::back_inserter(fake_argument_ptrs),
    322       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
    323 
    324   return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
    325                                   reference_preprocessor);
    326 }
    327 
    328 ::testing::AssertionResult HloTestBase::RunAndCompare(
    329     string_view hlo_string, const absl::optional<ErrorSpec>& error,
    330     const std::function<void(HloModule*)>& reference_preprocessor) {
    331   auto module_or_status =
    332       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
    333   if (!module_or_status.ok()) {
    334     return ::testing::AssertionFailure()
    335            << "Error while parsing HLO text format: "
    336            << module_or_status.status().ToString();
    337   }
    338   return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
    339                        reference_preprocessor);
    340 }
    341 
    342 ::testing::AssertionResult HloTestBase::Run(string_view hlo_string,
    343                                             bool run_hlo_passes,
    344                                             ExecutionProfile* profile,
    345                                             string backend_config) {
    346   auto module_or_status =
    347       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
    348   if (!module_or_status.ok()) {
    349     return ::testing::AssertionFailure()
    350            << "Error while parsing HLO text format: "
    351            << module_or_status.status().ToString();
    352   }
    353 
    354   std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
    355   const auto& fake_arguments =
    356       MakeFakeArguments(module.get()).ConsumeValueOrDie();
    357   std::vector<Literal*> fake_argument_ptrs;
    358   absl::c_transform(
    359       fake_arguments, std::back_inserter(fake_argument_ptrs),
    360       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
    361 
    362   if (profile != nullptr) {
    363     // We have to enable HLO profiling since otherwise currently the
    364     // ExecutionProfile is not correct.
    365     //
    366     // TODO(b/119432044): Fix collection of the ExecutionProfile
    367     // so that this is not necessary.
    368     HloModuleConfig config = module->config();
    369     DebugOptions debug_options = config.debug_options();
    370     debug_options.set_xla_hlo_profile(true);
    371     config.set_debug_options(debug_options);
    372     module->set_config(config);
    373   }
    374 
    375   if (!backend_config.empty()) {
    376     // Set backend configuration if it is given.
    377     HloInstruction* instruction =
    378         module->entry_computation()->root_instruction();
    379     instruction->set_raw_backend_config_string(backend_config);
    380   }
    381 
    382   // return ::testing::AssertionSuccess();
    383   auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs,
    384                                      /*run_hlo_passes=*/run_hlo_passes,
    385                                      /*profile=*/profile);
    386 
    387   return output.ok()
    388              ? ::testing::AssertionSuccess()
    389              : ::testing::AssertionFailure() << output.status().error_message();
    390 }
    391 
    392 ::testing::AssertionResult HloTestBase::RunMultipleTimes(
    393     string_view hlo_string, bool run_hlo_passes,
    394     std::vector<ExecutionProfile>* profiles, string backend_config) {
    395   int n = profiles->size();
    396   std::vector<std::vector<Literal*>> fake_argument_ptrs(n);
    397   std::vector<std::vector<Literal>> fake_arguments(n);
    398   std::vector<std::unique_ptr<Executable>> executables(n);
    399 
    400   for (int i = 0; i < n; ++i) {
    401     auto module_or_status =
    402         HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
    403     if (!module_or_status.ok()) {
    404       return ::testing::AssertionFailure()
    405              << "Error while parsing HLO text format: "
    406              << module_or_status.status().ToString();
    407     }
    408     std::unique_ptr<HloModule> module =
    409         std::move(module_or_status.ValueOrDie());
    410 
    411     fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
    412     absl::c_transform(
    413         fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]),
    414         [](const Literal& literal) { return const_cast<Literal*>(&literal); });
    415 
    416     if (profiles != nullptr) {
    417       // We have to enable HLO profiling since otherwise currently the
    418       // ExecutionProfile is not correct.
    419       //
    420       // TODO(b/119432044): Fix collection of the ExecutionProfile
    421       // so that this is not necessary.
    422       HloModuleConfig config = module->config();
    423       DebugOptions debug_options = config.debug_options();
    424       debug_options.set_xla_hlo_profile(true);
    425       config.set_debug_options(debug_options);
    426       module->set_config(config);
    427     }
    428 
    429     if (!backend_config.empty()) {
    430       // Set backend configuration if it is given.
    431       HloInstruction* instruction =
    432           module->entry_computation()->root_instruction();
    433       instruction->set_raw_backend_config_string(backend_config);
    434     }
    435 
    436     auto executable =
    437         test_runner_.CreateExecutable(std::move(module), run_hlo_passes);
    438     if (!executable.ok()) {
    439       return ::testing::AssertionFailure()
    440              << executable.status().error_message();
    441     }
    442     executables[i] = std::move(executable.ValueOrDie());
    443   }
    444 
    445   for (int i = 0; i < n; ++i) {
    446     auto output =
    447         test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i],
    448                              /*profile=*/&((*profiles)[i]));
    449     if (!output.ok()) {
    450       return ::testing::AssertionFailure() << output.status().error_message();
    451     }
    452   }
    453 
    454   return ::testing::AssertionSuccess();
    455 }
    456 
    457 ::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
    458     const string& filename, const absl::optional<ErrorSpec>& error,
    459     const std::function<void(HloModule*)>& reference_preprocessor) {
    460   auto module_or_status =
    461       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
    462   if (!module_or_status.ok()) {
    463     return ::testing::AssertionFailure()
    464            << "failed reading hlo module from file";
    465   }
    466   return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
    467                        reference_preprocessor);
    468 }
    469 
    470 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
    471     string_view hlo_string, const absl::optional<ErrorSpec>& error,
    472     const std::function<void(HloModule*)>& reference_preprocessor) {
    473   auto module_or_status =
    474       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
    475   if (!module_or_status.ok()) {
    476     return ::testing::AssertionFailure()
    477            << "Error while parsing HLO text format: "
    478            << module_or_status.status().ToString();
    479   }
    480   return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
    481                                   reference_preprocessor);
    482 }
    483 
    484 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
    485     const string& filename, const absl::optional<ErrorSpec>& error,
    486     const std::function<void(HloModule*)>& reference_preprocessor) {
    487   auto module_or_status =
    488       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
    489   if (!module_or_status.ok()) {
    490     return ::testing::AssertionFailure()
    491            << "failed reading hlo module from file";
    492   }
    493   return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
    494                                   reference_preprocessor);
    495 }
    496 
    497 HloComputation* HloTestBase::FindComputation(HloModule* module,
    498                                              absl::string_view name) {
    499   auto computations = module->computations();
    500   auto it = absl::c_find_if(
    501       computations, [&](HloComputation* c) { return c->name() == name; });
    502   if (it == computations.end()) {
    503     return nullptr;
    504   }
    505   return *it;
    506 }
    507 
    508 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
    509                                              absl::string_view name) {
    510   for (const HloComputation* c : module->computations()) {
    511     auto instructions = c->instructions();
    512     auto it = absl::c_find_if(
    513         instructions, [&](HloInstruction* i) { return i->name() == name; });
    514     if (it != instructions.end()) {
    515       return *it;
    516     }
    517   }
    518   return nullptr;
    519 }
    520 
    521 Backend& HloTestBase::backend() { return test_runner_.backend(); }
    522 
    523 /* static */
    524 string HloTestBase::TestName() {
    525   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
    526 }
    527 
    528 }  // namespace xla
    529