1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_ASSIGN_OP_H_ 17 #define TENSORFLOW_KERNELS_ASSIGN_OP_H_ 18 19 #define EIGEN_USE_THREADS 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/tensor_types.h" 24 25 namespace tensorflow { 26 27 // TODO(jeff): Get rid of use_exclusive_lock_ option 28 29 // Computes *input[0] = input[1] 30 class AssignOp : public OpKernel { 31 public: 32 explicit AssignOp(OpKernelConstruction* context) : OpKernel(context) { 33 OP_REQUIRES_OK(context, 34 context->GetAttr("use_locking", &use_exclusive_lock_)); 35 OP_REQUIRES_OK(context, 36 context->GetAttr("validate_shape", &validate_shape_)); 37 OP_REQUIRES(context, IsRefType(context->input_type(0)), 38 errors::InvalidArgument("lhs input needs to be a ref type")); 39 } 40 41 void Compute(OpKernelContext* context) override { 42 const Tensor& rhs = context->input(1); 43 44 // We always return the input ref. 45 context->forward_ref_input_to_ref_output(0, 0); 46 47 // We can't always know how this value will be used downstream, 48 // so make conservative assumptions in specifying constraints on 49 // the memory allocation attributes. 50 // TODO(rmlarsen): These conservative constraints make buffer 51 // forwarding unlikely to happen very often. Try to use graph analysis 52 // (possibly the InferAllocAttr pass in the executer) to improve the 53 // situation. 54 AllocatorAttributes attr; 55 attr.set_gpu_compatible(true); 56 attr.set_nic_compatible(true); 57 58 { 59 mutex_lock l(*context->input_ref_mutex(0)); 60 const Tensor& old_lhs = context->mutable_input(0, /* lock_held */ true); 61 const bool same_shape = old_lhs.shape().IsSameSize(rhs.shape()); 62 if (validate_shape_) { 63 OP_REQUIRES( 64 context, same_shape, 65 errors::InvalidArgument( 66 "Assign requires shapes of both tensors to match. lhs shape= ", 67 old_lhs.shape().DebugString(), 68 " rhs shape= ", rhs.shape().DebugString())); 69 } 70 71 // In the code below we try to minimize the amount of memory allocation 72 // and copying by trying the following two shortcuts: 73 // 1. If we can reuse the rhs buffer we avoid both a memory allocation 74 // and copying. 75 // 2. If the lhs is initialized and has the same number of elements as the 76 // rhs we can avoid a memory allocation. 77 78 // 1. Try to reuse the rhs. 79 std::unique_ptr<Tensor> input_alias = context->forward_input( 80 1, old_lhs.dtype(), old_lhs.shape(), DEVICE_MEMORY, attr); 81 if (input_alias != nullptr) { 82 // Transfer ownership to the ref. 83 context->replace_ref_input(0, *input_alias.release(), 84 /* lock_held */ true); 85 return; 86 } 87 88 // 2. Try to copy into an existing buffer. 89 if (old_lhs.IsInitialized() && 90 old_lhs.shape().num_elements() == rhs.shape().num_elements()) { 91 // The existing lhs tensor has already been initialized and the right 92 // hand side can fit in the underlying buffer. 93 Tensor reshaped_old_lhs; 94 if (same_shape) { 95 reshaped_old_lhs = old_lhs; 96 } else { 97 CHECK(reshaped_old_lhs.CopyFrom(old_lhs, rhs.shape())); 98 context->replace_ref_input(0, reshaped_old_lhs, /* lock_held */ true); 99 } 100 if (use_exclusive_lock_) { 101 Copy(context, &reshaped_old_lhs, rhs); 102 return; 103 } 104 } else { 105 // Create a new persistent tensor whose shape matches the right hand 106 // side, hand off to lhs and copy the rhs into it. 107 PersistentTensor copy; 108 Tensor* copyTensor = nullptr; 109 OP_REQUIRES_OK( 110 context, context->allocate_persistent(old_lhs.dtype(), rhs.shape(), 111 ©, ©Tensor, attr)); 112 // We track memory of variables in variable ops instead of in this 113 // assign op. 114 context->clear_recorded_memory(); 115 context->replace_ref_input(0, *copyTensor, /* lock_held */ true); 116 if (use_exclusive_lock_) { 117 Copy(context, copyTensor, rhs); 118 return; 119 } 120 } 121 } 122 123 // The tensor has already been initialized and the right hand side 124 // matches the left hand side's shape. We have been told to do the 125 // copy outside the lock. 126 Tensor old_unlocked_lhs = context->mutable_input(0, /* lock_held */ false); 127 Copy(context, &old_unlocked_lhs, rhs); 128 } 129 130 virtual void Copy(OpKernelContext* context, Tensor* lhs, 131 const Tensor& rhs) = 0; 132 133 bool use_exclusive_lock_; 134 bool validate_shape_; 135 }; 136 137 } // end namespace tensorflow 138 139 #endif // TENSORFLOW_KERNELS_ASSIGN_OP_H_ 140