1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/service/backend.h" 24 #include "tensorflow/compiler/xla/service/computation_layout.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_runner.h" 27 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 28 #include "tensorflow/compiler/xla/service/platform_util.h" 29 #include "tensorflow/compiler/xla/shape_layout.h" 30 #include "tensorflow/compiler/xla/statusor.h" 31 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/core/lib/gtl/array_slice.h" 35 #include "tensorflow/core/lib/gtl/optional.h" 36 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 37 #include "tensorflow/core/platform/test.h" 38 39 namespace xla { 40 41 // A base class for tests which build and/or run HLO code. The class includes 42 // support for running an HLO module on two platforms and compare the results. 43 // This is a lower level of abstraction than using the client interface and 44 // enables, for one, explicitly building a graph of HLO instructions to run. 45 // 46 // This can also be used to write text/file-based test cases. Note that the test 47 // target is responsible for linking the needed backends. A covenient way to do 48 // this is to make it an xla_test: it will generate test targets linking with 49 // the respective backends, which will be used as the test backend; the 50 // interpreter backend is already linked with hlo_test_base so it will be the 51 // default reference backend. For example, if you want to compare both cpu vs. 52 // interpreter, and gpu vs. interpreter, you can: 53 // 54 // xla_test ( 55 // name = "sample_text_test", 56 // srcs = ["sample_text_test.cc"], 57 // backends = [ 58 // "cpu", 59 // "gpu", 60 // ], 61 // deps = [ 62 // "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", 63 // ... 64 // ], 65 // ) 66 // 67 // For a more detailed example, see "../tests/sample_text_test.cc". 68 class HloTestBase : public ::testing::Test { 69 protected: 70 // This uses the interpreter backend as the reference backend and 71 // automatically finds another supported backend as the test backend. If the 72 // interpreter is the only supported backend, it will be both the test backend 73 // and the reference backend. 74 HloTestBase(); 75 76 // If your test doesn't use interpreter as the reference backend, you can use 77 // this constructor. Note that your test target is responsible for linking in 78 // both needed backends. 79 HloTestBase(::perftools::gputools::Platform* test_platform, 80 ::perftools::gputools::Platform* reference_platform); 81 82 ~HloTestBase() override {} 83 84 // Creates a new HLO module for a test. The module created will have 85 // TestName() for its name; it will also automatically populate its debug 86 // options from command-line flags. If you want a fresh HloModule object and 87 // then add HloComputations to it, it's recommended to use this method in your 88 // tests. 89 static std::unique_ptr<HloModule> CreateNewModule(); 90 91 // Populates debug options from command-line flags and adjusts the options for 92 // testing. It is recommended to use this when you need to pass in 93 // DebugOptions, e.g. when creating a module from a string or a file. 94 static DebugOptions GetDebugOptionsForTest(); 95 96 // Executes the given module and return the result as a Literal. 97 StatusOr<std::unique_ptr<Literal>> Execute( 98 std::unique_ptr<HloModule> module, 99 tensorflow::gtl::ArraySlice<Literal*> arguments); 100 101 std::unique_ptr<Literal> ExecuteAndTransfer( 102 std::unique_ptr<HloModule> module, 103 tensorflow::gtl::ArraySlice<Literal*> arguments); 104 105 // Executes the given hlo module on two backends and compares results. 106 // 107 // 'arguments': the input of the hlo module. The LiteralPtr type accepts 108 // Literal* or std::unique_ptr<Literal>. 109 // 110 // 'error': if has value, expects the results to be near (within the error 111 // bound). Otherwise, expects the results to be equal. 112 // 113 // 'reference_preprocessor': the module should be ready to run on the test 114 // backend, but it might need to be tailored so that it is able to run on the 115 // reference backend. Note that the program shape of the module must not be 116 // modified. 117 template <typename LiteralPtr> 118 ::testing::AssertionResult RunAndCompare( 119 std::unique_ptr<HloModule> module, 120 const tensorflow::gtl::ArraySlice<LiteralPtr> arguments, 121 const tensorflow::gtl::optional<ErrorSpec>& error, 122 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 123 TF_MUST_USE_RESULT; 124 125 // Same as above, except that the module will be executed without Hlo 126 // optimization. 127 template <typename LiteralPtr> 128 ::testing::AssertionResult RunAndCompareNoHloPasses( 129 std::unique_ptr<HloModule> module, 130 const tensorflow::gtl::ArraySlice<LiteralPtr> arguments, 131 const tensorflow::gtl::optional<ErrorSpec>& error, 132 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 133 TF_MUST_USE_RESULT; 134 135 // Executes an hlo module with fake inputs and compares the results. 136 ::testing::AssertionResult RunAndCompare( 137 std::unique_ptr<HloModule> module, 138 const tensorflow::gtl::optional<ErrorSpec>& error, 139 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 140 TF_MUST_USE_RESULT; 141 142 // Same as above, except that the module will be executed without Hlo 143 // optimization. 144 ::testing::AssertionResult RunAndCompareNoHloPasses( 145 std::unique_ptr<HloModule> module, 146 const tensorflow::gtl::optional<ErrorSpec>& error, 147 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 148 TF_MUST_USE_RESULT; 149 150 // Convenient wrappers for executing and comparing an hlo module with fake 151 // input. Module can be passed in directly, or parsed from an hlo_string, 152 // or loaded from a file. 153 ::testing::AssertionResult RunAndCompare( 154 const tensorflow::StringPiece hlo_string, 155 const tensorflow::gtl::optional<ErrorSpec>& error, 156 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 157 TF_MUST_USE_RESULT; 158 ::testing::AssertionResult RunAndCompareFromFile( 159 const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error, 160 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 161 TF_MUST_USE_RESULT; 162 ::testing::AssertionResult RunAndCompareNoHloPasses( 163 const tensorflow::StringPiece hlo_string, 164 const tensorflow::gtl::optional<ErrorSpec>& error, 165 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 166 TF_MUST_USE_RESULT; 167 ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( 168 const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error, 169 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 170 TF_MUST_USE_RESULT; 171 172 // Convenience method to force the layout of a given parameter in a module. 173 // The layout of parameter number 'param_no' in the 'module' is set to 174 // 'layout'. 175 void ForceParameterLayout(HloModule* module, int64 param_no, 176 const Layout& layout) { 177 ASSERT_LT(param_no, 178 module->mutable_entry_computation_layout()->parameter_count()); 179 module->mutable_entry_computation_layout() 180 ->mutable_parameter_layout(param_no) 181 ->ResetLayout(layout); 182 } 183 184 // Convenience method to force the layout of the computation result in a 185 // module. The result layout of 'module' is set to 'layout'. 186 void ForceResultLayout(HloModule* module, const Layout& layout) { 187 module->mutable_entry_computation_layout() 188 ->mutable_result_layout() 189 ->ResetLayout(layout); 190 } 191 192 // Convenience method to clear the layout of the computation result in 193 // 'module'. 194 void ForceClearResultLayout(HloModule* module) { 195 module->mutable_entry_computation_layout() 196 ->mutable_result_layout() 197 ->Clear(); 198 } 199 200 // Return an HLO verifier constructed for the test backend. 201 HloVerifier& verifier() const { return *hlo_verifier_; } 202 203 static string TestName(); 204 205 // Returns the backend owned by the test runner. 206 Backend& backend(); 207 208 HloRunner test_runner_; 209 HloRunner reference_runner_; 210 211 std::unique_ptr<HloVerifier> hlo_verifier_; 212 213 ErrorSpec error_spec_{0.0001}; 214 215 private: 216 // Given the test module, makes a reference module that is ready to run on the 217 // reference platform. This assumes that the given module is ready to run on 218 // the test platform. 219 StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule( 220 const HloModule& test_module, 221 const std::function<void(HloModule*)>& reference_preprocessor); 222 223 // Runs the module on two platforms with or without running hlo passes and 224 // compares the results. Returns whether the results are near or equal. If any 225 // error happens before the results are computed, returns the error status. 226 template <typename LiteralPtr> 227 StatusOr<::testing::AssertionResult> RunAndCompareInternal( 228 std::unique_ptr<HloModule> module, 229 const tensorflow::gtl::ArraySlice<LiteralPtr> arguments, 230 const tensorflow::gtl::optional<ErrorSpec>& error, bool run_hlo_passes, 231 const std::function<void(HloModule*)>& reference_preprocessor); 232 }; 233 234 } // namespace xla 235 236 #endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 237