Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
     17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
     18 
     19 #include <memory>
     20 #include <string>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/service/backend.h"
     24 #include "tensorflow/compiler/xla/service/computation_layout.h"
     25 #include "tensorflow/compiler/xla/service/hlo_module.h"
     26 #include "tensorflow/compiler/xla/service/hlo_runner.h"
     27 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
     28 #include "tensorflow/compiler/xla/service/platform_util.h"
     29 #include "tensorflow/compiler/xla/shape_layout.h"
     30 #include "tensorflow/compiler/xla/statusor.h"
     31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     32 #include "tensorflow/compiler/xla/types.h"
     33 #include "tensorflow/compiler/xla/xla_data.pb.h"
     34 #include "tensorflow/core/lib/gtl/array_slice.h"
     35 #include "tensorflow/core/lib/gtl/optional.h"
     36 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     37 #include "tensorflow/core/platform/test.h"
     38 
     39 namespace xla {
     40 
     41 // A base class for tests which build and/or run HLO code. The class includes
     42 // support for running an HLO module on two platforms and compare the results.
     43 // This is a lower level of abstraction than using the client interface and
     44 // enables, for one, explicitly building a graph of HLO instructions to run.
     45 //
     46 // This can also be used to write text/file-based test cases. Note that the test
     47 // target is responsible for linking the needed backends. A covenient way to do
     48 // this is to make it an xla_test: it will generate test targets linking with
     49 // the respective backends, which will be used as the test backend; the
     50 // interpreter backend is already linked with hlo_test_base so it will be the
     51 // default reference backend. For example, if you want to compare both cpu vs.
     52 // interpreter, and gpu vs. interpreter, you can:
     53 //
     54 //  xla_test (
     55 //    name = "sample_text_test",
     56 //    srcs = ["sample_text_test.cc"],
     57 //    backends = [
     58 //      "cpu",
     59 //      "gpu",
     60 //    ],
     61 //    deps = [
     62 //      "//third_party/tensorflow/compiler/xla/tests:hlo_test_base",
     63 //      ...
     64 //    ],
     65 //  )
     66 //
     67 // For a more detailed example, see "../tests/sample_text_test.cc".
     68 class HloTestBase : public ::testing::Test {
     69  protected:
     70   // This uses the interpreter backend as the reference backend and
     71   // automatically finds another supported backend as the test backend. If the
     72   // interpreter is the only supported backend, it will be both the test backend
     73   // and the reference backend.
     74   HloTestBase();
     75 
     76   // If your test doesn't use interpreter as the reference backend, you can use
     77   // this constructor. Note that your test target is responsible for linking in
     78   // both needed backends.
     79   HloTestBase(::perftools::gputools::Platform* test_platform,
     80               ::perftools::gputools::Platform* reference_platform);
     81 
     82   ~HloTestBase() override {}
     83 
     84   // Creates a new HLO module for a test. The module created will have
     85   // TestName() for its name; it will also automatically populate its debug
     86   // options from command-line flags. If you want a fresh HloModule object and
     87   // then add HloComputations to it, it's recommended to use this method in your
     88   // tests.
     89   static std::unique_ptr<HloModule> CreateNewModule();
     90 
     91   // Populates debug options from command-line flags and adjusts the options for
     92   // testing. It is recommended to use this when you need to pass in
     93   // DebugOptions, e.g. when creating a module from a string or a file.
     94   static DebugOptions GetDebugOptionsForTest();
     95 
     96   // Executes the given module and return the result as a Literal.
     97   StatusOr<std::unique_ptr<Literal>> Execute(
     98       std::unique_ptr<HloModule> module,
     99       tensorflow::gtl::ArraySlice<Literal*> arguments);
    100 
    101   std::unique_ptr<Literal> ExecuteAndTransfer(
    102       std::unique_ptr<HloModule> module,
    103       tensorflow::gtl::ArraySlice<Literal*> arguments);
    104 
    105   // Executes the given hlo module on two backends and compares results.
    106   //
    107   // 'arguments': the input of the hlo module. The LiteralPtr type accepts
    108   // Literal* or std::unique_ptr<Literal>.
    109   //
    110   // 'error': if has value, expects the results to be near (within the error
    111   // bound). Otherwise, expects the results to be equal.
    112   //
    113   // 'reference_preprocessor': the module should be ready to run on the test
    114   // backend, but it might need to be tailored so that it is able to run on the
    115   // reference backend. Note that the program shape of the module must not be
    116   // modified.
    117   template <typename LiteralPtr>
    118   ::testing::AssertionResult RunAndCompare(
    119       std::unique_ptr<HloModule> module,
    120       const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
    121       const tensorflow::gtl::optional<ErrorSpec>& error,
    122       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
    123       TF_MUST_USE_RESULT;
    124 
    125   // Same as above, except that the module will be executed without Hlo
    126   // optimization.
    127   template <typename LiteralPtr>
    128   ::testing::AssertionResult RunAndCompareNoHloPasses(
    129       std::unique_ptr<HloModule> module,
    130       const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
    131       const tensorflow::gtl::optional<ErrorSpec>& error,
    132       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
    133       TF_MUST_USE_RESULT;
    134 
    135   // Executes an hlo module with fake inputs and compares the results.
    136   ::testing::AssertionResult RunAndCompare(
    137       std::unique_ptr<HloModule> module,
    138       const tensorflow::gtl::optional<ErrorSpec>& error,
    139       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
    140       TF_MUST_USE_RESULT;
    141 
    142   // Same as above, except that the module will be executed without Hlo
    143   // optimization.
    144   ::testing::AssertionResult RunAndCompareNoHloPasses(
    145       std::unique_ptr<HloModule> module,
    146       const tensorflow::gtl::optional<ErrorSpec>& error,
    147       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
    148       TF_MUST_USE_RESULT;
    149 
    150   // Convenient wrappers for executing and comparing an hlo module with fake
    151   // input. Module can be passed in directly, or parsed from an hlo_string,
    152   // or loaded from a file.
    153   ::testing::AssertionResult RunAndCompare(
    154       const tensorflow::StringPiece hlo_string,
    155       const tensorflow::gtl::optional<ErrorSpec>& error,
    156       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
    157       TF_MUST_USE_RESULT;
    158   ::testing::AssertionResult RunAndCompareFromFile(
    159       const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
    160       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
    161       TF_MUST_USE_RESULT;
    162   ::testing::AssertionResult RunAndCompareNoHloPasses(
    163       const tensorflow::StringPiece hlo_string,
    164       const tensorflow::gtl::optional<ErrorSpec>& error,
    165       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
    166       TF_MUST_USE_RESULT;
    167   ::testing::AssertionResult RunAndCompareNoHloPassesFromFile(
    168       const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
    169       const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
    170       TF_MUST_USE_RESULT;
    171 
    172   // Convenience method to force the layout of a given parameter in a module.
    173   // The layout of parameter number 'param_no' in the 'module' is set to
    174   // 'layout'.
    175   void ForceParameterLayout(HloModule* module, int64 param_no,
    176                             const Layout& layout) {
    177     ASSERT_LT(param_no,
    178               module->mutable_entry_computation_layout()->parameter_count());
    179     module->mutable_entry_computation_layout()
    180         ->mutable_parameter_layout(param_no)
    181         ->ResetLayout(layout);
    182   }
    183 
    184   // Convenience method to force the layout of the computation result in a
    185   // module. The result layout of 'module' is set to 'layout'.
    186   void ForceResultLayout(HloModule* module, const Layout& layout) {
    187     module->mutable_entry_computation_layout()
    188         ->mutable_result_layout()
    189         ->ResetLayout(layout);
    190   }
    191 
    192   // Convenience method to clear the layout of the computation result in
    193   // 'module'.
    194   void ForceClearResultLayout(HloModule* module) {
    195     module->mutable_entry_computation_layout()
    196         ->mutable_result_layout()
    197         ->Clear();
    198   }
    199 
    200   // Return an HLO verifier constructed for the test backend.
    201   HloVerifier& verifier() const { return *hlo_verifier_; }
    202 
    203   static string TestName();
    204 
    205   // Returns the backend owned by the test runner.
    206   Backend& backend();
    207 
    208   HloRunner test_runner_;
    209   HloRunner reference_runner_;
    210 
    211   std::unique_ptr<HloVerifier> hlo_verifier_;
    212 
    213   ErrorSpec error_spec_{0.0001};
    214 
    215  private:
    216   // Given the test module, makes a reference module that is ready to run on the
    217   // reference platform. This assumes that the given module is ready to run on
    218   // the test platform.
    219   StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule(
    220       const HloModule& test_module,
    221       const std::function<void(HloModule*)>& reference_preprocessor);
    222 
    223   // Runs the module on two platforms with or without running hlo passes and
    224   // compares the results. Returns whether the results are near or equal. If any
    225   // error happens before the results are computed, returns the error status.
    226   template <typename LiteralPtr>
    227   StatusOr<::testing::AssertionResult> RunAndCompareInternal(
    228       std::unique_ptr<HloModule> module,
    229       const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
    230       const tensorflow::gtl::optional<ErrorSpec>& error, bool run_hlo_passes,
    231       const std::function<void(HloModule*)>& reference_preprocessor);
    232 };
    233 
    234 }  // namespace xla
    235 
    236 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
    237