Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 #include "tensorflow/core/kernels/variable_ops.h"
     18 
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/platform/types.h"
     23 
     24 namespace tensorflow {
     25 
     26 // Resource stored by variables in the resource manager
     27 // (legacy, ref-style version).
     28 class LegacyVar : public ResourceBase {
     29  public:
     30   explicit LegacyVar(DataType dtype) : tensor_(dtype) {}
     31   // Not copyable or movable.
     32   LegacyVar(const LegacyVar&) = delete;
     33   LegacyVar& operator=(const LegacyVar&) = delete;
     34 
     35   mutex* mu() { return &mu_; }
     36   Tensor* tensor() { return &tensor_; }
     37 
     38   string DebugString() override {
     39     return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
     40                            tensor_.shape().DebugString());
     41   }
     42 
     43  private:
     44   mutex mu_;
     45   Tensor tensor_;
     46 
     47   ~LegacyVar() override {}
     48 };
     49 
     50 VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) {
     51   OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
     52   dtype_ = RemoveRefType(context->output_type(0));
     53 }
     54 
     55 void VariableOp::Compute(OpKernelContext* ctx) {
     56   mutex_lock l(init_mu_);
     57   if (!initialized_) {
     58     OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
     59                                     true /* use name() */));
     60     initialized_ = true;
     61   }
     62   auto creator = [this](LegacyVar** var) {
     63     *var = new LegacyVar(dtype_);
     64     (*var)->tensor()->set_shape(shape_);
     65     return Status::OK();
     66   };
     67   LegacyVar* var;
     68   OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate<LegacyVar>(
     69                           cinfo_.container(), cinfo_.name(), &var, creator));
     70   // Output a reference to our tensor, so it may be updated.
     71   //
     72   // As long as the resource manager hasn't been cleared the ref we return
     73   // here is valid because it owns a ref on var.
     74   ctx->set_output_ref(0, var->mu(), var->tensor());
     75   if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
     76     AllocatorAttributes attr;
     77     attr.set_gpu_compatible(true);
     78     attr.set_nic_compatible(true);
     79     ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
     80   }
     81   var->Unref();
     82 }
     83 
     84 class TemporaryVariableOp : public OpKernel {
     85  public:
     86   explicit TemporaryVariableOp(OpKernelConstruction* context)
     87       : OpKernel(context) {
     88     OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
     89     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
     90     OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
     91     // Variable name defaults to op name if not specified explicitly.
     92     if (var_name_.empty()) var_name_ = name();
     93   }
     94 
     95   void Compute(OpKernelContext* context) override {
     96     Status s;
     97     ResourceMgr* rm = context->resource_manager();
     98     OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
     99     auto* tmp_var = new TmpVar;
    100     OP_REQUIRES(context, tmp_var,
    101                 errors::ResourceExhausted("Could not allocate TmpVar."));
    102     tmp_var->name = var_name_;
    103     s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
    104     if (!s.ok()) tmp_var->Unref();
    105     OP_REQUIRES_OK(context, s);
    106     OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
    107                                        var_name_, tmp_var));
    108     context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
    109     if (context->track_allocations()) {
    110       context->record_persistent_memory_allocation(
    111           tmp_var->val.AllocatedBytes());
    112     }
    113   }
    114 
    115  private:
    116   // Refcounted temporary variable resource.
    117   friend class DestroyTemporaryVariableOp;
    118   struct TmpVar : public ResourceBase {
    119     mutex mu;
    120     Tensor val;
    121     string name;
    122     string DebugString() override { return name; }
    123     ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; }
    124   };
    125 
    126   TensorShape shape_;
    127   DataType dtype_;
    128   string var_name_;
    129 };
    130 
    131 class DestroyTemporaryVariableOp : public OpKernel {
    132  public:
    133   explicit DestroyTemporaryVariableOp(OpKernelConstruction* context)
    134       : OpKernel(context) {
    135     OP_REQUIRES(context, IsRefType(context->input_type(0)),
    136                 errors::InvalidArgument("lhs input needs to be a ref type"));
    137     OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
    138     OP_REQUIRES(context, !var_name_.empty(),
    139                 errors::InvalidArgument("Missing var_name attribute"));
    140   }
    141 
    142   void Compute(OpKernelContext* context) override {
    143     // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed
    144     // their execution before this DestroyTemporaryVariable op executes.
    145     // This is typically achieved using control dependencies.
    146     CHECK(IsRefType(context->input_dtype(0)));
    147     Tensor tmpvar = context->mutable_input(0, false);
    148     context->set_output(0, tmpvar);
    149     ResourceMgr* rm = context->resource_manager();
    150     OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
    151     OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>(
    152                                 context->step_container()->name(), var_name_));
    153     if (context->track_allocations()) {
    154       context->record_persistent_memory_allocation(
    155           -static_cast<int64>(tmpvar.AllocatedBytes()));
    156     }
    157   }
    158 
    159  private:
    160   string var_name_;
    161 };
    162 
    163 class IsVariableInitializedOp : public OpKernel {
    164  public:
    165   explicit IsVariableInitializedOp(OpKernelConstruction* context)
    166       : OpKernel(context) {}
    167 
    168   void Compute(OpKernelContext* context) override {
    169     // Get a mutable input tensor of the Ref input.
    170     const Tensor& input_tensor = context->mutable_input(0, false);
    171     Tensor* output = nullptr;
    172     OP_REQUIRES_OK(context,
    173                    context->allocate_output(0, TensorShape({}), &output));
    174     auto output_tensor = output->tensor<bool, 0>();
    175     bool result = input_tensor.IsInitialized();
    176     output_tensor() = result;
    177   }
    178 };
    179 
    180 REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp);
    181 REGISTER_KERNEL_BUILDER(Name("VariableV2").Device(DEVICE_CPU), VariableOp);
    182 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
    183                         TemporaryVariableOp);
    184 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
    185                         DestroyTemporaryVariableOp);
    186 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
    187                         IsVariableInitializedOp);
    188 
    189 #ifdef TENSORFLOW_USE_SYCL
    190 #define REGISTER_SYCL_KERNEL(type)                                          \
    191   REGISTER_KERNEL_BUILDER(                                                  \
    192       Name("Variable").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"),   \
    193       VariableOp);                                                          \
    194   REGISTER_KERNEL_BUILDER(                                                  \
    195       Name("VariableV2").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"), \
    196       VariableOp);                                                          \
    197   REGISTER_KERNEL_BUILDER(Name("TemporaryVariable")                         \
    198                               .Device(DEVICE_SYCL)                          \
    199                               .TypeConstraint<type>("dtype"),               \
    200                           TemporaryVariableOp);                             \
    201   REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable")                  \
    202                               .Device(DEVICE_SYCL)                          \
    203                               .TypeConstraint<type>("T"),                   \
    204                           DestroyTemporaryVariableOp);                      \
    205   REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized")                     \
    206                               .Device(DEVICE_SYCL)                          \
    207                               .TypeConstraint<type>("dtype")                \
    208                               .HostMemory("is_initialized"),                \
    209                           IsVariableInitializedOp);
    210 
    211 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
    212 #undef REGISTER_SYCL_KERNEL
    213 #endif  // TENSORFLOW_USE_SYCL
    214 
    215 #if GOOGLE_CUDA
    216 // Only register 'Variable' on GPU for the subset of types also supported by
    217 // 'Assign' (see dense_update_ops.cc.)
    218 #define REGISTER_GPU_KERNELS(type)                                         \
    219   REGISTER_KERNEL_BUILDER(                                                 \
    220       Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"),   \
    221       VariableOp);                                                         \
    222   REGISTER_KERNEL_BUILDER(                                                 \
    223       Name("VariableV2").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
    224       VariableOp);                                                         \
    225   REGISTER_KERNEL_BUILDER(Name("TemporaryVariable")                        \
    226                               .Device(DEVICE_GPU)                          \
    227                               .TypeConstraint<type>("dtype"),              \
    228                           TemporaryVariableOp);                            \
    229   REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable")                 \
    230                               .Device(DEVICE_GPU)                          \
    231                               .TypeConstraint<type>("T"),                  \
    232                           DestroyTemporaryVariableOp);                     \
    233   REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized")                    \
    234                               .Device(DEVICE_GPU)                          \
    235                               .TypeConstraint<type>("dtype")               \
    236                               .HostMemory("is_initialized"),               \
    237                           IsVariableInitializedOp);
    238 
    239 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
    240 TF_CALL_int64(REGISTER_GPU_KERNELS);
    241 #undef REGISTER_GPU_KERNELS
    242 #endif  // GOOGLE_CUDA
    243 
    244 }  // namespace tensorflow
    245