Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <limits>
     17 
     18 #define EIGEN_USE_THREADS
     19 #if GOOGLE_CUDA
     20 #define EIGEN_USE_GPU
     21 #endif  // GOOGLE_CUDA
     22 
     23 #include "tensorflow/core/kernels/list_kernels.h"
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor_types.h"
     29 #include "tensorflow/core/framework/variant.h"
     30 #include "tensorflow/core/framework/variant_op_registry.h"
     31 #include "tensorflow/core/kernels/concat_lib.h"
     32 #include "tensorflow/core/lib/core/coding.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/util/util.h"
     35 
     36 namespace tensorflow {
     37 
     38 typedef Eigen::ThreadPoolDevice CPUDevice;
     39 
     40 // Variant compatible type for a list of tensors. This is mutable but instances
     41 // should never be mutated after stored in a variant tensor.
     42 TensorList::TensorList(const TensorList& other)
     43     : tensors(other.tensors),
     44       element_shape(other.element_shape),
     45       element_dtype(other.element_dtype) {}
     46 
     47 void TensorList::Encode(VariantTensorData* data) const {
     48   data->set_type_name(TypeName());
     49   for (const Tensor& t : tensors) {
     50     *data->add_tensors() = t;
     51   }
     52   string metadata;
     53   core::PutVarint64(&metadata, static_cast<uint64>(element_dtype));
     54   if (!element_shape.unknown_rank()) {
     55     for (TensorShapeDim dim : element_shape) {
     56       if (dim.size > 0) {
     57         core::PutVarint64(&metadata, dim.size);
     58       } else {
     59         core::PutVarint64(&metadata, std::numeric_limits<uint64>::max());
     60       }
     61     }
     62   }
     63   data->set_metadata(metadata);
     64 }
     65 
     66 static Status TensorListDeviceCopy(
     67     const TensorList& from, TensorList* to,
     68     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
     69   to->element_shape = from.element_shape;
     70   to->element_dtype = from.element_dtype;
     71   to->tensors.reserve(from.tensors.size());
     72   for (const Tensor& t : from.tensors) {
     73     Tensor tmp(t.dtype());
     74     TF_RETURN_IF_ERROR(copy(t, &tmp));
     75     to->tensors.push_back(tmp);
     76   }
     77   return Status::OK();
     78 }
     79 
     80 #define REGISTER_LIST_COPY(DIRECTION)                   \
     81   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
     82       TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy)
     83 
     84 REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
     85 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
     86 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
     87 
     88 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
     89 
     90 Status TensorListShape(const TensorList& t, TensorShape* s) {
     91   *s = TensorShape({});
     92   return Status::OK();
     93 }
     94 
     95 REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName,
     96                                       TensorListShape);
     97 
     98 bool TensorList::Decode(const VariantTensorData& data) {
     99   tensors = data.tensors();
    100   string metadata;
    101   data.get_metadata(&metadata);
    102   uint64 scratch;
    103   StringPiece iter(metadata);
    104   core::GetVarint64(&iter, &scratch);
    105   element_dtype = static_cast<DataType>(scratch);
    106   std::vector<int64> dims;
    107   while (!iter.empty()) {
    108     core::GetVarint64(&iter, &scratch);
    109     if (scratch == std::numeric_limits<uint64>::max()) {
    110       dims.push_back(-1);
    111     } else {
    112       dims.push_back(scratch);
    113     }
    114   }
    115   return true;
    116 }
    117 
    118 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
    119   if (t.shape() == TensorShape({})) {
    120     if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
    121         (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1)) {
    122       return Status::OK();
    123     }
    124     return errors::InvalidArgument(
    125         "The only valid scalar shape tensor is the fully unknown shape "
    126         "specified as -1.");
    127   }
    128   if (t.dtype() == DT_INT32) {
    129     return PartialTensorShape::MakePartialShape(t.vec<int32>().data(),
    130                                                 t.NumElements(), out);
    131   } else if (t.dtype() == DT_INT64) {
    132     return PartialTensorShape::MakePartialShape(t.vec<int64>().data(),
    133                                                 t.NumElements(), out);
    134   }
    135   return errors::InvalidArgument(
    136       "Expected an int32 or int64 shape tensor; found ",
    137       DataTypeString(t.dtype()));
    138 }
    139 
    140 class EmptyTensorList : public OpKernel {
    141  public:
    142   explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) {
    143     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &element_dtype_));
    144   }
    145 
    146   void Compute(OpKernelContext* ctx) override {
    147     Tensor* result;
    148     AllocatorAttributes attr;
    149     attr.set_on_host(true);
    150     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
    151     TensorList empty;
    152     empty.element_dtype = element_dtype_;
    153     PartialTensorShape element_shape;
    154     OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape));
    155     empty.element_shape = element_shape;
    156     result->scalar<Variant>()() = std::move(empty);
    157   }
    158 
    159  private:
    160   DataType element_dtype_;
    161 };
    162 
    163 const char TensorList::kTypeName[] = "tensorflow::TensorList";
    164 
    165 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU),
    166                         EmptyTensorList);
    167 
    168 #if GOOGLE_CUDA
    169 
    170 REGISTER_KERNEL_BUILDER(
    171     Name("EmptyTensorList").Device(DEVICE_GPU).HostMemory("element_shape"),
    172     EmptyTensorList);
    173 
    174 #endif  // GOOGLE_CUDA
    175 
    176 class TensorListPushBack : public OpKernel {
    177  public:
    178   explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) {
    179     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
    180   }
    181 
    182   ~TensorListPushBack() override {}
    183 
    184   void Compute(OpKernelContext* c) override {
    185     const Tensor& input = c->input(1);
    186     OP_REQUIRES(c, element_dtype_ == input.dtype(),
    187                 errors::InvalidArgument("Invalid data types; list elements ",
    188                                         DataTypeString(element_dtype_),
    189                                         " but tried to append ",
    190                                         DataTypeString(input.dtype())));
    191 
    192     const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
    193     OP_REQUIRES(c, l != nullptr,
    194                 errors::InvalidArgument(
    195                     "Input handle is not a list. Saw: '",
    196                     c->input(0).scalar<Variant>()().DebugString(), "'"));
    197     OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()),
    198                 errors::InvalidArgument(
    199                     "Tried to append a tensor with incompatible shape to a "
    200                     "list. Op element shape: ",
    201                     input.shape().DebugString(),
    202                     " list shape: ", l->element_shape.DebugString()));
    203     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
    204                 errors::InvalidArgument("Invalid data types; op elements ",
    205                                         DataTypeString(element_dtype_),
    206                                         " but list elements ",
    207                                         DataTypeString(l->element_dtype)));
    208 
    209     TensorList output;
    210     output = *l;
    211     output.tensors.push_back(input);
    212     Tensor* result;
    213     AllocatorAttributes attr;
    214     attr.set_on_host(true);
    215     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
    216     result->scalar<Variant>()() = std::move(output);
    217   }
    218 
    219  private:
    220   DataType element_dtype_;
    221 };
    222 
    223 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_CPU),
    224                         TensorListPushBack);
    225 
    226 #if GOOGLE_CUDA
    227 
    228 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_GPU),
    229                         TensorListPushBack);
    230 
    231 #endif  // GOOGLE_CUDA
    232 
    233 class TensorListLength : public OpKernel {
    234  public:
    235   explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {}
    236   ~TensorListLength() override {}
    237 
    238   void Compute(OpKernelContext* c) override {
    239     const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
    240     OP_REQUIRES(
    241         c, l != nullptr,
    242         errors::InvalidArgument(
    243             "TensorListLength received a variant which is not a list. Saw: '",
    244             c->input(0).scalar<Variant>()().DebugString(), "'"));
    245     Tensor* result;
    246     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
    247     result->scalar<int32>()() = l->tensors.size();
    248   }
    249 };
    250 
    251 REGISTER_KERNEL_BUILDER(Name("TensorListLength").Device(DEVICE_CPU),
    252                         TensorListLength);
    253 
    254 #if GOOGLE_CUDA
    255 
    256 REGISTER_KERNEL_BUILDER(
    257     Name("TensorListLength").Device(DEVICE_GPU).HostMemory("length"),
    258     TensorListLength);
    259 
    260 #endif  // GOOGLE_CUDA
    261 
    262 class TensorListElementShape : public OpKernel {
    263  public:
    264   explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {}
    265 
    266   void Compute(OpKernelContext* c) override {
    267     OP_REQUIRES(
    268         c, c->input(0).shape().num_elements() == 1,
    269         errors::InvalidArgument("List tensors are supposed to be scalars."));
    270     const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
    271     OP_REQUIRES(c, l != nullptr,
    272                 errors::InvalidArgument(
    273                     "TensorListElementShape received a variant which is not a "
    274                     "list. Saw: '",
    275                     c->input(0).scalar<Variant>()().DebugString(), "'"));
    276     Tensor* result;
    277     OP_REQUIRES_OK(c, c->allocate_output(
    278                           0, TensorShape{l->element_shape.dims()}, &result));
    279     for (int i = 0; i < l->element_shape.dims(); ++i) {
    280       if (result->dtype() == DT_INT32) {
    281         result->flat<int32>()(i) = l->element_shape.dim_size(i);
    282       } else {
    283         result->flat<int64>()(i) = l->element_shape.dim_size(i);
    284       }
    285     }
    286   }
    287 };
    288 
    289 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU),
    290                         TensorListElementShape);
    291 
    292 #if GOOGLE_CUDA
    293 
    294 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape")
    295                             .Device(DEVICE_GPU)
    296                             .HostMemory("element_shape"),
    297                         TensorListElementShape);
    298 
    299 #endif  // GOOGLE_CUDA
    300 
    301 class TensorListPopBack : public OpKernel {
    302  public:
    303   explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) {
    304     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
    305   }
    306 
    307   ~TensorListPopBack() override {}
    308 
    309   void Compute(OpKernelContext* c) override {
    310     const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
    311     OP_REQUIRES(c, l != nullptr,
    312                 errors::InvalidArgument(
    313                     "Input handle is not a list. Saw: '",
    314                     c->input(0).scalar<Variant>()().DebugString(), "'"));
    315     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
    316                 errors::InvalidArgument("Invalid data types; op elements ",
    317                                         DataTypeString(element_dtype_),
    318                                         " but list elements ",
    319                                         DataTypeString(l->element_dtype)));
    320 
    321     OP_REQUIRES(c, !l->tensors.empty(),
    322                 errors::InvalidArgument("Trying to pop from an empty list."));
    323 
    324     c->set_output(1, l->tensors.back());
    325     TensorList output;
    326     output = *l;
    327     output.tensors.pop_back();
    328     Tensor* result;
    329     AllocatorAttributes attr;
    330     attr.set_on_host(true);
    331     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
    332     result->scalar<Variant>()() = std::move(output);
    333   }
    334 
    335  private:
    336   DataType element_dtype_;
    337 };
    338 
    339 REGISTER_KERNEL_BUILDER(Name("TensorListPopBack").Device(DEVICE_CPU),
    340                         TensorListPopBack);
    341 
    342 #if GOOGLE_CUDA
    343 
    344 REGISTER_KERNEL_BUILDER(Name("TensorListPopBack").Device(DEVICE_GPU),
    345                         TensorListPopBack);
    346 
    347 #endif  // GOOGLE_CUDA
    348 
    349 class TensorListReserve : public OpKernel {
    350  public:
    351   explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) {
    352     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
    353   }
    354 
    355   void Compute(OpKernelContext* c) override {
    356     PartialTensorShape element_shape;
    357     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
    358     int32 num_elements = c->input(1).scalar<int32>()();
    359     TensorList output;
    360     output.element_shape = element_shape;
    361     output.element_dtype = element_dtype_;
    362     output.tensors.resize(num_elements, Tensor(DT_INVALID));
    363     Tensor* result;
    364     AllocatorAttributes attr;
    365     attr.set_on_host(true);
    366     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
    367     result->scalar<Variant>()() = std::move(output);
    368   }
    369 
    370  private:
    371   DataType element_dtype_;
    372 };
    373 
    374 REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU),
    375                         TensorListReserve);
    376 
    377 #if GOOGLE_CUDA
    378 
    379 REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
    380                             .Device(DEVICE_GPU)
    381                             .HostMemory("element_shape")
    382                             .HostMemory("num_elements"),
    383                         TensorListReserve);
    384 
    385 #endif  // GOOGLE_CUDA
    386 
    387 class TensorListGetItem : public OpKernel {
    388  public:
    389   explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) {
    390     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
    391   }
    392 
    393   void Compute(OpKernelContext* c) override {
    394     OP_REQUIRES(
    395         c, c->input(0).shape().num_elements() == 1,
    396         errors::InvalidArgument("List tensors are supposed to be scalars."));
    397     const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
    398     OP_REQUIRES(c, l != nullptr,
    399                 errors::InvalidArgument(
    400                     "Input handle is not a list. Saw: '",
    401                     c->input(0).scalar<Variant>()().DebugString(), "'"));
    402     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
    403                 errors::InvalidArgument("Invalid data types; op elements ",
    404                                         DataTypeString(element_dtype_),
    405                                         " but list elements ",
    406                                         DataTypeString(l->element_dtype)));
    407     int32 index = c->input(1).scalar<int32>()();
    408     OP_REQUIRES(c, index < l->tensors.size(),
    409                 errors::InvalidArgument("Trying to access element ", index,
    410                                         " in a list with ", l->tensors.size(),
    411                                         " elements."));
    412     c->set_output(0, l->tensors[index]);
    413   }
    414 
    415  private:
    416   DataType element_dtype_;
    417 };
    418 
    419 REGISTER_KERNEL_BUILDER(Name("TensorListGetItem").Device(DEVICE_CPU),
    420                         TensorListGetItem);
    421 
    422 #if GOOGLE_CUDA
    423 
    424 REGISTER_KERNEL_BUILDER(
    425     Name("TensorListGetItem").Device(DEVICE_GPU).HostMemory("index"),
    426     TensorListGetItem);
    427 
    428 #endif  // GOOGLE_CUDA
    429 
    430 class TensorListSetItem : public OpKernel {
    431  public:
    432   explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) {
    433     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
    434   }
    435 
    436   void Compute(OpKernelContext* c) override {
    437     const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
    438     OP_REQUIRES(c, l != nullptr,
    439                 errors::InvalidArgument(
    440                     "Input handle is not a list. Saw: '",
    441                     c->input(0).scalar<Variant>()().DebugString(), "'"));
    442     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
    443                 errors::InvalidArgument("Invalid data types; op elements ",
    444                                         DataTypeString(element_dtype_),
    445                                         " but list elements ",
    446                                         DataTypeString(l->element_dtype)));
    447     int32 index = c->input(1).scalar<int32>()();
    448     OP_REQUIRES(c, index < l->tensors.size(),
    449                 errors::InvalidArgument("Trying to modify element ", index,
    450                                         " in a list with ", l->tensors.size(),
    451                                         " elements."));
    452     TensorList output;
    453     output = *l;
    454     output.tensors[index] = c->input(2);
    455     Tensor* result;
    456     AllocatorAttributes attr;
    457     attr.set_on_host(true);
    458     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
    459     result->scalar<Variant>()() = std::move(output);
    460   }
    461 
    462  private:
    463   DataType element_dtype_;
    464 };
    465 
    466 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU),
    467                         TensorListSetItem);
    468 
    469 #if GOOGLE_CUDA
    470 
    471 REGISTER_KERNEL_BUILDER(
    472     Name("TensorListSetItem").Device(DEVICE_GPU).HostMemory("index"),
    473     TensorListSetItem);
    474 
    475 #endif  // GOOGLE_CUDA
    476 
    477 #define REGISTER_TENSOR_LIST_STACK_CPU(T)                         \
    478   REGISTER_KERNEL_BUILDER(Name("TensorListStack")                 \
    479                               .TypeConstraint<T>("element_dtype") \
    480                               .Device(DEVICE_CPU),                \
    481                           TensorListStack<CPUDevice, T>)
    482 
    483 TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_STACK_CPU);
    484 REGISTER_TENSOR_LIST_STACK_CPU(quint8);
    485 REGISTER_TENSOR_LIST_STACK_CPU(qint8);
    486 REGISTER_TENSOR_LIST_STACK_CPU(quint16);
    487 REGISTER_TENSOR_LIST_STACK_CPU(qint16);
    488 REGISTER_TENSOR_LIST_STACK_CPU(qint32);
    489 REGISTER_TENSOR_LIST_STACK_CPU(bfloat16);
    490 
    491 #undef REGISTER_TENSOR_LIST_STACK_CPU
    492 
    493 #define REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(T)                   \
    494   REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor")            \
    495                               .TypeConstraint<T>("element_dtype") \
    496                               .Device(DEVICE_CPU),                \
    497                           TensorListFromTensor<CPUDevice, T>)
    498 
    499 TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_CPU);
    500 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(quint8);
    501 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(qint8);
    502 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(quint16);
    503 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(qint16);
    504 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(qint32);
    505 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(bfloat16);
    506 
    507 #undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU
    508 
    509 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
    510                                           TensorList, TensorList::kTypeName,
    511                                           TensorListBinaryAdd<CPUDevice>);
    512 
    513 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
    514                                          DEVICE_CPU, TensorList,
    515                                          TensorList::kTypeName,
    516                                          TensorListZerosLike<CPUDevice>);
    517 
    518 }  // namespace tensorflow
    519