Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
     17 #define TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
     18 
     19 #include <initializer_list>
     20 #include <memory>
     21 #include <random>
     22 
     23 #include "tensorflow/compiler/xla/layout_util.h"
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/ptr_util.h"
     26 #include "tensorflow/compiler/xla/service/hlo_module.h"
     27 #include "tensorflow/compiler/xla/xla_data.pb.h"
     28 #include "tensorflow/core/lib/gtl/array_slice.h"
     29 #include "tensorflow/core/platform/types.h"
     30 #include "tensorflow/stream_executor/platform.h"
     31 
     32 namespace xla {
     33 
     34 // A class which generates pseudorandom numbers of a given type within a given
     35 // range. Not cryptographically secure and likely not perfectly evenly
     36 // distributed across the range but sufficient for most tests.
     37 template <typename NativeT>
     38 class PseudorandomGenerator {
     39  public:
     40   explicit PseudorandomGenerator(NativeT min_value, NativeT max_value,
     41                                  uint32 seed)
     42       : min_(min_value), max_(max_value), generator_(seed) {}
     43 
     44   // Get a pseudorandom value.
     45   NativeT get() {
     46     std::uniform_real_distribution<> distribution;
     47     return static_cast<NativeT>(min_ +
     48                                 (max_ - min_) * distribution(generator_));
     49   }
     50 
     51  private:
     52   NativeT min_;
     53   NativeT max_;
     54   std::mt19937 generator_;
     55 };
     56 
     57 // Generates fake data in a literal of the given shape, or returns an error
     58 // status if the element type is currently unhandled for fake data generation.
     59 StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
     60 
     61 // Generates a vector of arguments containing fake data. The number, shape and
     62 // layout of the arguments is appropriate for given HLO module.
     63 //
     64 // Will handle special cases such as making sure that indices used for dynamic
     65 // slices are bounded, reduces that call adds use 0 as an init value, etc.
     66 StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
     67     HloModule* const module);
     68 
     69 // Check that a given module satisfies various constraints before trying to
     70 // execute it.
     71 Status VerifyHloModule(const perftools::gputools::Platform& platform,
     72                        HloModule* const module);
     73 
     74 }  // namespace xla
     75 
     76 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
     77