Home | History | Annotate | Download | only in data
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
     16 #define TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
     17 
     18 #include <memory>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/function.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/kernels/data/dataset.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/lib/gtl/array_slice.h"
     26 #include "tensorflow/core/lib/random/random.h"
     27 #include "tensorflow/core/platform/macros.h"
     28 
     29 namespace tensorflow {
     30 
     31 class Device;
     32 class OpKernelContext;
     33 class ResourceMgr;
     34 
     35 // A `CapturedFunction` encapsulates a TensorFlow function and all of
     36 // the runtime support required to execute it.
     37 //
     38 // The `Dataset`-related classes use `CapturedFunction` to execute
     39 // TensorFlow functions outside a the normal `OpKernel::Compute()`
     40 // context.
     41 class CapturedFunction {
     42  public:
     43   // NOTE(mrry): The `captured_inputs` are passed by value. For
     44   // efficiency, you are recommended to move this argument into the call.
     45   static Status Create(const NameAttrList& func,
     46                        std::vector<Tensor> captured_inputs,
     47                        std::unique_ptr<CapturedFunction>* out_function);
     48 
     49   ~CapturedFunction();
     50 
     51   // Runs the "Captured function" using the given FLR and caches the lib and
     52   // handle generated during instantiation. If Run is called with a different
     53   // lib afterwards, generates an error. This method takes ownership of the
     54   // tensors in `args`, in order to be able to deallocate them as early as
     55   // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
     56   // ownership of the `args`.
     57   Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
     58              std::vector<Tensor>* rets);
     59 
     60   // Synchronously runs the captured function on the given `args`, and stores
     61   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
     62   // possible.
     63   Status RunWithBorrowedArgs(IteratorContext* ctx,
     64                              const std::vector<Tensor>& args,
     65                              std::vector<Tensor>* rets);
     66 
     67   // Asynchronously runs the captured function on the given `args`, stores
     68   // the results in `*rets`, and calls the given `done` callback when the
     69   // function returns. This method takes ownership of the tensors in `args`,
     70   // in order to be able to deallocate them as early as possible.
     71   void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
     72                 std::vector<Tensor>* rets,
     73                 FunctionLibraryRuntime::DoneCallback done);
     74 
     75   // Returns that additional captured inputs that will be passed to the function
     76   // when `Run*()` is called.
     77   const std::vector<Tensor>& captured_inputs() { return captured_inputs_; }
     78 
     79   // Returns a step ID for use when running a `CapturedFunction`.
     80   static int64 generate_step_id() {
     81     // Choose a step ID that is guaranteed not to clash with any
     82     // Session-generated step ID. DirectSession only generates
     83     // non-negative step IDs (contiguous, starting from 0), and
     84     // MasterSession generates 56-bit random step IDs whose MSB is
     85     // always 0, so a negative random step ID should suffice.
     86     return -std::abs(static_cast<int64>(random::New64()));
     87   }
     88 
     89  private:
     90   CapturedFunction(const NameAttrList& func,
     91                    std::vector<Tensor> captured_inputs);
     92 
     93   Status MaybeInstantiate(IteratorContext* ctx,
     94                           FunctionLibraryRuntime::Handle* out_handle);
     95 
     96   mutex mu_;
     97   const NameAttrList func_;
     98   FunctionLibraryRuntime* lib_ GUARDED_BY(mu_);
     99   FunctionLibraryRuntime::Handle f_handle_ GUARDED_BY(mu_);
    100   const std::vector<Tensor> captured_inputs_;
    101   DataTypeSlice ret_types_;
    102 
    103   TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
    104 };
    105 
    106 }  // namespace tensorflow
    107 
    108 #endif  // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
    109