Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_ASSIGN_OP_H_
     17 #define TENSORFLOW_KERNELS_ASSIGN_OP_H_
     18 
     19 #define EIGEN_USE_THREADS
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/tensor_types.h"
     24 
     25 namespace tensorflow {
     26 
     27 // TODO(jeff): Get rid of use_exclusive_lock_ option
     28 
     29 // Computes *input[0] = input[1]
     30 class AssignOp : public OpKernel {
     31  public:
     32   explicit AssignOp(OpKernelConstruction* context) : OpKernel(context) {
     33     OP_REQUIRES_OK(context,
     34                    context->GetAttr("use_locking", &use_exclusive_lock_));
     35     OP_REQUIRES_OK(context,
     36                    context->GetAttr("validate_shape", &validate_shape_));
     37     OP_REQUIRES(context, IsRefType(context->input_type(0)),
     38                 errors::InvalidArgument("lhs input needs to be a ref type"));
     39   }
     40 
     41   void Compute(OpKernelContext* context) override {
     42     const Tensor& rhs = context->input(1);
     43 
     44     // We always return the input ref.
     45     context->forward_ref_input_to_ref_output(0, 0);
     46 
     47     // We can't always know how this value will be used downstream,
     48     // so make conservative assumptions in specifying constraints on
     49     // the memory allocation attributes.
     50     // TODO(rmlarsen): These conservative constraints make buffer
     51     // forwarding unlikely to happen very often. Try to use graph analysis
     52     // (possibly the InferAllocAttr pass in the executer) to improve the
     53     // situation.
     54     AllocatorAttributes attr;
     55     attr.set_gpu_compatible(true);
     56     attr.set_nic_compatible(true);
     57 
     58     {
     59       mutex_lock l(*context->input_ref_mutex(0));
     60       const Tensor& old_lhs = context->mutable_input(0, /* lock_held */ true);
     61       const bool same_shape = old_lhs.shape().IsSameSize(rhs.shape());
     62       if (validate_shape_) {
     63         OP_REQUIRES(
     64             context, same_shape,
     65             errors::InvalidArgument(
     66                 "Assign requires shapes of both tensors to match. lhs shape= ",
     67                 old_lhs.shape().DebugString(),
     68                 " rhs shape= ", rhs.shape().DebugString()));
     69       }
     70 
     71       // In the code below we try to minimize the amount of memory allocation
     72       // and copying by trying the following two shortcuts:
     73       // 1. If we can reuse the rhs buffer we avoid both a memory allocation
     74       //   and copying.
     75       // 2. If the lhs is initialized and has the same number of elements as the
     76       //    rhs we can avoid a memory allocation.
     77 
     78       // 1. Try to reuse the rhs.
     79       std::unique_ptr<Tensor> input_alias = context->forward_input(
     80           1, old_lhs.dtype(), old_lhs.shape(), DEVICE_MEMORY, attr);
     81       if (input_alias != nullptr) {
     82         // Transfer ownership to the ref.
     83         context->replace_ref_input(0, *input_alias.release(),
     84                                    /* lock_held */ true);
     85         return;
     86       }
     87 
     88       // 2. Try to copy into an existing buffer.
     89       if (old_lhs.IsInitialized() &&
     90           old_lhs.shape().num_elements() == rhs.shape().num_elements()) {
     91         // The existing lhs tensor has already been initialized and the right
     92         // hand side can fit in the underlying buffer.
     93         Tensor reshaped_old_lhs;
     94         if (same_shape) {
     95           reshaped_old_lhs = old_lhs;
     96         } else {
     97           CHECK(reshaped_old_lhs.CopyFrom(old_lhs, rhs.shape()));
     98           context->replace_ref_input(0, reshaped_old_lhs, /* lock_held */ true);
     99         }
    100         if (use_exclusive_lock_) {
    101           Copy(context, &reshaped_old_lhs, rhs);
    102           return;
    103         }
    104       } else {
    105         // Create a new persistent tensor whose shape matches the right hand
    106         // side, hand off to lhs and copy the rhs into it.
    107         PersistentTensor copy;
    108         Tensor* copyTensor = nullptr;
    109         OP_REQUIRES_OK(
    110             context, context->allocate_persistent(old_lhs.dtype(), rhs.shape(),
    111                                                   &copy, &copyTensor, attr));
    112         // We track memory of variables in variable ops instead of in this
    113         // assign op.
    114         context->clear_recorded_memory();
    115         context->replace_ref_input(0, *copyTensor, /* lock_held */ true);
    116         if (use_exclusive_lock_) {
    117           Copy(context, copyTensor, rhs);
    118           return;
    119         }
    120       }
    121     }
    122 
    123     // The tensor has already been initialized and the right hand side
    124     // matches the left hand side's shape. We have been told to do the
    125     // copy outside the lock.
    126     Tensor old_unlocked_lhs = context->mutable_input(0, /* lock_held */ false);
    127     Copy(context, &old_unlocked_lhs, rhs);
    128   }
    129 
    130   virtual void Copy(OpKernelContext* context, Tensor* lhs,
    131                     const Tensor& rhs) = 0;
    132 
    133   bool use_exclusive_lock_;
    134   bool validate_shape_;
    135 };
    136 
    137 }  // end namespace tensorflow
    138 
    139 #endif  // TENSORFLOW_KERNELS_ASSIGN_OP_H_
    140