Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
     17 #define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
     18 
     19 #include <memory>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/common_runtime/device.h"
     23 #include "tensorflow/core/common_runtime/device_factory.h"
     24 #include "tensorflow/core/framework/allocator.h"
     25 #include "tensorflow/core/framework/device_base.h"
     26 #include "tensorflow/core/framework/graph.pb.h"
     27 #include "tensorflow/core/framework/node_def.pb.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/resource_mgr.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_testutil.h"
     32 #include "tensorflow/core/framework/types.h"
     33 #include "tensorflow/core/framework/types.pb.h"
     34 #include "tensorflow/core/lib/core/status.h"
     35 #include "tensorflow/core/lib/core/status_test_util.h"
     36 #include "tensorflow/core/lib/gtl/array_slice.h"
     37 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     38 #include "tensorflow/core/lib/gtl/stl_util.h"
     39 #include "tensorflow/core/platform/env.h"
     40 #include "tensorflow/core/platform/logging.h"
     41 #include "tensorflow/core/platform/macros.h"
     42 #include "tensorflow/core/platform/mutex.h"
     43 #include "tensorflow/core/platform/test.h"
     44 #include "tensorflow/core/platform/types.h"
     45 #include "tensorflow/core/public/session_options.h"
     46 #include "tensorflow/core/public/version.h"
     47 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
     48 
     49 namespace tensorflow {
     50 namespace test {
     51 
     52 inline void SetOutputAttrs(OpKernelContext::Params* params,
     53                            std::vector<AllocatorAttributes>* attrs) {
     54   attrs->clear();
     55   for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
     56     AllocatorAttributes attr;
     57     const bool on_host =
     58         (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
     59     attr.set_on_host(on_host);
     60     attrs->push_back(attr);
     61   }
     62   params->output_attr_array = gtl::vector_as_array(attrs);
     63 }
     64 
     65 }  // namespace test
     66 
     67 // Helpful functions to test operators.
     68 //
     69 // This class will eventually be replaced / heavily modified
     70 // to use the BrainClient interface.
     71 class OpsTestBase : public ::testing::Test {
     72  public:
     73   OpsTestBase()
     74       : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
     75         device_type_(DEVICE_CPU) {
     76     CHECK(device_.get()) << "Could not create CPU device";
     77     allocator_ = device_->GetAllocator(AllocatorAttributes());
     78   }
     79 
     80   ~OpsTestBase() override {
     81     gtl::STLDeleteElements(&tensors_);
     82     gtl::STLDeleteElements(&managed_outputs_);
     83     context_.reset(nullptr);
     84     params_.reset(nullptr);
     85   }
     86 
     87   // Allow kernel unit tests to run on GPU
     88   void SetDevice(const DeviceType& device_type, std::unique_ptr<Device> device);
     89 
     90   void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
     91 
     92   // Clients can manipulate the underlying NodeDef via this accessor.
     93   NodeDef* node_def() { return &node_def_; }
     94 
     95   // Initializes an operator that takes in 'input_types' as input
     96   // and output types as output.
     97   //
     98   // Returns the status of initialization.
     99   Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); }
    100 
    101   // Only use this directly if you have a deprecated op that you need to test.
    102   Status InitOpWithGraphVersion(int graph_def_version) {
    103     Status status;
    104     kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(),
    105                              node_def_, graph_def_version, &status);
    106     if (kernel_ != nullptr) input_types_ = kernel_->input_types();
    107     return status;
    108   }
    109 
    110   // Adds an input for every element described by the shape.
    111   // 'input_mapping' maps an index (0...NumElements(shape)) to a
    112   // value.
    113   //
    114   // TODO(vrv): Replace with something like a BrainClient Feed.
    115   template <typename T>
    116   void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) {
    117     test::FillFn(AddInput(DataTypeToEnum<T>::v(), shape), input_mapping);
    118   }
    119 
    120   // Like AddInput but takes in an explicit arrayslice of data.
    121   template <typename T>
    122   void AddInputFromArray(const TensorShape& shape,
    123                          const gtl::ArraySlice<T>& data) {
    124     test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
    125   }
    126 
    127   // Convenience function to add an input and populate it with the elements from
    128   // an initializer list converting the types as needed.
    129   template <typename T, typename SrcType>
    130   void AddInputFromList(const TensorShape& shape,
    131                         std::initializer_list<SrcType> data) {
    132     test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
    133   }
    134 
    135   // Adds a Resource type as input. If <container> is empty, uses the default
    136   // container name.
    137   template <typename T>
    138   void AddResourceInput(const string& container, const string& name,
    139                         T* resource) {
    140     CHECK_GT(input_types_.size(), inputs_.size())
    141         << "Adding more inputs than types; perhaps you need to call MakeOp";
    142     ResourceMgr* rm = device_->resource_manager();
    143     EXPECT_TRUE(
    144         rm->Create(container == "" ? rm->default_container() : container, name,
    145                    resource)
    146             .ok());
    147     TypeIndex type_index = MakeTypeIndex<T>();
    148     ResourceHandle handle;
    149     handle.set_device(device_->name());
    150     handle.set_container(container);
    151     handle.set_name(name);
    152     handle.set_hash_code(type_index.hash_code());
    153     handle.set_maybe_type_name(type_index.name());
    154     Tensor* input = new Tensor(allocator(), DT_RESOURCE, TensorShape({}));
    155     input->scalar<ResourceHandle>()() = handle;
    156     tensors_.push_back(input);
    157     inputs_.push_back({nullptr, input});
    158   }
    159 
    160   // Runs an operation producing 'num_outputs' outputs.
    161   //
    162   // Returns the context's status after running the operation.
    163   Status RunOpKernel() {
    164     // Make sure the old OpKernelContext is deleted before the Params
    165     // it was using.
    166     context_.reset(nullptr);
    167 
    168     params_.reset(new OpKernelContext::Params);
    169     params_.get()->device = device_.get();
    170     params_.get()->frame_iter = FrameAndIter(0, 0);
    171     params_.get()->inputs = &inputs_;
    172     params_.get()->op_kernel = kernel_.get();
    173     step_container_.reset(new ScopedStepContainer(0, [](const string&) {}));
    174     params_->step_container = step_container_.get();
    175     std::vector<AllocatorAttributes> attrs;
    176     test::SetOutputAttrs(params_.get(), &attrs);
    177     checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
    178     params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
    179     params_.get()->resource_manager = device_.get()->resource_manager();
    180 
    181     context_.reset(new OpKernelContext(params_.get()));
    182     device_->Compute(kernel_.get(), context_.get());
    183     return context_->status();
    184   }
    185 
    186   // Returns the tensor input for 'input_index'.
    187   //
    188   // REQUIRES: 0 <= input_index < context_->num_inputs()
    189   const Tensor& GetInput(int input_index) const {
    190     CHECK_LT(input_index, context_->num_inputs());
    191     CHECK(!IsRefType(context_->input_dtype(input_index)));
    192     return context_->input(input_index);
    193   }
    194 
    195   TensorValue mutable_input(int input_index) {
    196     CHECK_LT(input_index, inputs_.size());
    197     return inputs_[input_index];
    198   }
    199   // Returns the tensor output for 'output_index'.
    200   //
    201   // REQUIRES: 0 <= output_index < context_->num_outputs()
    202   Tensor* GetOutput(int output_index);
    203 
    204   Allocator* allocator() { return allocator_; }
    205 
    206   const DataTypeVector& output_types() const { return kernel_->output_types(); }
    207 
    208  private:
    209   Tensor* AddInput(DataType dtype, const TensorShape& shape) {
    210     CHECK_GT(input_types_.size(), inputs_.size())
    211         << "Adding more inputs than types; perhaps you need to call MakeOp";
    212     bool is_ref = IsRefType(input_types_[inputs_.size()]);
    213     Tensor* input = new Tensor(allocator(), dtype, shape);
    214     tensors_.push_back(input);
    215     if (is_ref) {
    216       CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), dtype);
    217       inputs_.push_back({&lock_for_refs_, input});
    218     } else {
    219       CHECK_EQ(input_types_[inputs_.size()], dtype);
    220       inputs_.push_back({nullptr, input});
    221     }
    222     return input;
    223   }
    224 
    225  protected:
    226   std::unique_ptr<Device> device_;
    227   // The device allocator, or the managed_allocator_ below if running on GPU.
    228   Allocator* allocator_;
    229 
    230   std::unique_ptr<OpKernel> kernel_;
    231   std::unique_ptr<ScopedStepContainer> step_container_;
    232   NodeDef node_def_;
    233   DataTypeVector input_types_;
    234   DeviceType device_type_;
    235 
    236   mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs
    237 
    238   gtl::InlinedVector<TensorValue, 4> inputs_;
    239   // Owns Tensors.
    240   std::vector<Tensor*> tensors_;
    241   // Copies of the outputs in unified memory (host and device accessible).
    242   std::vector<Tensor*> managed_outputs_;
    243 
    244   std::unique_ptr<OpKernelContext::Params> params_;
    245   std::unique_ptr<OpKernelContext> context_;
    246   // Unified memory allocator, only used when running on GPU.
    247   std::unique_ptr<Allocator> managed_allocator_;
    248 
    249  private:
    250   TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);
    251 };
    252 
    253 }  // namespace tensorflow
    254 
    255 #endif  // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
    256