1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <limits> 17 18 #define EIGEN_USE_THREADS 19 #if GOOGLE_CUDA 20 #define EIGEN_USE_GPU 21 #endif // GOOGLE_CUDA 22 23 #include "tensorflow/core/kernels/list_kernels.h" 24 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor_types.h" 29 #include "tensorflow/core/framework/variant.h" 30 #include "tensorflow/core/framework/variant_op_registry.h" 31 #include "tensorflow/core/kernels/concat_lib.h" 32 #include "tensorflow/core/lib/core/coding.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 #include "tensorflow/core/util/util.h" 35 36 namespace tensorflow { 37 38 typedef Eigen::ThreadPoolDevice CPUDevice; 39 40 // Variant compatible type for a list of tensors. This is mutable but instances 41 // should never be mutated after stored in a variant tensor. 42 TensorList::TensorList(const TensorList& other) 43 : tensors(other.tensors), 44 element_shape(other.element_shape), 45 element_dtype(other.element_dtype) {} 46 47 void TensorList::Encode(VariantTensorData* data) const { 48 data->set_type_name(TypeName()); 49 for (const Tensor& t : tensors) { 50 *data->add_tensors() = t; 51 } 52 string metadata; 53 core::PutVarint64(&metadata, static_cast<uint64>(element_dtype)); 54 if (!element_shape.unknown_rank()) { 55 for (TensorShapeDim dim : element_shape) { 56 if (dim.size > 0) { 57 core::PutVarint64(&metadata, dim.size); 58 } else { 59 core::PutVarint64(&metadata, std::numeric_limits<uint64>::max()); 60 } 61 } 62 } 63 data->set_metadata(metadata); 64 } 65 66 static Status TensorListDeviceCopy( 67 const TensorList& from, TensorList* to, 68 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { 69 to->element_shape = from.element_shape; 70 to->element_dtype = from.element_dtype; 71 to->tensors.reserve(from.tensors.size()); 72 for (const Tensor& t : from.tensors) { 73 Tensor tmp(t.dtype()); 74 TF_RETURN_IF_ERROR(copy(t, &tmp)); 75 to->tensors.push_back(tmp); 76 } 77 return Status::OK(); 78 } 79 80 #define REGISTER_LIST_COPY(DIRECTION) \ 81 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ 82 TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy) 83 84 REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); 85 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); 86 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE); 87 88 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName); 89 90 Status TensorListShape(const TensorList& t, TensorShape* s) { 91 *s = TensorShape({}); 92 return Status::OK(); 93 } 94 95 REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName, 96 TensorListShape); 97 98 bool TensorList::Decode(const VariantTensorData& data) { 99 tensors = data.tensors(); 100 string metadata; 101 data.get_metadata(&metadata); 102 uint64 scratch; 103 StringPiece iter(metadata); 104 core::GetVarint64(&iter, &scratch); 105 element_dtype = static_cast<DataType>(scratch); 106 std::vector<int64> dims; 107 while (!iter.empty()) { 108 core::GetVarint64(&iter, &scratch); 109 if (scratch == std::numeric_limits<uint64>::max()) { 110 dims.push_back(-1); 111 } else { 112 dims.push_back(scratch); 113 } 114 } 115 return true; 116 } 117 118 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) { 119 if (t.shape() == TensorShape({})) { 120 if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) || 121 (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1)) { 122 return Status::OK(); 123 } 124 return errors::InvalidArgument( 125 "The only valid scalar shape tensor is the fully unknown shape " 126 "specified as -1."); 127 } 128 if (t.dtype() == DT_INT32) { 129 return PartialTensorShape::MakePartialShape(t.vec<int32>().data(), 130 t.NumElements(), out); 131 } else if (t.dtype() == DT_INT64) { 132 return PartialTensorShape::MakePartialShape(t.vec<int64>().data(), 133 t.NumElements(), out); 134 } 135 return errors::InvalidArgument( 136 "Expected an int32 or int64 shape tensor; found ", 137 DataTypeString(t.dtype())); 138 } 139 140 class EmptyTensorList : public OpKernel { 141 public: 142 explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) { 143 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &element_dtype_)); 144 } 145 146 void Compute(OpKernelContext* ctx) override { 147 Tensor* result; 148 AllocatorAttributes attr; 149 attr.set_on_host(true); 150 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr)); 151 TensorList empty; 152 empty.element_dtype = element_dtype_; 153 PartialTensorShape element_shape; 154 OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape)); 155 empty.element_shape = element_shape; 156 result->scalar<Variant>()() = std::move(empty); 157 } 158 159 private: 160 DataType element_dtype_; 161 }; 162 163 const char TensorList::kTypeName[] = "tensorflow::TensorList"; 164 165 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU), 166 EmptyTensorList); 167 168 #if GOOGLE_CUDA 169 170 REGISTER_KERNEL_BUILDER( 171 Name("EmptyTensorList").Device(DEVICE_GPU).HostMemory("element_shape"), 172 EmptyTensorList); 173 174 #endif // GOOGLE_CUDA 175 176 class TensorListPushBack : public OpKernel { 177 public: 178 explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) { 179 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); 180 } 181 182 ~TensorListPushBack() override {} 183 184 void Compute(OpKernelContext* c) override { 185 const Tensor& input = c->input(1); 186 OP_REQUIRES(c, element_dtype_ == input.dtype(), 187 errors::InvalidArgument("Invalid data types; list elements ", 188 DataTypeString(element_dtype_), 189 " but tried to append ", 190 DataTypeString(input.dtype()))); 191 192 const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>(); 193 OP_REQUIRES(c, l != nullptr, 194 errors::InvalidArgument( 195 "Input handle is not a list. Saw: '", 196 c->input(0).scalar<Variant>()().DebugString(), "'")); 197 OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()), 198 errors::InvalidArgument( 199 "Tried to append a tensor with incompatible shape to a " 200 "list. Op element shape: ", 201 input.shape().DebugString(), 202 " list shape: ", l->element_shape.DebugString())); 203 OP_REQUIRES(c, element_dtype_ == l->element_dtype, 204 errors::InvalidArgument("Invalid data types; op elements ", 205 DataTypeString(element_dtype_), 206 " but list elements ", 207 DataTypeString(l->element_dtype))); 208 209 TensorList output; 210 output = *l; 211 output.tensors.push_back(input); 212 Tensor* result; 213 AllocatorAttributes attr; 214 attr.set_on_host(true); 215 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); 216 result->scalar<Variant>()() = std::move(output); 217 } 218 219 private: 220 DataType element_dtype_; 221 }; 222 223 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_CPU), 224 TensorListPushBack); 225 226 #if GOOGLE_CUDA 227 228 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_GPU), 229 TensorListPushBack); 230 231 #endif // GOOGLE_CUDA 232 233 class TensorListLength : public OpKernel { 234 public: 235 explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {} 236 ~TensorListLength() override {} 237 238 void Compute(OpKernelContext* c) override { 239 const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>(); 240 OP_REQUIRES( 241 c, l != nullptr, 242 errors::InvalidArgument( 243 "TensorListLength received a variant which is not a list. Saw: '", 244 c->input(0).scalar<Variant>()().DebugString(), "'")); 245 Tensor* result; 246 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result)); 247 result->scalar<int32>()() = l->tensors.size(); 248 } 249 }; 250 251 REGISTER_KERNEL_BUILDER(Name("TensorListLength").Device(DEVICE_CPU), 252 TensorListLength); 253 254 #if GOOGLE_CUDA 255 256 REGISTER_KERNEL_BUILDER( 257 Name("TensorListLength").Device(DEVICE_GPU).HostMemory("length"), 258 TensorListLength); 259 260 #endif // GOOGLE_CUDA 261 262 class TensorListElementShape : public OpKernel { 263 public: 264 explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {} 265 266 void Compute(OpKernelContext* c) override { 267 OP_REQUIRES( 268 c, c->input(0).shape().num_elements() == 1, 269 errors::InvalidArgument("List tensors are supposed to be scalars.")); 270 const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>(); 271 OP_REQUIRES(c, l != nullptr, 272 errors::InvalidArgument( 273 "TensorListElementShape received a variant which is not a " 274 "list. Saw: '", 275 c->input(0).scalar<Variant>()().DebugString(), "'")); 276 Tensor* result; 277 OP_REQUIRES_OK(c, c->allocate_output( 278 0, TensorShape{l->element_shape.dims()}, &result)); 279 for (int i = 0; i < l->element_shape.dims(); ++i) { 280 if (result->dtype() == DT_INT32) { 281 result->flat<int32>()(i) = l->element_shape.dim_size(i); 282 } else { 283 result->flat<int64>()(i) = l->element_shape.dim_size(i); 284 } 285 } 286 } 287 }; 288 289 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU), 290 TensorListElementShape); 291 292 #if GOOGLE_CUDA 293 294 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape") 295 .Device(DEVICE_GPU) 296 .HostMemory("element_shape"), 297 TensorListElementShape); 298 299 #endif // GOOGLE_CUDA 300 301 class TensorListPopBack : public OpKernel { 302 public: 303 explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) { 304 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); 305 } 306 307 ~TensorListPopBack() override {} 308 309 void Compute(OpKernelContext* c) override { 310 const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>(); 311 OP_REQUIRES(c, l != nullptr, 312 errors::InvalidArgument( 313 "Input handle is not a list. Saw: '", 314 c->input(0).scalar<Variant>()().DebugString(), "'")); 315 OP_REQUIRES(c, element_dtype_ == l->element_dtype, 316 errors::InvalidArgument("Invalid data types; op elements ", 317 DataTypeString(element_dtype_), 318 " but list elements ", 319 DataTypeString(l->element_dtype))); 320 321 OP_REQUIRES(c, !l->tensors.empty(), 322 errors::InvalidArgument("Trying to pop from an empty list.")); 323 324 c->set_output(1, l->tensors.back()); 325 TensorList output; 326 output = *l; 327 output.tensors.pop_back(); 328 Tensor* result; 329 AllocatorAttributes attr; 330 attr.set_on_host(true); 331 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); 332 result->scalar<Variant>()() = std::move(output); 333 } 334 335 private: 336 DataType element_dtype_; 337 }; 338 339 REGISTER_KERNEL_BUILDER(Name("TensorListPopBack").Device(DEVICE_CPU), 340 TensorListPopBack); 341 342 #if GOOGLE_CUDA 343 344 REGISTER_KERNEL_BUILDER(Name("TensorListPopBack").Device(DEVICE_GPU), 345 TensorListPopBack); 346 347 #endif // GOOGLE_CUDA 348 349 class TensorListReserve : public OpKernel { 350 public: 351 explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) { 352 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); 353 } 354 355 void Compute(OpKernelContext* c) override { 356 PartialTensorShape element_shape; 357 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape)); 358 int32 num_elements = c->input(1).scalar<int32>()(); 359 TensorList output; 360 output.element_shape = element_shape; 361 output.element_dtype = element_dtype_; 362 output.tensors.resize(num_elements, Tensor(DT_INVALID)); 363 Tensor* result; 364 AllocatorAttributes attr; 365 attr.set_on_host(true); 366 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); 367 result->scalar<Variant>()() = std::move(output); 368 } 369 370 private: 371 DataType element_dtype_; 372 }; 373 374 REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU), 375 TensorListReserve); 376 377 #if GOOGLE_CUDA 378 379 REGISTER_KERNEL_BUILDER(Name("TensorListReserve") 380 .Device(DEVICE_GPU) 381 .HostMemory("element_shape") 382 .HostMemory("num_elements"), 383 TensorListReserve); 384 385 #endif // GOOGLE_CUDA 386 387 class TensorListGetItem : public OpKernel { 388 public: 389 explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) { 390 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); 391 } 392 393 void Compute(OpKernelContext* c) override { 394 OP_REQUIRES( 395 c, c->input(0).shape().num_elements() == 1, 396 errors::InvalidArgument("List tensors are supposed to be scalars.")); 397 const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>(); 398 OP_REQUIRES(c, l != nullptr, 399 errors::InvalidArgument( 400 "Input handle is not a list. Saw: '", 401 c->input(0).scalar<Variant>()().DebugString(), "'")); 402 OP_REQUIRES(c, element_dtype_ == l->element_dtype, 403 errors::InvalidArgument("Invalid data types; op elements ", 404 DataTypeString(element_dtype_), 405 " but list elements ", 406 DataTypeString(l->element_dtype))); 407 int32 index = c->input(1).scalar<int32>()(); 408 OP_REQUIRES(c, index < l->tensors.size(), 409 errors::InvalidArgument("Trying to access element ", index, 410 " in a list with ", l->tensors.size(), 411 " elements.")); 412 c->set_output(0, l->tensors[index]); 413 } 414 415 private: 416 DataType element_dtype_; 417 }; 418 419 REGISTER_KERNEL_BUILDER(Name("TensorListGetItem").Device(DEVICE_CPU), 420 TensorListGetItem); 421 422 #if GOOGLE_CUDA 423 424 REGISTER_KERNEL_BUILDER( 425 Name("TensorListGetItem").Device(DEVICE_GPU).HostMemory("index"), 426 TensorListGetItem); 427 428 #endif // GOOGLE_CUDA 429 430 class TensorListSetItem : public OpKernel { 431 public: 432 explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) { 433 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); 434 } 435 436 void Compute(OpKernelContext* c) override { 437 const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>(); 438 OP_REQUIRES(c, l != nullptr, 439 errors::InvalidArgument( 440 "Input handle is not a list. Saw: '", 441 c->input(0).scalar<Variant>()().DebugString(), "'")); 442 OP_REQUIRES(c, element_dtype_ == l->element_dtype, 443 errors::InvalidArgument("Invalid data types; op elements ", 444 DataTypeString(element_dtype_), 445 " but list elements ", 446 DataTypeString(l->element_dtype))); 447 int32 index = c->input(1).scalar<int32>()(); 448 OP_REQUIRES(c, index < l->tensors.size(), 449 errors::InvalidArgument("Trying to modify element ", index, 450 " in a list with ", l->tensors.size(), 451 " elements.")); 452 TensorList output; 453 output = *l; 454 output.tensors[index] = c->input(2); 455 Tensor* result; 456 AllocatorAttributes attr; 457 attr.set_on_host(true); 458 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); 459 result->scalar<Variant>()() = std::move(output); 460 } 461 462 private: 463 DataType element_dtype_; 464 }; 465 466 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU), 467 TensorListSetItem); 468 469 #if GOOGLE_CUDA 470 471 REGISTER_KERNEL_BUILDER( 472 Name("TensorListSetItem").Device(DEVICE_GPU).HostMemory("index"), 473 TensorListSetItem); 474 475 #endif // GOOGLE_CUDA 476 477 #define REGISTER_TENSOR_LIST_STACK_CPU(T) \ 478 REGISTER_KERNEL_BUILDER(Name("TensorListStack") \ 479 .TypeConstraint<T>("element_dtype") \ 480 .Device(DEVICE_CPU), \ 481 TensorListStack<CPUDevice, T>) 482 483 TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_STACK_CPU); 484 REGISTER_TENSOR_LIST_STACK_CPU(quint8); 485 REGISTER_TENSOR_LIST_STACK_CPU(qint8); 486 REGISTER_TENSOR_LIST_STACK_CPU(quint16); 487 REGISTER_TENSOR_LIST_STACK_CPU(qint16); 488 REGISTER_TENSOR_LIST_STACK_CPU(qint32); 489 REGISTER_TENSOR_LIST_STACK_CPU(bfloat16); 490 491 #undef REGISTER_TENSOR_LIST_STACK_CPU 492 493 #define REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(T) \ 494 REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \ 495 .TypeConstraint<T>("element_dtype") \ 496 .Device(DEVICE_CPU), \ 497 TensorListFromTensor<CPUDevice, T>) 498 499 TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_CPU); 500 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(quint8); 501 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(qint8); 502 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(quint16); 503 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(qint16); 504 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(qint32); 505 REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(bfloat16); 506 507 #undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU 508 509 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, 510 TensorList, TensorList::kTypeName, 511 TensorListBinaryAdd<CPUDevice>); 512 513 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, 514 DEVICE_CPU, TensorList, 515 TensorList::kTypeName, 516 TensorListZerosLike<CPUDevice>); 517 518 } // namespace tensorflow 519