Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
     16 #define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
     17 
     18 #define EIGEN_USE_THREADS
     19 #if GOOGLE_CUDA
     20 #define EIGEN_USE_GPU
     21 #endif  // GOOGLE_CUDA
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_types.h"
     28 #include "tensorflow/core/framework/variant.h"
     29 #include "tensorflow/core/framework/variant_op_registry.h"
     30 #include "tensorflow/core/kernels/concat_lib.h"
     31 #include "tensorflow/core/lib/core/coding.h"
     32 #include "tensorflow/core/lib/core/errors.h"
     33 #include "tensorflow/core/util/util.h"
     34 
     35 namespace tensorflow {
     36 
     37 // Variant compatible type for a list of tensors. This is mutable but instances
     38 // should never be mutated after stored in a variant tensor.
     39 struct TensorList {
     40  public:
     41   TensorList() {}
     42   TensorList(const TensorList& other);
     43 
     44   static const char kTypeName[];
     45   string TypeName() const { return kTypeName; }
     46 
     47   void Encode(VariantTensorData* data) const;
     48 
     49   bool Decode(const VariantTensorData& data);
     50 
     51   // TODO(apassos) fill this out
     52   string DebugString() const { return "TensorList"; }
     53 
     54   std::vector<Tensor> tensors;
     55   PartialTensorShape element_shape;
     56   DataType element_dtype;
     57 };
     58 
     59 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
     60 
     61 template <typename Device, typename T>
     62 class TensorListStack : public OpKernel {
     63  public:
     64   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
     65       ConstMatrixVector;
     66   explicit TensorListStack(OpKernelConstruction* c) : OpKernel(c) {
     67     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
     68     OP_REQUIRES_OK(c, c->GetAttr("num_elements", &num_elements_));
     69   }
     70 
     71   ~TensorListStack() {}
     72 
     73   void Compute(OpKernelContext* c) override {
     74     const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
     75     OP_REQUIRES(c, l != nullptr,
     76                 errors::InvalidArgument(
     77                     "Input handle is not a list. Saw: '",
     78                     c->input(0).scalar<Variant>()().DebugString(), "'"));
     79     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
     80                 errors::InvalidArgument("Invalid data types; op elements ",
     81                                         DataTypeString(element_dtype_),
     82                                         " but list elements ",
     83                                         DataTypeString(l->element_dtype)));
     84     OP_REQUIRES(c, l->element_shape.IsFullyDefined(),
     85                 errors::InvalidArgument("Tried to stack elements from a list "
     86                                         "with non-fully-defined shape."));
     87     if (num_elements_ != -1) {
     88       OP_REQUIRES(c, l->tensors.size() == num_elements_,
     89                   errors::InvalidArgument("Operation expected a list with ",
     90                                           num_elements_,
     91                                           " elements but got a list with ",
     92                                           l->tensors.size(), " elements."));
     93     }
     94     TensorShape resulting_shape;
     95     resulting_shape.AddDim(l->tensors.size());
     96     for (TensorShapeDim s : l->element_shape) {
     97       resulting_shape.AddDim(s.size);
     98     }
     99     Tensor* output;
    100     OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output));
    101     if (output->NumElements() == 0) {
    102       return;
    103     }
    104 
    105     ConstMatrixVector inputs_flat;
    106     inputs_flat.reserve(l->tensors.size());
    107     for (const auto& t : l->tensors) {
    108       OP_REQUIRES(
    109           c, l->element_shape.IsCompatibleWith(t.shape()),
    110           errors::InvalidArgument(
    111               "Tensor with invalid shape in list. List element shape shape: ",
    112               l->element_shape.DebugString(),
    113               " and tensor shape: ", t.shape().DebugString()));
    114       inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
    115           t.shaped<T, 2>({1, t.NumElements()})));
    116     }
    117     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
    118 
    119 #if GOOGLE_CUDA
    120     if (std::is_same<Device, Eigen::GpuDevice>::value) {
    121       ConcatGPU<T>(c, inputs_flat, output, &output_flat);
    122       return;
    123     }
    124 #endif  // GOOGLE_CUDA
    125     ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
    126   }
    127 
    128  private:
    129   int num_elements_;
    130   DataType element_dtype_;
    131 };
    132 
    133 template <typename Device, typename T>
    134 class TensorListFromTensor : public OpKernel {
    135  public:
    136   TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
    137 
    138   void Compute(OpKernelContext* c) override {
    139     Tensor* output_tensor;
    140     AllocatorAttributes attr;
    141     attr.set_on_host(true);
    142     OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
    143     PartialTensorShape element_shape;
    144     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
    145     TensorList output_list;
    146     const Tensor& t = c->input(0);
    147     output_list.element_dtype = t.dtype();
    148     TensorShape output_shape(t.shape());
    149     output_shape.RemoveDim(0);
    150     OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
    151                 errors::InvalidArgument(
    152                     "Specified a list with shape ", element_shape.DebugString(),
    153                     " from a tensor with shape ", output_shape.DebugString()));
    154     output_list.element_shape = element_shape;
    155     output_list.tensors.reserve(t.shape().dim_size(0));
    156     for (int i = 0; i < t.shape().dim_size(0); ++i) {
    157       Tensor tmp = t.Slice(i, i + 1);
    158       TensorShape tmp_shape = tmp.shape();
    159       tmp_shape.RemoveDim(0);
    160       OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
    161                   errors::Unknown("Unexpected shape error."));
    162       if (tmp.IsAligned() || !DataTypeCanUseMemcpy(DataTypeToEnum<T>::value)) {
    163         output_list.tensors.push_back(tmp);
    164       } else {
    165         Tensor aligned;
    166         OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
    167         aligned.flat<T>().device(c->eigen_device<Device>()) =
    168             tmp.unaligned_flat<T>();
    169         output_list.tensors.push_back(aligned);
    170       }
    171     }
    172     output_tensor->scalar<Variant>()() = std::move(output_list);
    173   }
    174 };
    175 
    176 template <typename Device>
    177 Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
    178                            const TensorList& b, TensorList* out) {
    179   if (a.element_dtype != b.element_dtype) {
    180     return errors::InvalidArgument(
    181         "Trying to add two lists of tensors of different dtypes. One is ",
    182         DataTypeString(a.element_dtype), " and the other is ",
    183         DataTypeString(b.element_dtype));
    184   }
    185   out->element_dtype = a.element_dtype;
    186   if (!a.element_shape.IsCompatibleWith(b.element_shape)) {
    187     return errors::InvalidArgument(
    188         "Trying to add two lists of tensors with incompatible element shapes. "
    189         "One is ",
    190         a.element_shape.DebugString(), " and the other is ",
    191         b.element_shape.DebugString());
    192   }
    193 
    194   TF_RETURN_IF_ERROR(
    195       a.element_shape.MergeWith(b.element_shape, &out->element_shape));
    196   if (a.tensors.size() != b.tensors.size()) {
    197     return errors::InvalidArgument(
    198         "Trying to add two lists of tensors with different lengths. One is ",
    199         a.tensors.size(), " and the other is ", b.tensors.size());
    200   }
    201   out->tensors.reserve(a.tensors.size());
    202   for (int i = 0; i < a.tensors.size(); ++i) {
    203     const Tensor& a_tensor = a.tensors[i];
    204     const Tensor& b_tensor = b.tensors[i];
    205     if (a_tensor.dtype() == DT_INVALID) {
    206       out->tensors.push_back(b_tensor);
    207       continue;
    208     }
    209     if (b_tensor.dtype() == DT_INVALID) {
    210       out->tensors.push_back(a_tensor);
    211       continue;
    212     }
    213     if (a_tensor.shape() != b_tensor.shape()) {
    214       // TODO(apassos) support broadcasting additions here?
    215       return errors::InvalidArgument(
    216           "Trying to add two tensors with incompatible element shapes. "
    217           "One is ",
    218           a_tensor.shape().DebugString(), " and the other is ",
    219           b_tensor.shape().DebugString(), " in position ", i);
    220     }
    221     Tensor out_tensor;
    222     TF_RETURN_IF_ERROR(
    223         c->allocate_temp(a_tensor.dtype(), a_tensor.shape(), &out_tensor));
    224     out->tensors.push_back(out_tensor);
    225     switch (out_tensor.dtype()) {
    226 #define DTYPE_CASE(dtype)                                        \
    227   case DataTypeToEnum<dtype>::value:                             \
    228     out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \
    229         a_tensor.flat<dtype>() + b_tensor.flat<dtype>();         \
    230     break;
    231 
    232       TF_CALL_NUMBER_TYPES(DTYPE_CASE)
    233 
    234 #undef DTYPE_CASE
    235       default:
    236         return errors::InvalidArgument("Trying to add unsupported dtype ",
    237                                        out_tensor.dtype());
    238     }
    239   }
    240   return Status::OK();
    241 }
    242 
    243 template <typename Device>
    244 Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
    245                            TensorList* y) {
    246   y->element_dtype = x.element_dtype;
    247   y->element_shape = x.element_shape;
    248   y->tensors.reserve(x.tensors.size());
    249   for (const Tensor& t : x.tensors) {
    250     Tensor out_tensor;
    251     TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor));
    252     switch (out_tensor.dtype()) {
    253 #define DTYPE_CASE(dtype)                                        \
    254   case DataTypeToEnum<dtype>::value:                             \
    255     out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \
    256         out_tensor.flat<dtype>().constant(dtype(0));             \
    257     break;
    258 
    259       TF_CALL_NUMBER_TYPES(DTYPE_CASE)
    260 
    261 #undef DTYPE_CASE
    262       default:
    263         return errors::InvalidArgument(
    264             "Trying to compute zeros_like for unsupported dtype",
    265             out_tensor.dtype());
    266     }
    267   }
    268   return Status::OK();
    269 }
    270 
    271 }  // namespace tensorflow
    272 
    273 #endif  // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
    274