1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 #include "tensorflow/core/kernels/variable_ops.h" 18 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 #include "tensorflow/core/platform/types.h" 23 24 namespace tensorflow { 25 26 // Resource stored by variables in the resource manager 27 // (legacy, ref-style version). 28 class LegacyVar : public ResourceBase { 29 public: 30 explicit LegacyVar(DataType dtype) : tensor_(dtype) {} 31 // Not copyable or movable. 32 LegacyVar(const LegacyVar&) = delete; 33 LegacyVar& operator=(const LegacyVar&) = delete; 34 35 mutex* mu() { return &mu_; } 36 Tensor* tensor() { return &tensor_; } 37 38 string DebugString() override { 39 return strings::StrCat(DataTypeString(tensor_.dtype()), "/", 40 tensor_.shape().DebugString()); 41 } 42 43 private: 44 mutex mu_; 45 Tensor tensor_; 46 47 ~LegacyVar() override {} 48 }; 49 50 VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) { 51 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); 52 dtype_ = RemoveRefType(context->output_type(0)); 53 } 54 55 void VariableOp::Compute(OpKernelContext* ctx) { 56 mutex_lock l(init_mu_); 57 if (!initialized_) { 58 OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), 59 true /* use name() */)); 60 initialized_ = true; 61 } 62 auto creator = [this](LegacyVar** var) { 63 *var = new LegacyVar(dtype_); 64 (*var)->tensor()->set_shape(shape_); 65 return Status::OK(); 66 }; 67 LegacyVar* var; 68 OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate<LegacyVar>( 69 cinfo_.container(), cinfo_.name(), &var, creator)); 70 // Output a reference to our tensor, so it may be updated. 71 // 72 // As long as the resource manager hasn't been cleared the ref we return 73 // here is valid because it owns a ref on var. 74 ctx->set_output_ref(0, var->mu(), var->tensor()); 75 if (ctx->track_allocations() && var->tensor()->IsInitialized()) { 76 AllocatorAttributes attr; 77 attr.set_gpu_compatible(true); 78 attr.set_nic_compatible(true); 79 ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes()); 80 } 81 var->Unref(); 82 } 83 84 class TemporaryVariableOp : public OpKernel { 85 public: 86 explicit TemporaryVariableOp(OpKernelConstruction* context) 87 : OpKernel(context) { 88 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); 89 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); 90 OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); 91 // Variable name defaults to op name if not specified explicitly. 92 if (var_name_.empty()) var_name_ = name(); 93 } 94 95 void Compute(OpKernelContext* context) override { 96 Status s; 97 ResourceMgr* rm = context->resource_manager(); 98 OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); 99 auto* tmp_var = new TmpVar; 100 OP_REQUIRES(context, tmp_var, 101 errors::ResourceExhausted("Could not allocate TmpVar.")); 102 tmp_var->name = var_name_; 103 s = context->allocate_temp(dtype_, shape_, &tmp_var->val); 104 if (!s.ok()) tmp_var->Unref(); 105 OP_REQUIRES_OK(context, s); 106 OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(), 107 var_name_, tmp_var)); 108 context->set_output_ref(0, &tmp_var->mu, &tmp_var->val); 109 if (context->track_allocations()) { 110 context->record_persistent_memory_allocation( 111 tmp_var->val.AllocatedBytes()); 112 } 113 } 114 115 private: 116 // Refcounted temporary variable resource. 117 friend class DestroyTemporaryVariableOp; 118 struct TmpVar : public ResourceBase { 119 mutex mu; 120 Tensor val; 121 string name; 122 string DebugString() override { return name; } 123 ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; } 124 }; 125 126 TensorShape shape_; 127 DataType dtype_; 128 string var_name_; 129 }; 130 131 class DestroyTemporaryVariableOp : public OpKernel { 132 public: 133 explicit DestroyTemporaryVariableOp(OpKernelConstruction* context) 134 : OpKernel(context) { 135 OP_REQUIRES(context, IsRefType(context->input_type(0)), 136 errors::InvalidArgument("lhs input needs to be a ref type")); 137 OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); 138 OP_REQUIRES(context, !var_name_.empty(), 139 errors::InvalidArgument("Missing var_name attribute")); 140 } 141 142 void Compute(OpKernelContext* context) override { 143 // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed 144 // their execution before this DestroyTemporaryVariable op executes. 145 // This is typically achieved using control dependencies. 146 CHECK(IsRefType(context->input_dtype(0))); 147 Tensor tmpvar = context->mutable_input(0, false); 148 context->set_output(0, tmpvar); 149 ResourceMgr* rm = context->resource_manager(); 150 OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); 151 OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>( 152 context->step_container()->name(), var_name_)); 153 if (context->track_allocations()) { 154 context->record_persistent_memory_allocation( 155 -static_cast<int64>(tmpvar.AllocatedBytes())); 156 } 157 } 158 159 private: 160 string var_name_; 161 }; 162 163 class IsVariableInitializedOp : public OpKernel { 164 public: 165 explicit IsVariableInitializedOp(OpKernelConstruction* context) 166 : OpKernel(context) {} 167 168 void Compute(OpKernelContext* context) override { 169 // Get a mutable input tensor of the Ref input. 170 const Tensor& input_tensor = context->mutable_input(0, false); 171 Tensor* output = nullptr; 172 OP_REQUIRES_OK(context, 173 context->allocate_output(0, TensorShape({}), &output)); 174 auto output_tensor = output->tensor<bool, 0>(); 175 bool result = input_tensor.IsInitialized(); 176 output_tensor() = result; 177 } 178 }; 179 180 REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp); 181 REGISTER_KERNEL_BUILDER(Name("VariableV2").Device(DEVICE_CPU), VariableOp); 182 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU), 183 TemporaryVariableOp); 184 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU), 185 DestroyTemporaryVariableOp); 186 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU), 187 IsVariableInitializedOp); 188 189 #ifdef TENSORFLOW_USE_SYCL 190 #define REGISTER_SYCL_KERNEL(type) \ 191 REGISTER_KERNEL_BUILDER( \ 192 Name("Variable").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"), \ 193 VariableOp); \ 194 REGISTER_KERNEL_BUILDER( \ 195 Name("VariableV2").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"), \ 196 VariableOp); \ 197 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \ 198 .Device(DEVICE_SYCL) \ 199 .TypeConstraint<type>("dtype"), \ 200 TemporaryVariableOp); \ 201 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \ 202 .Device(DEVICE_SYCL) \ 203 .TypeConstraint<type>("T"), \ 204 DestroyTemporaryVariableOp); \ 205 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \ 206 .Device(DEVICE_SYCL) \ 207 .TypeConstraint<type>("dtype") \ 208 .HostMemory("is_initialized"), \ 209 IsVariableInitializedOp); 210 211 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL); 212 #undef REGISTER_SYCL_KERNEL 213 #endif // TENSORFLOW_USE_SYCL 214 215 #if GOOGLE_CUDA 216 // Only register 'Variable' on GPU for the subset of types also supported by 217 // 'Assign' (see dense_update_ops.cc.) 218 #define REGISTER_GPU_KERNELS(type) \ 219 REGISTER_KERNEL_BUILDER( \ 220 Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \ 221 VariableOp); \ 222 REGISTER_KERNEL_BUILDER( \ 223 Name("VariableV2").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \ 224 VariableOp); \ 225 REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \ 226 .Device(DEVICE_GPU) \ 227 .TypeConstraint<type>("dtype"), \ 228 TemporaryVariableOp); \ 229 REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \ 230 .Device(DEVICE_GPU) \ 231 .TypeConstraint<type>("T"), \ 232 DestroyTemporaryVariableOp); \ 233 REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \ 234 .Device(DEVICE_GPU) \ 235 .TypeConstraint<type>("dtype") \ 236 .HostMemory("is_initialized"), \ 237 IsVariableInitializedOp); 238 239 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); 240 TF_CALL_int64(REGISTER_GPU_KERNELS); 241 #undef REGISTER_GPU_KERNELS 242 #endif // GOOGLE_CUDA 243 244 } // namespace tensorflow 245