Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     17 
     18 #include <memory>
     19 #include <set>
     20 #include <string>
     21 #include <utility>
     22 
     23 #include "tensorflow/compiler/xla/layout_util.h"
     24 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
     25 #include "tensorflow/compiler/xla/ptr_util.h"
     26 #include "tensorflow/compiler/xla/service/platform_util.h"
     27 #include "tensorflow/compiler/xla/shape_util.h"
     28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     29 #include "tensorflow/compiler/xla/tests/test_utils.h"
     30 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
     31 #include "tensorflow/compiler/xla/types.h"
     32 #include "tensorflow/core/lib/core/status_test_util.h"
     33 #include "tensorflow/core/lib/gtl/array_slice.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/test.h"
     36 #include "tensorflow/core/platform/types.h"
     37 
     38 namespace se = ::perftools::gputools;
     39 
     40 namespace xla {
     41 
     42 namespace {
     43 
     44 using tensorflow::StringPiece;
     45 using tensorflow::gtl::ArraySlice;
     46 using tensorflow::gtl::optional;
     47 
     48 constexpr char kInterpreter[] = "interpreter";
     49 
     50 // Helper functions to get test and reference platforms.
     51 se::Platform* GetReferencePlatform() {
     52   auto result = PlatformUtil::GetPlatform(kInterpreter);
     53   TF_CHECK_OK(result.status()) << "could not get interpreter platform";
     54   return result.ValueOrDie();
     55 }
     56 
     57 se::Platform* GetTestPlatform() {
     58   auto result = PlatformUtil::GetDefaultPlatform();
     59   TF_CHECK_OK(result.status()) << "could not get test platform";
     60   return result.ValueOrDie();
     61 }
     62 
     63 bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
     64   if (lhs.parameters_size() != rhs.parameters_size()) {
     65     return false;
     66   }
     67   for (int i = 0; i < lhs.parameters_size(); i++) {
     68     if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
     69       return false;
     70     }
     71   }
     72   return ShapeUtil::Equal(lhs.result(), rhs.result());
     73 }
     74 
     75 ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
     76   ProgramShape program_shape;
     77   const auto* entry = module.entry_computation();
     78   for (const auto* param : entry->parameter_instructions()) {
     79     *program_shape.add_parameters() = param->shape();
     80     *program_shape.add_parameter_names() = param->name();
     81   }
     82   *program_shape.mutable_result() = entry->root_instruction()->shape();
     83   return program_shape;
     84 }
     85 
     86 }  // namespace
     87 
     88 HloTestBase::HloTestBase()
     89     : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {}
     90 
     91 HloTestBase::HloTestBase(se::Platform* test_platform,
     92                          se::Platform* reference_platform)
     93     : test_runner_(test_platform), reference_runner_(reference_platform) {
     94   hlo_verifier_ = MakeUnique<HloVerifier>();
     95 }
     96 
     97 /* static */
     98 std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
     99   HloModuleConfig config;
    100   config.set_debug_options(GetDebugOptionsForTest());
    101   return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
    102                                config);
    103 }
    104 
    105 /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
    106   auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
    107   // TODO(b/38354253): Change tests to use Parameters instead of Constants.
    108   debug_options.add_xla_disable_hlo_passes("constant_folding");
    109   return debug_options;
    110 }
    111 
    112 StatusOr<std::unique_ptr<Literal>> HloTestBase::Execute(
    113     std::unique_ptr<HloModule> module,
    114     tensorflow::gtl::ArraySlice<Literal*> arguments) {
    115   return test_runner_.Execute(std::move(module), arguments);
    116 }
    117 
    118 std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
    119     std::unique_ptr<HloModule> module,
    120     tensorflow::gtl::ArraySlice<Literal*> arguments) {
    121   return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
    122 }
    123 
    124 StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
    125     const HloModule& test_module,
    126     const std::function<void(HloModule*)>& reference_preprocessor) {
    127   std::unique_ptr<HloModule> reference_module = test_module.Clone();
    128   const auto& program_shape = GetProgramShapeWithLayout(test_module);
    129 
    130   if (reference_preprocessor != nullptr) {
    131     reference_preprocessor(reference_module.get());
    132     if (!ProgramShapesEqual(program_shape,
    133                             GetProgramShapeWithLayout(*reference_module))) {
    134       return InvalidArgument(
    135           "reference preprocessor must not modify the program shape");
    136     }
    137   }
    138   TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(),
    139                                      reference_module.get()));
    140   return std::move(reference_module);
    141 }
    142 
    143 template <typename LiteralPtr>
    144 StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
    145     std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
    146     const optional<ErrorSpec>& error, bool run_hlo_passes,
    147     const std::function<void(HloModule*)>& reference_preprocessor) {
    148   static_assert(
    149       std::is_same<Literal*, LiteralPtr>::value ||
    150           std::is_same<std::unique_ptr<Literal>, LiteralPtr>::value,
    151       "The LiteralPtr type only accepts Literal* or std::unique_ptr<Literal>.");
    152   TF_RETURN_IF_ERROR(
    153       VerifyHloModule(*test_runner_.backend().platform(), module.get()));
    154   TF_ASSIGN_OR_RETURN(auto reference_module,
    155                       MakeReferenceModule(*module, reference_preprocessor));
    156 
    157   // Execute on two backends.
    158   TF_ASSIGN_OR_RETURN(
    159       auto test,
    160       test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
    161   TF_ASSIGN_OR_RETURN(auto reference,
    162                       reference_runner_.Execute(std::move(reference_module),
    163                                                 arguments, run_hlo_passes));
    164   return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test,
    165                                       error);
    166 }
    167 
    168 template <typename LiteralPtr>
    169 ::testing::AssertionResult HloTestBase::RunAndCompare(
    170     std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
    171     const optional<ErrorSpec>& error,
    172     const std::function<void(HloModule*)>& reference_preprocessor) {
    173   auto result =
    174       RunAndCompareInternal(std::move(module), arguments, error,
    175                             /*run_hlo_passes=*/true, reference_preprocessor);
    176   if (!result.ok()) {
    177     return ::testing::AssertionFailure() << result.status();
    178   }
    179   return result.ValueOrDie();
    180 }
    181 
    182 template <typename LiteralPtr>
    183 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
    184     std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
    185     const optional<ErrorSpec>& error,
    186     const std::function<void(HloModule*)>& reference_preprocessor) {
    187   auto result =
    188       RunAndCompareInternal(std::move(module), arguments, error,
    189                             /*run_hlo_passes=*/false, reference_preprocessor);
    190   if (!result.ok()) {
    191     return ::testing::AssertionFailure() << result.status();
    192   }
    193   return result.ValueOrDie();
    194 }
    195 
    196 ::testing::AssertionResult HloTestBase::RunAndCompare(
    197     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
    198     const std::function<void(HloModule*)>& reference_preprocessor) {
    199   const auto& fake_arguments =
    200       MakeFakeArguments(module.get()).ConsumeValueOrDie();
    201   return RunAndCompare<std::unique_ptr<Literal>>(
    202       std::move(module), fake_arguments, error, reference_preprocessor);
    203 }
    204 
    205 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
    206     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
    207     const std::function<void(HloModule*)>& reference_preprocessor) {
    208   const auto& fake_arguments =
    209       MakeFakeArguments(module.get()).ConsumeValueOrDie();
    210   return RunAndCompareNoHloPasses<std::unique_ptr<Literal>>(
    211       std::move(module), fake_arguments, error, reference_preprocessor);
    212 }
    213 
    214 ::testing::AssertionResult HloTestBase::RunAndCompare(
    215     const StringPiece hlo_string,
    216     const tensorflow::gtl::optional<ErrorSpec>& error,
    217     const std::function<void(HloModule*)>& reference_preprocessor) {
    218   auto module_or_status =
    219       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
    220   if (!module_or_status.ok()) {
    221     return ::testing::AssertionFailure()
    222            << "Error while parsing HLO text format: "
    223            << module_or_status.status().ToString();
    224   }
    225   return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
    226                        reference_preprocessor);
    227 }
    228 
    229 ::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
    230     const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
    231     const std::function<void(HloModule*)>& reference_preprocessor) {
    232   auto module_or_status =
    233       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
    234   if (!module_or_status.ok()) {
    235     return ::testing::AssertionFailure()
    236            << "failed reading hlo module from file";
    237   }
    238   return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
    239                        reference_preprocessor);
    240 }
    241 
    242 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
    243     const StringPiece hlo_string,
    244     const tensorflow::gtl::optional<ErrorSpec>& error,
    245     const std::function<void(HloModule*)>& reference_preprocessor) {
    246   auto module_or_status =
    247       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
    248   if (!module_or_status.ok()) {
    249     return ::testing::AssertionFailure()
    250            << "Error while parsing HLO text format: "
    251            << module_or_status.status().ToString();
    252   }
    253   return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
    254                                   reference_preprocessor);
    255 }
    256 
    257 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
    258     const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
    259     const std::function<void(HloModule*)>& reference_preprocessor) {
    260   auto module_or_status =
    261       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
    262   if (!module_or_status.ok()) {
    263     return ::testing::AssertionFailure()
    264            << "failed reading hlo module from file";
    265   }
    266   return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
    267                                   reference_preprocessor);
    268 }
    269 
    270 Backend& HloTestBase::backend() { return test_runner_.backend(); }
    271 
    272 /* static */
    273 string HloTestBase::TestName() {
    274   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
    275 }
    276 
    277 }  // namespace xla
    278