1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ 16 #define TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "tensorflow/core/framework/function.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/kernels/data/dataset.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/gtl/array_slice.h" 26 #include "tensorflow/core/lib/random/random.h" 27 #include "tensorflow/core/platform/macros.h" 28 29 namespace tensorflow { 30 31 class Device; 32 class OpKernelContext; 33 class ResourceMgr; 34 35 // A `CapturedFunction` encapsulates a TensorFlow function and all of 36 // the runtime support required to execute it. 37 // 38 // The `Dataset`-related classes use `CapturedFunction` to execute 39 // TensorFlow functions outside a the normal `OpKernel::Compute()` 40 // context. 41 class CapturedFunction { 42 public: 43 // NOTE(mrry): The `captured_inputs` are passed by value. For 44 // efficiency, you are recommended to move this argument into the call. 45 static Status Create(const NameAttrList& func, 46 std::vector<Tensor> captured_inputs, 47 std::unique_ptr<CapturedFunction>* out_function); 48 49 ~CapturedFunction(); 50 51 // Runs the "Captured function" using the given FLR and caches the lib and 52 // handle generated during instantiation. If Run is called with a different 53 // lib afterwards, generates an error. This method takes ownership of the 54 // tensors in `args`, in order to be able to deallocate them as early as 55 // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain 56 // ownership of the `args`. 57 Status Run(IteratorContext* ctx, std::vector<Tensor>&& args, 58 std::vector<Tensor>* rets); 59 60 // Synchronously runs the captured function on the given `args`, and stores 61 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 62 // possible. 63 Status RunWithBorrowedArgs(IteratorContext* ctx, 64 const std::vector<Tensor>& args, 65 std::vector<Tensor>* rets); 66 67 // Asynchronously runs the captured function on the given `args`, stores 68 // the results in `*rets`, and calls the given `done` callback when the 69 // function returns. This method takes ownership of the tensors in `args`, 70 // in order to be able to deallocate them as early as possible. 71 void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, 72 std::vector<Tensor>* rets, 73 FunctionLibraryRuntime::DoneCallback done); 74 75 // Returns that additional captured inputs that will be passed to the function 76 // when `Run*()` is called. 77 const std::vector<Tensor>& captured_inputs() { return captured_inputs_; } 78 79 // Returns a step ID for use when running a `CapturedFunction`. 80 static int64 generate_step_id() { 81 // Choose a step ID that is guaranteed not to clash with any 82 // Session-generated step ID. DirectSession only generates 83 // non-negative step IDs (contiguous, starting from 0), and 84 // MasterSession generates 56-bit random step IDs whose MSB is 85 // always 0, so a negative random step ID should suffice. 86 return -std::abs(static_cast<int64>(random::New64())); 87 } 88 89 private: 90 CapturedFunction(const NameAttrList& func, 91 std::vector<Tensor> captured_inputs); 92 93 Status MaybeInstantiate(IteratorContext* ctx, 94 FunctionLibraryRuntime::Handle* out_handle); 95 96 mutex mu_; 97 const NameAttrList func_; 98 FunctionLibraryRuntime* lib_ GUARDED_BY(mu_); 99 FunctionLibraryRuntime::Handle f_handle_ GUARDED_BY(mu_); 100 const std::vector<Tensor> captured_inputs_; 101 DataTypeSlice ret_types_; 102 103 TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); 104 }; 105 106 } // namespace tensorflow 107 108 #endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ 109