Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/data_flow_ops.cc.
     17 
     18 #include <limits.h>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/common_runtime/device.h"
     22 #include "tensorflow/core/framework/device_base.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/lib/gtl/map_util.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/platform/mutex.h"
     33 #include "tensorflow/core/platform/thread_annotations.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace tensorflow {
     37 
     38 class GetSessionHandleOp : public OpKernel {
     39  public:
     40   explicit GetSessionHandleOp(OpKernelConstruction* context)
     41       : OpKernel(context) {}
     42 
     43   void Compute(OpKernelContext* ctx) override {
     44     const Tensor& val = ctx->input(0);
     45     int64 id = ctx->session_state()->GetNewId();
     46     TensorStore::TensorAndKey tk{val, id, requested_device()};
     47     OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk));
     48 
     49     Tensor* handle = nullptr;
     50     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
     51     if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
     52       ResourceHandle resource_handle = MakeResourceHandle<Tensor>(
     53           ctx, SessionState::kTensorHandleResourceTypeName,
     54           tk.GetHandle(name()));
     55       resource_handle.set_maybe_type_name(
     56           SessionState::kTensorHandleResourceTypeName);
     57       handle->scalar<ResourceHandle>()() = resource_handle;
     58     } else {
     59       // Legacy behavior in V1.
     60       handle->flat<string>().setConstant(tk.GetHandle(name()));
     61     }
     62   }
     63 
     64   TF_DISALLOW_COPY_AND_ASSIGN(GetSessionHandleOp);
     65 };
     66 
     67 REGISTER_KERNEL_BUILDER(Name("GetSessionHandle").Device(DEVICE_CPU),
     68                         GetSessionHandleOp);
     69 REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2").Device(DEVICE_CPU),
     70                         GetSessionHandleOp);
     71 
     72 #define REGISTER_GPU_KERNEL(type)                         \
     73   REGISTER_KERNEL_BUILDER(Name("GetSessionHandle")        \
     74                               .Device(DEVICE_GPU)         \
     75                               .HostMemory("handle")       \
     76                               .TypeConstraint<type>("T"), \
     77                           GetSessionHandleOp)             \
     78   REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2")      \
     79                               .Device(DEVICE_GPU)         \
     80                               .HostMemory("handle")       \
     81                               .TypeConstraint<type>("T"), \
     82                           GetSessionHandleOp)
     83 
     84 TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
     85 REGISTER_GPU_KERNEL(bool);
     86 #undef REGISTER_GPU_KERNEL
     87 
     88 #ifdef TENSORFLOW_USE_SYCL
     89 #define REGISTER_SYCL_KERNEL(type)                        \
     90   REGISTER_KERNEL_BUILDER(Name("GetSessionHandle")        \
     91                               .Device(DEVICE_SYCL)        \
     92                               .HostMemory("handle")       \
     93                               .TypeConstraint<type>("T"), \
     94                           GetSessionHandleOp)             \
     95   REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2")      \
     96                               .Device(DEVICE_SYCL)        \
     97                               .HostMemory("handle")       \
     98                               .TypeConstraint<type>("T"), \
     99                           GetSessionHandleOp)
    100 
    101 TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
    102 REGISTER_SYCL_KERNEL(bool);
    103 #undef REGISTER_SYCL_KERNEL
    104 #endif  // TENSORFLOW_USE_SYCL
    105 
    106 class GetSessionTensorOp : public OpKernel {
    107  public:
    108   explicit GetSessionTensorOp(OpKernelConstruction* context)
    109       : OpKernel(context) {}
    110 
    111   void Compute(OpKernelContext* ctx) override {
    112     const Tensor& handle = ctx->input(0);
    113     const string& name = handle.scalar<string>()();
    114     Tensor val;
    115     OP_REQUIRES_OK(ctx, ctx->session_state()->GetTensor(name, &val));
    116     ctx->set_output(0, val);
    117   }
    118 
    119   TF_DISALLOW_COPY_AND_ASSIGN(GetSessionTensorOp);
    120 };
    121 
    122 REGISTER_KERNEL_BUILDER(Name("GetSessionTensor").Device(DEVICE_CPU),
    123                         GetSessionTensorOp);
    124 
    125 #define REGISTER_GPU_KERNEL(type)                             \
    126   REGISTER_KERNEL_BUILDER(Name("GetSessionTensor")            \
    127                               .Device(DEVICE_GPU)             \
    128                               .HostMemory("handle")           \
    129                               .TypeConstraint<type>("dtype"), \
    130                           GetSessionTensorOp)
    131 
    132 TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
    133 REGISTER_GPU_KERNEL(bool);
    134 #undef REGISTER_GPU_KERNEL
    135 
    136 #ifdef TENSORFLOW_USE_SYCL
    137 #define REGISTER_SYCL_KERNEL(type)                            \
    138   REGISTER_KERNEL_BUILDER(Name("GetSessionTensor")            \
    139                               .Device(DEVICE_SYCL)            \
    140                               .HostMemory("handle")           \
    141                               .TypeConstraint<type>("dtype"), \
    142                           GetSessionTensorOp)
    143 
    144 TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
    145 REGISTER_SYCL_KERNEL(bool);
    146 #undef REGISTER_SYCL_KERNEL
    147 #endif  // TENSORFLOW_USE_SYCL
    148 
    149 class DeleteSessionTensorOp : public OpKernel {
    150  public:
    151   explicit DeleteSessionTensorOp(OpKernelConstruction* context)
    152       : OpKernel(context) {}
    153 
    154   void Compute(OpKernelContext* ctx) override {
    155     const Tensor& handle = ctx->input(0);
    156     const string& name = handle.scalar<string>()();
    157     OP_REQUIRES_OK(ctx, ctx->session_state()->DeleteTensor(name));
    158   }
    159 
    160   TF_DISALLOW_COPY_AND_ASSIGN(DeleteSessionTensorOp);
    161 };
    162 
    163 REGISTER_KERNEL_BUILDER(Name("DeleteSessionTensor").Device(DEVICE_CPU),
    164                         DeleteSessionTensorOp);
    165 REGISTER_KERNEL_BUILDER(
    166     Name("DeleteSessionTensor").Device(DEVICE_GPU).HostMemory("handle"),
    167     DeleteSessionTensorOp);
    168 
    169 #ifdef TENSORFLOW_USE_SYCL
    170 REGISTER_KERNEL_BUILDER(
    171     Name("DeleteSessionTensor").Device(DEVICE_SYCL).HostMemory("handle"),
    172     DeleteSessionTensorOp);
    173 #endif  // TENSORFLOW_USE_SYCL
    174 }  // namespace tensorflow
    175