1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/op_def_builder.h" 19 #include "tensorflow/core/framework/shape_inference.h" 20 21 namespace tensorflow { 22 23 using shape_inference::DimensionHandle; 24 using shape_inference::InferenceContext; 25 using shape_inference::ShapeHandle; 26 27 namespace { 28 29 Status DequeueManyV2Shape(InferenceContext* c, ShapeHandle n_shape) { 30 auto* t = c->input_handle_shapes_and_types(0); 31 if (t != nullptr && t->size() == c->num_outputs()) { 32 for (int i = 0; i < c->num_outputs(); ++i) { 33 ShapeHandle combined_shape; 34 TF_RETURN_IF_ERROR( 35 c->Concatenate(n_shape, (*t)[i].shape, &combined_shape)); 36 c->set_output(i, combined_shape); 37 } 38 return Status::OK(); 39 } else { 40 return shape_inference::UnknownShape(c); 41 } 42 } 43 44 } // namespace 45 46 // -------------------------------------------------------------------------- 47 48 REGISTER_OP("DynamicPartition") 49 .Input("data: T") 50 .Input("partitions: int32") 51 .Output("outputs: num_partitions * T") 52 .Attr("num_partitions: int") 53 .Attr("T: type") 54 .SetShapeFn([](InferenceContext* c) { 55 int64 num_partitions; 56 TF_RETURN_IF_ERROR(c->GetAttr("num_partitions", &num_partitions)); 57 58 ShapeHandle data_shape = c->input(0); 59 ShapeHandle partitions_shape = c->input(1); 60 61 if (!c->RankKnown(partitions_shape)) { 62 return shape_inference::UnknownShape(c); 63 } 64 65 const int64 rank = c->Rank(partitions_shape); 66 67 // data shape must start with partitions_shape 68 ShapeHandle unused; 69 TF_RETURN_IF_ERROR( 70 c->MergePrefix(data_shape, partitions_shape, &unused, &unused)); 71 72 // The partition shape is dynamic in the 0th dimension, and matches 73 // data_shape in the remaining dimensions. 74 ShapeHandle unknown_dim0 = c->MakeShape({c->UnknownDim()}); 75 76 ShapeHandle data_suffix_shape; 77 TF_RETURN_IF_ERROR(c->Subshape(data_shape, rank, &data_suffix_shape)); 78 ShapeHandle result_shape; 79 TF_RETURN_IF_ERROR( 80 c->Concatenate(unknown_dim0, data_suffix_shape, &result_shape)); 81 82 for (int i = 0; i < c->num_outputs(); ++i) { 83 c->set_output(i, result_shape); 84 } 85 86 return Status::OK(); 87 }); 88 89 namespace { 90 91 Status DynamicStitchShapeFunction(InferenceContext* c) { 92 int32 num_partitions; 93 TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions)); 94 95 bool all_indices_constant = true; 96 int32 max_index = 0; 97 ShapeHandle extra_shape = c->UnknownShape(); 98 for (int i = 0; i < num_partitions; ++i) { 99 const Tensor* indices_t = c->input_tensor(i); 100 if (indices_t == nullptr) { 101 all_indices_constant = false; 102 } 103 104 ShapeHandle indices_shape = c->input(i); 105 ShapeHandle data_shape = c->input(i + num_partitions); 106 if (!c->RankKnown(indices_shape)) { 107 continue; 108 } 109 const int64 indices_rank = c->Rank(indices_shape); 110 111 // Assert that data_shape starts with indices_shape. 112 ShapeHandle unused; 113 TF_RETURN_IF_ERROR( 114 c->MergePrefix(data_shape, indices_shape, &unused, &unused)); 115 116 // The rest belongs to output. 117 ShapeHandle rest; 118 TF_RETURN_IF_ERROR(c->Subshape(data_shape, indices_rank, &rest)); 119 TF_RETURN_IF_ERROR(c->Merge(extra_shape, rest, &extra_shape)); 120 121 if (indices_t != nullptr) { 122 // The length is based on the highest index from flattened indices. 123 const int32* indices = indices_t->flat<int32>().data(); 124 int64 count = indices_t->NumElements(); 125 for (int64 i = 0; i < count; ++i) { 126 if (indices[i] > max_index) { 127 max_index = indices[i]; 128 } 129 } 130 } 131 } 132 133 ShapeHandle output_shape = c->Vector( 134 all_indices_constant ? c->MakeDim(max_index + 1) : c->UnknownDim()); 135 TF_RETURN_IF_ERROR(c->Concatenate(output_shape, extra_shape, &output_shape)); 136 c->set_output(0, output_shape); 137 return Status::OK(); 138 } 139 140 } // namespace 141 142 REGISTER_OP("DynamicStitch") 143 .Input("indices: N * int32") 144 .Input("data: N * T") 145 .Output("merged: T") 146 .Attr("N : int >= 1") 147 .Attr("T : type") 148 .SetShapeFn(DynamicStitchShapeFunction); 149 150 REGISTER_OP("ParallelDynamicStitch") 151 .Input("indices: N * int32") 152 .Input("data: N * T") 153 .Output("merged: T") 154 .Attr("N : int >= 1") 155 .Attr("T : type") 156 .SetShapeFn(DynamicStitchShapeFunction); 157 158 // -------------------------------------------------------------------------- 159 160 namespace { 161 Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { 162 ShapeHandle handle; 163 DimensionHandle unused_handle; 164 for (int i = 0; i < c->num_inputs(); ++i) { 165 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); 166 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); 167 } 168 for (int i = 0; i < c->num_outputs(); ++i) { 169 c->set_output(i, c->Scalar()); 170 } 171 return Status::OK(); 172 } 173 174 Status TwoElementOutput(InferenceContext* c) { 175 c->set_output(0, c->Vector(2)); 176 return Status::OK(); 177 } 178 } // namespace 179 180 REGISTER_OP("RandomShuffleQueue") 181 .Output("handle: Ref(string)") 182 .Attr("component_types: list(type) >= 1") 183 .Attr("shapes: list(shape) >= 0 = []") 184 .Attr("capacity: int = -1") 185 .Attr("min_after_dequeue: int = 0") 186 .Attr("seed: int = 0") 187 .Attr("seed2: int = 0") 188 .Attr("container: string = ''") 189 .Attr("shared_name: string = ''") 190 .SetIsStateful() 191 .SetShapeFn(TwoElementOutput); 192 193 REGISTER_OP("RandomShuffleQueueV2") 194 .Output("handle: resource") 195 .Attr("component_types: list(type) >= 1") 196 .Attr("shapes: list(shape) >= 0 = []") 197 .Attr("capacity: int = -1") 198 .Attr("min_after_dequeue: int = 0") 199 .Attr("seed: int = 0") 200 .Attr("seed2: int = 0") 201 .Attr("container: string = ''") 202 .Attr("shared_name: string = ''") 203 .SetIsStateful() 204 .SetShapeFn(shape_inference::ScalarShape); 205 206 REGISTER_OP("FIFOQueue") 207 .Output("handle: Ref(string)") 208 .Attr("component_types: list(type) >= 1") 209 .Attr("shapes: list(shape) >= 0 = []") 210 .Attr("capacity: int = -1") 211 .Attr("container: string = ''") 212 .Attr("shared_name: string = ''") 213 .SetIsStateful() 214 .SetShapeFn(TwoElementOutput); 215 216 REGISTER_OP("FIFOQueueV2") 217 .Output("handle: resource") 218 .Attr("component_types: list(type) >= 1") 219 .Attr("shapes: list(shape) >= 0 = []") 220 .Attr("capacity: int = -1") 221 .Attr("container: string = ''") 222 .Attr("shared_name: string = ''") 223 .SetIsStateful() 224 .SetShapeFn(shape_inference::ScalarShape); 225 226 REGISTER_OP("PaddingFIFOQueue") 227 .Output("handle: Ref(string)") 228 .Attr("component_types: list(type) >= 1") 229 .Attr("shapes: list(shape) >= 0 = []") 230 .Attr("capacity: int = -1") 231 .Attr("container: string = ''") 232 .Attr("shared_name: string = ''") 233 .SetIsStateful() 234 .SetShapeFn(TwoElementOutput); 235 236 REGISTER_OP("PaddingFIFOQueueV2") 237 .Output("handle: resource") 238 .Attr("component_types: list(type) >= 1") 239 .Attr("shapes: list(shape) >= 0 = []") 240 .Attr("capacity: int = -1") 241 .Attr("container: string = ''") 242 .Attr("shared_name: string = ''") 243 .SetIsStateful() 244 .SetShapeFn(shape_inference::ScalarShape); 245 246 REGISTER_OP("PriorityQueue") 247 .Output("handle: Ref(string)") 248 .Attr("component_types: list(type) >= 0 = []") 249 .Attr("shapes: list(shape) >= 0") 250 .Attr("capacity: int = -1") 251 .Attr("container: string = ''") 252 .Attr("shared_name: string = ''") 253 .SetIsStateful() 254 .SetShapeFn(TwoElementOutput); 255 256 REGISTER_OP("PriorityQueueV2") 257 .Output("handle: resource") 258 .Attr("component_types: list(type) >= 0 = []") 259 .Attr("shapes: list(shape) >= 0") 260 .Attr("capacity: int = -1") 261 .Attr("container: string = ''") 262 .Attr("shared_name: string = ''") 263 .SetIsStateful() 264 .SetShapeFn(shape_inference::ScalarShape); 265 266 REGISTER_OP("FakeQueue") 267 .Input("resource: resource") 268 .Output("handle: Ref(string)") 269 .SetIsStateful() 270 .SetShapeFn(TwoElementOutput); 271 272 REGISTER_OP("QueueEnqueue") 273 .Input("handle: Ref(string)") 274 .Input("components: Tcomponents") 275 .Attr("Tcomponents: list(type) >= 1") 276 .Attr("timeout_ms: int = -1") 277 .SetShapeFn(shape_inference::UnknownShape); 278 279 REGISTER_OP("QueueEnqueueV2") 280 .Input("handle: resource") 281 .Input("components: Tcomponents") 282 .Attr("Tcomponents: list(type) >= 1") 283 .Attr("timeout_ms: int = -1") 284 .SetShapeFn(shape_inference::UnknownShape); 285 286 REGISTER_OP("QueueEnqueueMany") 287 .Input("handle: Ref(string)") 288 .Input("components: Tcomponents") 289 .Attr("Tcomponents: list(type) >= 1") 290 .Attr("timeout_ms: int = -1") 291 .SetShapeFn(shape_inference::UnknownShape); 292 293 REGISTER_OP("QueueEnqueueManyV2") 294 .Input("handle: resource") 295 .Input("components: Tcomponents") 296 .Attr("Tcomponents: list(type) >= 1") 297 .Attr("timeout_ms: int = -1") 298 .SetShapeFn(shape_inference::UnknownShape); 299 300 REGISTER_OP("QueueDequeue") 301 .Input("handle: Ref(string)") 302 .Output("components: component_types") 303 .Attr("component_types: list(type) >= 1") 304 .Attr("timeout_ms: int = -1") 305 .SetShapeFn(shape_inference::UnknownShape); 306 307 REGISTER_OP("QueueDequeueV2") 308 .Input("handle: resource") 309 .Output("components: component_types") 310 .Attr("component_types: list(type) >= 1") 311 .Attr("timeout_ms: int = -1") 312 .SetShapeFn([](InferenceContext* c) { 313 auto* t = c->input_handle_shapes_and_types(0); 314 if (t != nullptr && t->size() == c->num_outputs()) { 315 for (int i = 0; i < c->num_outputs(); ++i) { 316 c->set_output(i, (*t)[i].shape); 317 } 318 return Status::OK(); 319 } else { 320 return shape_inference::UnknownShape(c); 321 } 322 }); 323 324 REGISTER_OP("QueueDequeueMany") 325 .Input("handle: Ref(string)") 326 .Input("n: int32") 327 .Output("components: component_types") 328 .Attr("component_types: list(type) >= 1") 329 .Attr("timeout_ms: int = -1") 330 .SetShapeFn(shape_inference::UnknownShape); 331 332 REGISTER_OP("QueueDequeueManyV2") 333 .Input("handle: resource") 334 .Input("n: int32") 335 .Output("components: component_types") 336 .Attr("component_types: list(type) >= 1") 337 .Attr("timeout_ms: int = -1") 338 .SetShapeFn([](InferenceContext* c) { 339 ShapeHandle n_shape; 340 if (c->input_tensor(1) == nullptr) { 341 n_shape = c->Vector(InferenceContext::kUnknownDim); 342 } else { 343 const int32 n = c->input_tensor(1)->scalar<int32>()(); 344 if (n < 0) { 345 return errors::InvalidArgument("Input 'n' must be >= 0, but is ", n); 346 } 347 n_shape = c->Vector(n); 348 } 349 return DequeueManyV2Shape(c, n_shape); 350 }); 351 352 REGISTER_OP("QueueDequeueUpTo") 353 .Input("handle: Ref(string)") 354 .Input("n: int32") 355 .Output("components: component_types") 356 .Attr("component_types: list(type) >= 1") 357 .Attr("timeout_ms: int = -1") 358 .SetShapeFn(shape_inference::UnknownShape); 359 360 REGISTER_OP("QueueDequeueUpToV2") 361 .Input("handle: resource") 362 .Input("n: int32") 363 .Output("components: component_types") 364 .Attr("component_types: list(type) >= 1") 365 .Attr("timeout_ms: int = -1") 366 .SetShapeFn([](InferenceContext* c) { 367 return DequeueManyV2Shape(c, c->Vector(InferenceContext::kUnknownDim)); 368 }); 369 370 REGISTER_OP("QueueClose") 371 .Input("handle: Ref(string)") 372 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) 373 .Attr("cancel_pending_enqueues: bool = false"); 374 375 REGISTER_OP("QueueCloseV2") 376 .Input("handle: resource") 377 .SetShapeFn(shape_inference::NoOutputs) 378 .Attr("cancel_pending_enqueues: bool = false"); 379 380 REGISTER_OP("QueueIsClosed") 381 .Input("handle: Ref(string)") 382 .Output("is_closed: bool") 383 .SetShapeFn(shape_inference::ScalarShape); 384 385 REGISTER_OP("QueueIsClosedV2") 386 .Input("handle: resource") 387 .Output("is_closed: bool") 388 .SetShapeFn(shape_inference::ScalarShape); 389 390 REGISTER_OP("QueueSize") 391 .Input("handle: Ref(string)") 392 .Output("size: int32") 393 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); 394 395 REGISTER_OP("QueueSizeV2") 396 .Input("handle: resource") 397 .Output("size: int32") 398 .SetShapeFn(shape_inference::UnchangedShape); 399 400 // -------------------------------------------------------------------------- 401 402 REGISTER_OP("AccumulatorNumAccumulated") 403 .Input("handle: Ref(string)") 404 .Output("num_accumulated: int32") 405 .SetShapeFn(shape_inference::ScalarShape); 406 407 REGISTER_OP("AccumulatorSetGlobalStep") 408 .Input("handle: Ref(string)") 409 .Input("new_global_step: int64") 410 .SetShapeFn([](InferenceContext* c) { 411 ShapeHandle unused; 412 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 413 return Status::OK(); 414 }); 415 416 REGISTER_OP("ConditionalAccumulator") 417 .Output("handle: Ref(string)") 418 .Attr("dtype: numbertype") 419 .Attr("shape: shape") 420 .Attr("container: string = ''") 421 .Attr("shared_name: string = ''") 422 .SetIsStateful() 423 .SetShapeFn([](InferenceContext* c) { 424 c->set_output(0, c->Vector(2)); 425 return Status::OK(); 426 }); 427 428 REGISTER_OP("AccumulatorApplyGradient") 429 .Input("handle: Ref(string)") 430 .Input("local_step: int64") 431 .Input("gradient: dtype") 432 .Attr("dtype: numbertype") 433 .SetShapeFn([](InferenceContext* c) { 434 ShapeHandle unused; 435 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 436 return Status::OK(); 437 }); 438 439 REGISTER_OP("AccumulatorTakeGradient") 440 .Input("handle: Ref(string)") 441 .Input("num_required: int32") 442 .Output("average: dtype") 443 .SetShapeFn([](InferenceContext* c) { 444 ShapeHandle unused; 445 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 446 // Shape of output is the shape of the accumulator referenced 447 // by 'handle', but which is not available here, so we lose 448 // shape information. 449 return shape_inference::UnknownShape(c); 450 }) 451 .Attr("dtype: numbertype"); 452 453 REGISTER_OP("SparseConditionalAccumulator") 454 .Output("handle: Ref(string)") 455 .Attr("dtype: numbertype") 456 .Attr("shape: shape") 457 .Attr("container: string = ''") 458 .Attr("shared_name: string = ''") 459 .SetIsStateful() 460 .SetShapeFn([](InferenceContext* c) { 461 c->set_output(0, c->Vector(2)); 462 return Status::OK(); 463 }); 464 465 REGISTER_OP("SparseAccumulatorApplyGradient") 466 .Input("handle: Ref(string)") 467 .Input("local_step: int64") 468 .Input("gradient_indices: int64") 469 .Input("gradient_values: dtype") 470 .Input("gradient_shape: int64") 471 .Attr("dtype: numbertype") 472 .Attr("has_known_shape: bool") 473 .SetShapeFn([](InferenceContext* c) { 474 ShapeHandle unused; 475 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 476 return Status::OK(); 477 }); 478 479 REGISTER_OP("SparseAccumulatorTakeGradient") 480 .Input("handle: Ref(string)") 481 .Input("num_required: int32") 482 .Output("indices: int64") 483 .Output("values: dtype") 484 .Output("shape: int64") 485 .Attr("dtype: numbertype") 486 .SetShapeFn([](InferenceContext* c) { 487 ShapeHandle unused; 488 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 489 // Shape of output is the shape of the accumulator referenced 490 // by 'handle', but which is not available here, so we lose 491 // shape information. 492 return shape_inference::UnknownShape(c); 493 }); 494 495 // -------------------------------------------------------------------------- 496 497 REGISTER_OP("StackV2") 498 .Input("max_size: int32") 499 .Output("handle: resource") 500 .Attr("elem_type: type") 501 .Attr("stack_name: string = ''") 502 .SetIsStateful() 503 .SetShapeFn(TwoElementOutput); 504 505 REGISTER_OP("StackPushV2") 506 .Input("handle: resource") 507 .Input("elem: T") 508 .Output("output: T") 509 .Attr("T: type") 510 .Attr("swap_memory: bool = false") 511 .SetShapeFn([](shape_inference::InferenceContext* c) { 512 c->set_output(0, c->input(1)); 513 return Status::OK(); 514 }); 515 516 REGISTER_OP("StackPopV2") 517 .Input("handle: resource") 518 .Output("elem: elem_type") 519 .Attr("elem_type: type") 520 .SetShapeFn(shape_inference::UnknownShape); 521 522 REGISTER_OP("StackCloseV2") 523 .Input("handle: resource") 524 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); 525 526 // Deprecated ref-typed variants of stack. 527 528 REGISTER_OP("Stack") 529 .Output("handle: Ref(string)") 530 .Attr("elem_type: type") 531 .Attr("stack_name: string = ''") 532 .SetIsStateful() 533 .SetShapeFn(TwoElementOutput); 534 535 REGISTER_OP("StackPush") 536 .Input("handle: Ref(string)") 537 .Input("elem: T") 538 .Output("output: T") 539 .Attr("T: type") 540 .Attr("swap_memory: bool = false") 541 .SetShapeFn([](shape_inference::InferenceContext* c) { 542 c->set_output(0, c->input(1)); 543 return Status::OK(); 544 }); 545 546 REGISTER_OP("StackPop") 547 .Input("handle: Ref(string)") 548 .Output("elem: elem_type") 549 .Attr("elem_type: type") 550 .SetShapeFn(shape_inference::UnknownShape); 551 552 REGISTER_OP("StackClose") 553 .Input("handle: Ref(string)") 554 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); 555 556 // -------------------------------------------------------------------------- 557 558 REGISTER_OP("TensorArrayV3") 559 .Input("size: int32") 560 .Attr("dtype: type") 561 .Attr("element_shape: shape = { unknown_rank: true }") 562 .Attr("dynamic_size: bool = false") 563 .Attr("clear_after_read: bool = true") 564 .Attr("identical_element_shapes: bool = false") 565 .Attr("tensor_array_name: string = ''") 566 .Output("handle: resource") 567 .Output("flow: float") 568 .SetIsStateful() 569 .SetShapeFn([](InferenceContext* c) { 570 ShapeHandle unused; 571 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 572 c->set_output(0, c->Vector(2)); 573 c->set_output(1, c->Scalar()); 574 bool identical_shapes; 575 TF_RETURN_IF_ERROR( 576 c->GetAttr("identical_element_shapes", &identical_shapes)); 577 DataType t; 578 TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t)); 579 PartialTensorShape p; 580 TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p)); 581 ShapeHandle s; 582 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); 583 if (c->FullyDefined(s) || identical_shapes) { 584 c->set_output_handle_shapes_and_types( 585 0, std::vector<shape_inference::ShapeAndType>{{s, t}}); 586 } 587 return Status::OK(); 588 }); 589 590 REGISTER_OP("TensorArrayGradV3") 591 .Input("handle: resource") 592 .Input("flow_in: float") 593 .Output("grad_handle: resource") 594 .Output("flow_out: float") 595 .Attr("source: string") 596 .SetIsStateful() 597 .SetShapeFn([](InferenceContext* c) { 598 ShapeHandle handle; 599 DimensionHandle unused_dim; 600 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 601 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 602 c->set_output(0, c->Vector(2)); 603 c->set_output(1, c->Scalar()); 604 if (c->input_handle_shapes_and_types(0)) { 605 c->set_output_handle_shapes_and_types( 606 0, *c->input_handle_shapes_and_types(0)); 607 } 608 return Status::OK(); 609 }); 610 611 REGISTER_OP("TensorArrayWriteV3") 612 .Input("handle: resource") 613 .Input("index: int32") 614 .Input("value: T") 615 .Input("flow_in: float") 616 .Output("flow_out: float") 617 .Attr("T: type") 618 .SetShapeFn([](InferenceContext* c) { 619 ShapeHandle handle; 620 DimensionHandle unused_dim; 621 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 622 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 623 624 ShapeHandle unused; 625 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 626 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 627 628 auto* handle_data = c->input_handle_shapes_and_types(0); 629 if (handle_data != nullptr && !handle_data->empty()) { 630 shape_inference::ShapeAndType shape_and_type = (*handle_data)[0]; 631 ShapeHandle value_shape = c->input(2); 632 TF_RETURN_IF_ERROR( 633 c->Merge(shape_and_type.shape, value_shape, &unused)); 634 } 635 636 return shape_inference::ScalarShape(c); 637 }); 638 639 REGISTER_OP("TensorArrayReadV3") 640 .Input("handle: resource") 641 .Input("index: int32") 642 .Input("flow_in: float") 643 .Output("value: dtype") 644 .Attr("dtype: type") 645 .SetShapeFn([](InferenceContext* c) { 646 ShapeHandle handle; 647 DimensionHandle unused_dim; 648 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 649 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 650 ShapeHandle unused; 651 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 652 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 653 auto shapes = c->input_handle_shapes_and_types(0); 654 if (shapes != nullptr && !shapes->empty()) { 655 ShapeHandle tensor_shape = shapes->at(0).shape; 656 c->set_output(0, tensor_shape); 657 return Status::OK(); 658 } else { 659 return shape_inference::UnknownShape(c); 660 } 661 }); 662 663 REGISTER_OP("TensorArrayGatherV3") 664 .Input("handle: resource") 665 .Input("indices: int32") 666 .Input("flow_in: float") 667 .Output("value: dtype") 668 .Attr("dtype: type") 669 .Attr("element_shape: shape = { unknown_rank: true }") 670 .SetShapeFn([](InferenceContext* c) { 671 ShapeHandle unused; 672 DimensionHandle unused_dim; 673 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 674 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 675 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); 676 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 677 return shape_inference::UnknownShape(c); 678 }); 679 680 REGISTER_OP("TensorArrayScatterV3") 681 .Input("handle: resource") 682 .Input("indices: int32") 683 .Input("value: T") 684 .Input("flow_in: float") 685 .Output("flow_out: float") 686 .Attr("T: type") 687 .SetShapeFn([](InferenceContext* c) { 688 ShapeHandle unused; 689 DimensionHandle unused_dim; 690 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 691 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 692 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); 693 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 694 return shape_inference::ScalarShape(c); 695 }); 696 697 REGISTER_OP("TensorArrayConcatV3") 698 .Input("handle: resource") 699 .Input("flow_in: float") 700 .Output("value: dtype") 701 .Output("lengths: int64") 702 .Attr("dtype: type") 703 .Attr("element_shape_except0: shape = { unknown_rank: true }") 704 .SetShapeFn([](InferenceContext* c) { 705 ShapeHandle handle; 706 DimensionHandle unused_dim; 707 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 708 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 709 ShapeHandle unused; 710 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 711 c->set_output(0, c->UnknownShape()); 712 c->set_output(1, c->Vector(c->UnknownDim())); 713 return Status::OK(); 714 }); 715 716 REGISTER_OP("TensorArraySplitV3") 717 .Input("handle: resource") 718 .Input("value: T") 719 .Input("lengths: int64") 720 .Input("flow_in: float") 721 .Output("flow_out: float") 722 .Attr("T: type") 723 .SetShapeFn([](InferenceContext* c) { 724 ShapeHandle handle; 725 DimensionHandle unused_dim; 726 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 727 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 728 ShapeHandle unused; 729 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); 730 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 731 return shape_inference::ScalarShape(c); 732 }); 733 734 REGISTER_OP("TensorArraySizeV3") 735 .Input("handle: resource") 736 .Input("flow_in: float") 737 .Output("size: int32") 738 .SetShapeFn([](InferenceContext* c) { 739 ShapeHandle handle; 740 DimensionHandle unused_dim; 741 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 742 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 743 return shape_inference::ScalarShape(c); 744 }); 745 746 REGISTER_OP("TensorArrayCloseV3") 747 .Input("handle: resource") 748 .SetShapeFn([](InferenceContext* c) { 749 ShapeHandle handle; 750 DimensionHandle unused_dim; 751 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 752 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 753 return Status::OK(); 754 }); 755 756 // -------------------------------------------------------------------------- 757 758 // Deprecated TensorArray methods 759 760 REGISTER_OP("TensorArray") 761 .Input("size: int32") 762 .Attr("dtype: type") 763 .Attr("dynamic_size: bool = false") 764 .Attr("clear_after_read: bool = true") 765 .Attr("tensor_array_name: string = ''") 766 .Attr("element_shape: shape = { unknown_rank: true }") 767 .Output("handle: Ref(string)") 768 .SetIsStateful() 769 .SetShapeFn(shape_inference::UnknownShape) 770 .Deprecated(16, "Use TensorArrayV3"); 771 REGISTER_OP("TensorArrayV2") 772 .Input("size: int32") 773 .Attr("dtype: type") 774 .Attr("element_shape: shape = { unknown_rank: true }") 775 .Attr("dynamic_size: bool = false") 776 .Attr("clear_after_read: bool = true") 777 .Attr("tensor_array_name: string = ''") 778 .Output("handle: string") 779 .SetIsStateful() 780 .SetShapeFn([](InferenceContext* c) { 781 ShapeHandle unused; 782 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 783 c->set_output(0, c->Vector(2)); 784 return Status::OK(); 785 }) 786 .Deprecated(26, "Use TensorArrayV3"); 787 REGISTER_OP("TensorArrayGrad") 788 .Input("handle: string") 789 .Input("flow_in: float") 790 .Output("grad_handle: Ref(string)") 791 .Attr("source: string") 792 .SetIsStateful() 793 .SetShapeFn(shape_inference::UnknownShape) 794 .Deprecated(16, "Use TensorArrayGradV3"); 795 REGISTER_OP("TensorArrayGradV2") 796 .Input("handle: string") 797 .Input("flow_in: float") 798 .Output("grad_handle: string") 799 .Attr("source: string") 800 .SetIsStateful() 801 .SetShapeFn([](InferenceContext* c) { 802 ShapeHandle handle; 803 DimensionHandle unused_dim; 804 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 805 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 806 c->set_output(0, c->Vector(2)); 807 return Status::OK(); 808 }) 809 .Deprecated(26, "Use TensorArrayGradV3"); 810 REGISTER_OP("TensorArrayWrite") 811 .Input("handle: Ref(string)") 812 .Input("index: int32") 813 .Input("value: T") 814 .Input("flow_in: float") 815 .Output("flow_out: float") 816 .Attr("T: type") 817 .SetShapeFn(shape_inference::UnknownShape) 818 .Deprecated(16, "Use TensorArrayWriteV3"); 819 REGISTER_OP("TensorArrayWriteV2") 820 .Input("handle: string") 821 .Input("index: int32") 822 .Input("value: T") 823 .Input("flow_in: float") 824 .Output("flow_out: float") 825 .Attr("T: type") 826 .SetShapeFn([](InferenceContext* c) { 827 ShapeHandle handle; 828 DimensionHandle unused_dim; 829 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 830 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 831 832 ShapeHandle unused; 833 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 834 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 835 return shape_inference::ScalarShape(c); 836 }) 837 .Deprecated(26, "Use TensorArrayWriteV3"); 838 REGISTER_OP("TensorArrayRead") 839 .Input("handle: Ref(string)") 840 .Input("index: int32") 841 .Input("flow_in: float") 842 .Output("value: dtype") 843 .Attr("dtype: type") 844 .SetShapeFn(shape_inference::UnknownShape) 845 .Deprecated(16, "Use TensorArrayReadV3"); 846 REGISTER_OP("TensorArrayReadV2") 847 .Input("handle: string") 848 .Input("index: int32") 849 .Input("flow_in: float") 850 .Output("value: dtype") 851 .Attr("dtype: type") 852 .SetShapeFn([](InferenceContext* c) { 853 ShapeHandle handle; 854 DimensionHandle unused_dim; 855 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 856 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 857 ShapeHandle unused; 858 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 859 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 860 return shape_inference::UnknownShape(c); 861 }) 862 .Deprecated(26, "Use TensorArrayReadV3"); 863 REGISTER_OP("TensorArrayPack") 864 .Input("handle: Ref(string)") 865 .Input("flow_in: float") 866 .Output("value: dtype") 867 .Attr("dtype: type") 868 .Attr("element_shape: shape = { unknown_rank: true }") 869 .SetShapeFn(shape_inference::UnknownShape) 870 .Deprecated(16, "Use TensorArrayGatherV3 with RangeOp"); 871 REGISTER_OP("TensorArrayUnpack") 872 .Input("handle: Ref(string)") 873 .Input("value: T") 874 .Input("flow_in: float") 875 .Output("flow_out: float") 876 .Attr("T: type") 877 .SetShapeFn(shape_inference::UnknownShape) 878 .Deprecated(20, "Use TensorArrayScatterV3 with RangeOp"); 879 REGISTER_OP("TensorArrayGather") 880 .Input("handle: Ref(string)") 881 .Input("indices: int32") 882 .Input("flow_in: float") 883 .Output("value: dtype") 884 .Attr("dtype: type") 885 .Attr("element_shape: shape = { unknown_rank: true }") 886 .SetShapeFn(shape_inference::UnknownShape) 887 .Deprecated(16, "Use TensorArrayGatherV3"); 888 REGISTER_OP("TensorArrayGatherV2") 889 .Input("handle: string") 890 .Input("indices: int32") 891 .Input("flow_in: float") 892 .Output("value: dtype") 893 .Attr("dtype: type") 894 .Attr("element_shape: shape = { unknown_rank: true }") 895 .SetShapeFn([](InferenceContext* c) { 896 ShapeHandle unused; 897 DimensionHandle unused_dim; 898 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 899 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 900 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); 901 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 902 return shape_inference::UnknownShape(c); 903 }) 904 .Deprecated(26, "Use TensorArrayGatherV3"); 905 REGISTER_OP("TensorArrayScatter") 906 .Input("handle: Ref(string)") 907 .Input("indices: int32") 908 .Input("value: T") 909 .Input("flow_in: float") 910 .Output("flow_out: float") 911 .Attr("T: type") 912 .SetShapeFn(shape_inference::UnknownShape) 913 .Deprecated(19, "Use TensorArrayGradV3"); 914 REGISTER_OP("TensorArrayScatterV2") 915 .Input("handle: string") 916 .Input("indices: int32") 917 .Input("value: T") 918 .Input("flow_in: float") 919 .Output("flow_out: float") 920 .Attr("T: type") 921 .SetShapeFn([](InferenceContext* c) { 922 ShapeHandle unused; 923 DimensionHandle unused_dim; 924 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 925 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); 926 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); 927 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 928 return shape_inference::ScalarShape(c); 929 }) 930 .Deprecated(26, "Use TensorArrayScatterV3"); 931 REGISTER_OP("TensorArrayConcat") 932 .Input("handle: Ref(string)") 933 .Input("flow_in: float") 934 .Output("value: dtype") 935 .Output("lengths: int64") 936 .Attr("dtype: type") 937 .Attr("element_shape_except0: shape = { unknown_rank: true }") 938 .SetShapeFn(shape_inference::UnknownShape) 939 .Deprecated(16, "Use TensorArrayGradV3"); 940 REGISTER_OP("TensorArrayConcatV2") 941 .Input("handle: string") 942 .Input("flow_in: float") 943 .Output("value: dtype") 944 .Output("lengths: int64") 945 .Attr("dtype: type") 946 .Attr("element_shape_except0: shape = { unknown_rank: true }") 947 .SetShapeFn([](InferenceContext* c) { 948 ShapeHandle handle; 949 DimensionHandle unused_dim; 950 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 951 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 952 ShapeHandle unused; 953 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 954 c->set_output(0, c->UnknownShape()); 955 c->set_output(1, c->Vector(c->UnknownDim())); 956 return Status::OK(); 957 }); 958 REGISTER_OP("TensorArraySplit") 959 .Input("handle: Ref(string)") 960 .Input("value: T") 961 .Input("lengths: int64") 962 .Input("flow_in: float") 963 .Output("flow_out: float") 964 .Attr("T: type") 965 .SetShapeFn(shape_inference::UnknownShape) 966 .Deprecated(16, "Use TensorArraySplitV3"); 967 REGISTER_OP("TensorArraySplitV2") 968 .Input("handle: string") 969 .Input("value: T") 970 .Input("lengths: int64") 971 .Input("flow_in: float") 972 .Output("flow_out: float") 973 .Attr("T: type") 974 .SetShapeFn([](InferenceContext* c) { 975 ShapeHandle handle; 976 DimensionHandle unused_dim; 977 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 978 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 979 ShapeHandle unused; 980 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); 981 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 982 return shape_inference::ScalarShape(c); 983 }) 984 .Deprecated(26, "Use TensorArraySplitV3"); 985 REGISTER_OP("TensorArraySize") 986 .Input("handle: Ref(string)") 987 .Input("flow_in: float") 988 .Output("size: int32") 989 .SetShapeFn(shape_inference::UnknownShape) 990 .Deprecated(16, "Use TensorArraySizeV3"); 991 REGISTER_OP("TensorArraySizeV2") 992 .Input("handle: string") 993 .Input("flow_in: float") 994 .Output("size: int32") 995 .SetShapeFn([](InferenceContext* c) { 996 ShapeHandle handle; 997 DimensionHandle unused_dim; 998 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 999 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 1000 return shape_inference::ScalarShape(c); 1001 }) 1002 .Deprecated(26, "Use TensorArraySizeV3"); 1003 REGISTER_OP("TensorArrayClose") 1004 .Input("handle: Ref(string)") 1005 .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) 1006 .Deprecated(16, "Use TensorArrayCloseV3"); 1007 REGISTER_OP("TensorArrayCloseV2") 1008 .Input("handle: string") 1009 .SetShapeFn([](InferenceContext* c) { 1010 ShapeHandle handle; 1011 DimensionHandle unused_dim; 1012 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 1013 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 1014 return Status::OK(); 1015 }) 1016 .Deprecated(26, "Use TensorArrayCloseV3"); 1017 1018 // -------------------------------------------------------------------------- 1019 1020 REGISTER_OP("Barrier") 1021 .SetIsStateful() 1022 .Output("handle: Ref(string)") 1023 .Attr("component_types: list(type) >= 1") 1024 .Attr("shapes: list(shape) >= 0 = []") 1025 .Attr("capacity: int = -1") 1026 .Attr("container: string = ''") 1027 .Attr("shared_name: string = ''") 1028 .SetShapeFn(TwoElementOutput); 1029 1030 REGISTER_OP("BarrierInsertMany") 1031 .Input("handle: Ref(string)") 1032 .Input("keys: string") 1033 .Input("values: T") 1034 .Attr("T: type") 1035 .Attr("component_index: int") 1036 .SetShapeFn([](InferenceContext* c) { 1037 ShapeHandle keys = c->input(1); 1038 ShapeHandle values = c->input(2); 1039 ShapeHandle handle; 1040 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); 1041 DimensionHandle unused_dim; 1042 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); 1043 TF_RETURN_IF_ERROR(c->WithRank(keys, 1, &keys)); 1044 TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); 1045 TF_RETURN_IF_ERROR(c->Merge(keys, c->Vector(c->Dim(values, 0)), &handle)); 1046 return Status::OK(); 1047 }); 1048 1049 REGISTER_OP("BarrierTakeMany") 1050 .Input("handle: Ref(string)") 1051 .Input("num_elements: int32") 1052 .Output("indices: int64") 1053 .Output("keys: string") 1054 .Output("values: component_types") 1055 .Attr("component_types: list(type) >= 1") 1056 .Attr("allow_small_batch: bool = false") 1057 .Attr("wait_for_incomplete: bool = false") 1058 .Attr("timeout_ms: int = -1") 1059 .SetShapeFn(shape_inference::UnknownShape); 1060 1061 REGISTER_OP("BarrierClose") 1062 .Input("handle: Ref(string)") 1063 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) 1064 .Attr("cancel_pending_enqueues: bool = false"); 1065 1066 REGISTER_OP("BarrierReadySize") 1067 .Input("handle: Ref(string)") 1068 .Output("size: int32") 1069 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); 1070 1071 REGISTER_OP("BarrierIncompleteSize") 1072 .Input("handle: Ref(string)") 1073 .Output("size: int32") 1074 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); 1075 1076 // -------------------------------------------------------------------------- 1077 1078 REGISTER_OP("GetSessionHandle") 1079 .Input("value: T") 1080 .Output("handle: string") 1081 .Attr("T: type") 1082 .SetIsStateful() 1083 .SetShapeFn(shape_inference::ScalarShape); 1084 1085 REGISTER_OP("GetSessionHandleV2") 1086 .Input("value: T") 1087 .Output("handle: resource") 1088 .Attr("T: type") 1089 .SetIsStateful() 1090 .SetShapeFn(shape_inference::ScalarShape); 1091 1092 REGISTER_OP("GetSessionTensor") 1093 .Input("handle: string") 1094 .Output("value: dtype") 1095 .Attr("dtype: type") 1096 .SetIsStateful() 1097 .SetShapeFn([](InferenceContext* c) { 1098 ShapeHandle unused; 1099 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 1100 return shape_inference::UnknownShape(c); 1101 }); 1102 1103 REGISTER_OP("DeleteSessionTensor") 1104 .Input("handle: string") 1105 .SetIsStateful() 1106 .SetShapeFn([](InferenceContext* c) { 1107 ShapeHandle unused; 1108 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 1109 return Status::OK(); 1110 }); 1111 1112 REGISTER_OP("Stage") 1113 .Input("values: dtypes") 1114 .Attr("capacity: int >= 0 = 0") 1115 .Attr("memory_limit: int >= 0 = 0") 1116 .Attr("dtypes: list(type)") 1117 .Attr("container: string = ''") 1118 .Attr("shared_name: string = ''") 1119 .SetShapeFn(shape_inference::UnknownShape) 1120 .SetIsStateful(); 1121 1122 REGISTER_OP("Unstage") 1123 .Output("values: dtypes") 1124 .Attr("capacity: int >= 0 = 0") 1125 .Attr("memory_limit: int >= 0 = 0") 1126 .Attr("dtypes: list(type)") 1127 .Attr("container: string = ''") 1128 .Attr("shared_name: string = ''") 1129 .SetShapeFn(shape_inference::UnknownShape) 1130 .SetIsStateful(); 1131 1132 REGISTER_OP("StagePeek") 1133 .Input("index: int32") 1134 .Output("values: dtypes") 1135 .Attr("capacity: int >= 0 = 0") 1136 .Attr("memory_limit: int >= 0 = 0") 1137 .Attr("dtypes: list(type)") 1138 .Attr("container: string = ''") 1139 .Attr("shared_name: string = ''") 1140 .SetShapeFn(shape_inference::UnknownShape) 1141 .SetIsStateful(); 1142 1143 REGISTER_OP("StageSize") 1144 .Output("size: int32") 1145 .Attr("capacity: int >= 0 = 0") 1146 .Attr("memory_limit: int >= 0 = 0") 1147 .Attr("dtypes: list(type)") 1148 .Attr("container: string = ''") 1149 .Attr("shared_name: string = ''") 1150 .SetShapeFn(shape_inference::ScalarShape) 1151 .SetIsStateful(); 1152 1153 REGISTER_OP("StageClear") 1154 .Attr("capacity: int >= 0 = 0") 1155 .Attr("memory_limit: int >= 0 = 0") 1156 .Attr("dtypes: list(type)") 1157 .Attr("container: string = ''") 1158 .Attr("shared_name: string = ''") 1159 .SetShapeFn(shape_inference::UnknownShape) 1160 .SetIsStateful(); 1161 1162 // UnorderedMap 1163 REGISTER_OP("MapStage") 1164 .Input("key: int64") 1165 .Input("indices: int32") 1166 .Input("values: fake_dtypes") 1167 .Attr("capacity: int >= 0 = 0") 1168 .Attr("memory_limit: int >= 0 = 0") 1169 .Attr("dtypes: list(type)") 1170 .Attr("fake_dtypes: list(type)") 1171 .Attr("container: string = ''") 1172 .Attr("shared_name: string = ''") 1173 .SetShapeFn(tensorflow::shape_inference::NoOutputs) 1174 .SetIsStateful(); 1175 1176 REGISTER_OP("MapPeek") 1177 .Input("key: int64") 1178 .Input("indices: int32") 1179 .Output("values: dtypes") 1180 .Attr("capacity: int >= 0 = 0") 1181 .Attr("memory_limit: int >= 0 = 0") 1182 .Attr("dtypes: list(type)") 1183 .Attr("container: string = ''") 1184 .Attr("shared_name: string = ''") 1185 .SetShapeFn(tensorflow::shape_inference::UnknownShape) 1186 .SetIsStateful(); 1187 1188 REGISTER_OP("MapUnstage") 1189 .Input("key: int64") 1190 .Input("indices: int32") 1191 .Output("values: dtypes") 1192 .Attr("capacity: int >= 0 = 0") 1193 .Attr("memory_limit: int >= 0 = 0") 1194 .Attr("dtypes: list(type)") 1195 .Attr("container: string = ''") 1196 .Attr("shared_name: string = ''") 1197 .SetShapeFn(tensorflow::shape_inference::UnknownShape) 1198 .SetIsStateful(); 1199 1200 REGISTER_OP("MapUnstageNoKey") 1201 .Input("indices: int32") 1202 .Output("key: int64") 1203 .Output("values: dtypes") 1204 .Attr("capacity: int >= 0 = 0") 1205 .Attr("memory_limit: int >= 0 = 0") 1206 .Attr("dtypes: list(type)") 1207 .Attr("container: string = ''") 1208 .Attr("shared_name: string = ''") 1209 .SetShapeFn(tensorflow::shape_inference::UnknownShape) 1210 .SetIsStateful(); 1211 1212 REGISTER_OP("MapSize") 1213 .Output("size: int32") 1214 .Attr("capacity: int >= 0 = 0") 1215 .Attr("memory_limit: int >= 0 = 0") 1216 .Attr("dtypes: list(type)") 1217 .Attr("container: string = ''") 1218 .Attr("shared_name: string = ''") 1219 .SetShapeFn(tensorflow::shape_inference::ScalarShape) 1220 .SetIsStateful(); 1221 1222 REGISTER_OP("MapIncompleteSize") 1223 .Output("size: int32") 1224 .Attr("capacity: int >= 0 = 0") 1225 .Attr("memory_limit: int >= 0 = 0") 1226 .Attr("dtypes: list(type)") 1227 .Attr("container: string = ''") 1228 .Attr("shared_name: string = ''") 1229 .SetShapeFn(tensorflow::shape_inference::ScalarShape) 1230 .SetIsStateful(); 1231 1232 REGISTER_OP("MapClear") 1233 .Attr("capacity: int >= 0 = 0") 1234 .Attr("memory_limit: int >= 0 = 0") 1235 .Attr("dtypes: list(type)") 1236 .Attr("container: string = ''") 1237 .Attr("shared_name: string = ''") 1238 .SetShapeFn(tensorflow::shape_inference::NoOutputs) 1239 .SetIsStateful(); 1240 1241 // OrderedMap 1242 REGISTER_OP("OrderedMapStage") 1243 .Input("key: int64") 1244 .Input("indices: int32") 1245 .Input("values: fake_dtypes") 1246 .Attr("capacity: int >= 0 = 0") 1247 .Attr("memory_limit: int >= 0 = 0") 1248 .Attr("dtypes: list(type)") 1249 .Attr("fake_dtypes: list(type)") 1250 .Attr("container: string = ''") 1251 .Attr("shared_name: string = ''") 1252 .SetShapeFn(tensorflow::shape_inference::NoOutputs) 1253 .SetIsStateful(); 1254 1255 REGISTER_OP("OrderedMapPeek") 1256 .Input("key: int64") 1257 .Input("indices: int32") 1258 .Output("values: dtypes") 1259 .Attr("capacity: int >= 0 = 0") 1260 .Attr("memory_limit: int >= 0 = 0") 1261 .Attr("dtypes: list(type)") 1262 .Attr("container: string = ''") 1263 .Attr("shared_name: string = ''") 1264 .SetShapeFn(tensorflow::shape_inference::UnknownShape) 1265 .SetIsStateful(); 1266 1267 REGISTER_OP("OrderedMapUnstage") 1268 .Input("key: int64") 1269 .Input("indices: int32") 1270 .Output("values: dtypes") 1271 .Attr("capacity: int >= 0 = 0") 1272 .Attr("memory_limit: int >= 0 = 0") 1273 .Attr("dtypes: list(type)") 1274 .Attr("container: string = ''") 1275 .Attr("shared_name: string = ''") 1276 .SetShapeFn(tensorflow::shape_inference::UnknownShape) 1277 .SetIsStateful(); 1278 1279 REGISTER_OP("OrderedMapUnstageNoKey") 1280 .Input("indices: int32") 1281 .Output("key: int64") 1282 .Output("values: dtypes") 1283 .Attr("capacity: int >= 0 = 0") 1284 .Attr("memory_limit: int >= 0 = 0") 1285 .Attr("dtypes: list(type)") 1286 .Attr("container: string = ''") 1287 .Attr("shared_name: string = ''") 1288 .SetShapeFn(tensorflow::shape_inference::UnknownShape) 1289 .SetIsStateful(); 1290 1291 REGISTER_OP("OrderedMapSize") 1292 .Output("size: int32") 1293 .Attr("capacity: int >= 0 = 0") 1294 .Attr("memory_limit: int >= 0 = 0") 1295 .Attr("dtypes: list(type)") 1296 .Attr("container: string = ''") 1297 .Attr("shared_name: string = ''") 1298 .SetShapeFn(tensorflow::shape_inference::ScalarShape) 1299 .SetIsStateful(); 1300 1301 REGISTER_OP("OrderedMapIncompleteSize") 1302 .Output("size: int32") 1303 .Attr("capacity: int >= 0 = 0") 1304 .Attr("memory_limit: int >= 0 = 0") 1305 .Attr("dtypes: list(type)") 1306 .Attr("container: string = ''") 1307 .Attr("shared_name: string = ''") 1308 .SetShapeFn(tensorflow::shape_inference::ScalarShape) 1309 .SetIsStateful(); 1310 1311 REGISTER_OP("OrderedMapClear") 1312 .Attr("capacity: int >= 0 = 0") 1313 .Attr("memory_limit: int >= 0 = 0") 1314 .Attr("dtypes: list(type)") 1315 .Attr("container: string = ''") 1316 .Attr("shared_name: string = ''") 1317 .SetShapeFn(tensorflow::shape_inference::NoOutputs) 1318 .SetIsStateful(); 1319 1320 REGISTER_OP("RecordInput") 1321 .Output("records: string") 1322 .Attr("file_pattern: string") 1323 .Attr("file_random_seed: int = 301") 1324 .Attr("file_shuffle_shift_ratio: float = 0") 1325 .Attr("file_buffer_size: int = 10000") 1326 .Attr("file_parallelism: int = 16") 1327 .Attr("batch_size: int = 32") 1328 .Attr("compression_type: string = ''") 1329 .SetIsStateful() 1330 .SetShapeFn(shape_inference::UnknownShape); 1331 1332 } // namespace tensorflow 1333