1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/data_flow_ops.cc. 17 18 #include <limits.h> 19 #include <vector> 20 21 #include "tensorflow/core/common_runtime/device.h" 22 #include "tensorflow/core/framework/device_base.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/gtl/map_util.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace tensorflow { 37 38 class GetSessionHandleOp : public OpKernel { 39 public: 40 explicit GetSessionHandleOp(OpKernelConstruction* context) 41 : OpKernel(context) {} 42 43 void Compute(OpKernelContext* ctx) override { 44 const Tensor& val = ctx->input(0); 45 int64 id = ctx->session_state()->GetNewId(); 46 TensorStore::TensorAndKey tk{val, id, requested_device()}; 47 OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk)); 48 49 Tensor* handle = nullptr; 50 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); 51 if (ctx->expected_output_dtype(0) == DT_RESOURCE) { 52 ResourceHandle resource_handle = MakeResourceHandle<Tensor>( 53 ctx, SessionState::kTensorHandleResourceTypeName, 54 tk.GetHandle(name())); 55 resource_handle.set_maybe_type_name( 56 SessionState::kTensorHandleResourceTypeName); 57 handle->scalar<ResourceHandle>()() = resource_handle; 58 } else { 59 // Legacy behavior in V1. 60 handle->flat<string>().setConstant(tk.GetHandle(name())); 61 } 62 } 63 64 TF_DISALLOW_COPY_AND_ASSIGN(GetSessionHandleOp); 65 }; 66 67 REGISTER_KERNEL_BUILDER(Name("GetSessionHandle").Device(DEVICE_CPU), 68 GetSessionHandleOp); 69 REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2").Device(DEVICE_CPU), 70 GetSessionHandleOp); 71 72 #define REGISTER_GPU_KERNEL(type) \ 73 REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \ 74 .Device(DEVICE_GPU) \ 75 .HostMemory("handle") \ 76 .TypeConstraint<type>("T"), \ 77 GetSessionHandleOp) \ 78 REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2") \ 79 .Device(DEVICE_GPU) \ 80 .HostMemory("handle") \ 81 .TypeConstraint<type>("T"), \ 82 GetSessionHandleOp) 83 84 TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL); 85 REGISTER_GPU_KERNEL(bool); 86 #undef REGISTER_GPU_KERNEL 87 88 #ifdef TENSORFLOW_USE_SYCL 89 #define REGISTER_SYCL_KERNEL(type) \ 90 REGISTER_KERNEL_BUILDER(Name("GetSessionHandle") \ 91 .Device(DEVICE_SYCL) \ 92 .HostMemory("handle") \ 93 .TypeConstraint<type>("T"), \ 94 GetSessionHandleOp) \ 95 REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2") \ 96 .Device(DEVICE_SYCL) \ 97 .HostMemory("handle") \ 98 .TypeConstraint<type>("T"), \ 99 GetSessionHandleOp) 100 101 TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL); 102 REGISTER_SYCL_KERNEL(bool); 103 #undef REGISTER_SYCL_KERNEL 104 #endif // TENSORFLOW_USE_SYCL 105 106 class GetSessionTensorOp : public OpKernel { 107 public: 108 explicit GetSessionTensorOp(OpKernelConstruction* context) 109 : OpKernel(context) {} 110 111 void Compute(OpKernelContext* ctx) override { 112 const Tensor& handle = ctx->input(0); 113 const string& name = handle.scalar<string>()(); 114 Tensor val; 115 OP_REQUIRES_OK(ctx, ctx->session_state()->GetTensor(name, &val)); 116 ctx->set_output(0, val); 117 } 118 119 TF_DISALLOW_COPY_AND_ASSIGN(GetSessionTensorOp); 120 }; 121 122 REGISTER_KERNEL_BUILDER(Name("GetSessionTensor").Device(DEVICE_CPU), 123 GetSessionTensorOp); 124 125 #define REGISTER_GPU_KERNEL(type) \ 126 REGISTER_KERNEL_BUILDER(Name("GetSessionTensor") \ 127 .Device(DEVICE_GPU) \ 128 .HostMemory("handle") \ 129 .TypeConstraint<type>("dtype"), \ 130 GetSessionTensorOp) 131 132 TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL); 133 REGISTER_GPU_KERNEL(bool); 134 #undef REGISTER_GPU_KERNEL 135 136 #ifdef TENSORFLOW_USE_SYCL 137 #define REGISTER_SYCL_KERNEL(type) \ 138 REGISTER_KERNEL_BUILDER(Name("GetSessionTensor") \ 139 .Device(DEVICE_SYCL) \ 140 .HostMemory("handle") \ 141 .TypeConstraint<type>("dtype"), \ 142 GetSessionTensorOp) 143 144 TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL); 145 REGISTER_SYCL_KERNEL(bool); 146 #undef REGISTER_SYCL_KERNEL 147 #endif // TENSORFLOW_USE_SYCL 148 149 class DeleteSessionTensorOp : public OpKernel { 150 public: 151 explicit DeleteSessionTensorOp(OpKernelConstruction* context) 152 : OpKernel(context) {} 153 154 void Compute(OpKernelContext* ctx) override { 155 const Tensor& handle = ctx->input(0); 156 const string& name = handle.scalar<string>()(); 157 OP_REQUIRES_OK(ctx, ctx->session_state()->DeleteTensor(name)); 158 } 159 160 TF_DISALLOW_COPY_AND_ASSIGN(DeleteSessionTensorOp); 161 }; 162 163 REGISTER_KERNEL_BUILDER(Name("DeleteSessionTensor").Device(DEVICE_CPU), 164 DeleteSessionTensorOp); 165 REGISTER_KERNEL_BUILDER( 166 Name("DeleteSessionTensor").Device(DEVICE_GPU).HostMemory("handle"), 167 DeleteSessionTensorOp); 168 169 #ifdef TENSORFLOW_USE_SYCL 170 REGISTER_KERNEL_BUILDER( 171 Name("DeleteSessionTensor").Device(DEVICE_SYCL).HostMemory("handle"), 172 DeleteSessionTensorOp); 173 #endif // TENSORFLOW_USE_SYCL 174 } // namespace tensorflow 175