1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <algorithm> 16 #include <iterator> 17 #include <memory> 18 #include <string> 19 #include <unordered_map> 20 #include <vector> 21 22 #include "absl/strings/str_join.h" 23 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 24 #include "tensorflow/contrib/lite/toco/model.h" 25 #include "tensorflow/contrib/lite/toco/tooling_util.h" 26 #include "tensorflow/core/platform/logging.h" 27 28 namespace toco { 29 30 namespace { 31 32 void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth, 33 int kheight, int stride_width, int stride_height, 34 PaddingType padding_type, Shape* output_shape, 35 FixedPadding* fixed_padding) { 36 const int input_width = input_shape.dims(2); 37 const int input_height = input_shape.dims(1); 38 const int batch = input_shape.dims(0); 39 40 int output_height = 0; 41 int output_width = 0; 42 if (padding_type == PaddingType::kValid) { 43 output_height = (input_height + stride_height - kheight) / stride_height; 44 output_width = (input_width + stride_width - kwidth) / stride_width; 45 } else if (padding_type == PaddingType::kSame) { 46 output_height = (input_height + stride_height - 1) / stride_height; 47 output_width = (input_width + stride_width - 1) / stride_width; 48 } else { 49 LOG(FATAL) << "Only supporting SAME or VALID padding"; 50 } 51 52 fixed_padding->height = std::max( 53 0, ((output_height - 1) * stride_height + kheight - input_height) / 2); 54 fixed_padding->width = std::max( 55 0, ((output_width - 1) * stride_width + kwidth - input_width) / 2); 56 57 // Actually had to debug a situation where those were negative due to bad 58 // propagation of placeholder -1 sizes in TensorFlowReshape. 59 CHECK_GT(output_width, 0); 60 CHECK_GT(output_height, 0); 61 output_shape->ReplaceDims({batch, output_height, output_width, output_depth}); 62 } 63 64 void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x, 65 const Shape& input_shape_y, 66 Array* output_array) { 67 // This matches the code in BroadcastBinaryOpShapeFn from tensorflow. 68 // It zips together the two input shapes and pads with 1 to make them the 69 // same length. For each dimension we broadcast if either dimension is 1 and 70 // otherwise expect them to match. 71 int rank_x = input_shape_x.dimensions_count(); 72 int rank_y = input_shape_y.dimensions_count(); 73 int rank_out = std::max(rank_x, rank_y); 74 std::vector<int>* dims_out = output_array->mutable_shape()->mutable_dims(); 75 dims_out->clear(); 76 dims_out->reserve(rank_out); 77 for (int i = 0; i < rank_out; ++i) { 78 int dim_x = i < (rank_out - rank_x) 79 ? 1 80 : input_shape_x.dims(i - (rank_out - rank_x)); 81 bool dim_y_is_one = i < (rank_out - rank_y); 82 int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y)); 83 if (dim_x == -1 || dim_y == -1) { 84 // One or both dimensions is unknown. 85 QCHECK(false) << "Shapes must be specified"; 86 } else if (dim_x == 1 || dim_y == 1) { 87 // Broadcast one dimension to the other that is 1. 88 if (dim_x == 1 && !dim_y_is_one) { 89 // Broadcast dim_y to dim_x (1). 90 dims_out->push_back(dim_y); 91 } else { 92 // Broadcast dim_x to dim_y (1). 93 DCHECK_EQ(dim_y, 1); 94 dims_out->push_back(dim_x); 95 } 96 } else { 97 // Expect the dimensions to match. 98 CHECK_EQ(dim_x, dim_y) << "Dimensions must match"; 99 dims_out->push_back(dim_x); 100 } 101 } 102 CHECK(output_array->has_shape()); 103 } 104 105 int GetOutputDepthFromWeights(const Model& model, const Operator& op) { 106 const string& weights_name = op.inputs[1]; 107 const auto& weights_shape = model.GetArray(weights_name).shape(); 108 if (op.type == OperatorType::kConv || 109 op.type == OperatorType::kFullyConnected) { 110 return weights_shape.dims(0); 111 } else if (op.type == OperatorType::kDepthwiseConv) { 112 return weights_shape.dims(3); 113 } else { 114 LOG(FATAL) << "Unhandled operator type"; 115 } 116 } 117 118 bool EnsureBiasVectorShape(Model* model, Operator* op) { 119 const string& weights_name = op->inputs[1]; 120 const auto& weights_array = model->GetArray(weights_name); 121 // Yield until weights shape has been resolved. 122 if (!weights_array.has_shape()) { 123 return false; 124 } 125 126 if (op->inputs.size() < 3) { 127 return false; 128 } 129 auto& bias_array = model->GetArray(op->inputs[2]); 130 if (bias_array.has_shape()) { 131 return true; 132 } 133 134 const int output_depth = GetOutputDepthFromWeights(*model, *op); 135 bias_array.copy_shape(Shape({output_depth})); 136 137 auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>(); 138 float_buffer.data.resize(output_depth, 0); 139 140 return true; 141 } 142 143 void ProcessConvOperator(Model* model, ConvOperator* op) { 144 if (!EnsureBiasVectorShape(model, op)) { 145 return; 146 } 147 148 const auto& input_array = model->GetArray(op->inputs[0]); 149 // Yield until input dims have been resolved. 150 if (!input_array.has_shape()) { 151 return; 152 } 153 const auto& input_shape = input_array.shape(); 154 CHECK_EQ(input_shape.dimensions_count(), 4); 155 156 const auto& weights_array = model->GetArray(op->inputs[1]); 157 // Yield until weights dims have been resolved. 158 if (!weights_array.has_shape()) { 159 return; 160 } 161 const auto& weights_shape = weights_array.shape(); 162 CHECK_EQ(weights_shape.dimensions_count(), 4); 163 164 auto& output_array = model->GetArray(op->outputs[0]); 165 const int output_depth = weights_shape.dims(0); 166 const int kheight = weights_shape.dims(1); 167 const int kwidth = weights_shape.dims(2); 168 ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, 169 op->stride_height, op->padding.type, 170 output_array.mutable_shape(), 171 &op->padding.GetOrCreateFixedPadding()); 172 CHECK_EQ(output_array.shape().dimensions_count(), 4); 173 174 // Set im2col array dimensions if there is one. 175 if (op->outputs.size() == 2) { 176 const auto& output_shape = output_array.shape(); 177 const int input_depth = weights_shape.dims(3); 178 auto& im2col_array = model->GetArray(op->outputs[1]); 179 im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1), 180 output_shape.dims(2), 181 input_depth * kheight * kwidth}); 182 } 183 } 184 185 void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { 186 if (!EnsureBiasVectorShape(model, op)) { 187 return; 188 } 189 190 const auto& input_array = model->GetArray(op->inputs[0]); 191 // Yield until input dims have been resolved. 192 if (!input_array.has_shape()) { 193 return; 194 } 195 const auto& input_shape = input_array.shape(); 196 CHECK_EQ(input_shape.dimensions_count(), 4); 197 198 const auto& weights_array = model->GetArray(op->inputs[1]); 199 // Yield until weights dims have been resolved. 200 if (!weights_array.has_shape()) { 201 return; 202 } 203 const auto& weights_shape = weights_array.shape(); 204 CHECK_EQ(weights_shape.dimensions_count(), 4); 205 206 const string& output_name = op->outputs[0]; 207 const int input_depth = input_shape.dims(3); 208 const int output_depth = weights_shape.dims(3); 209 // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops, 210 // instead it has to be inferred from the weights dims. However, once we are 211 // here, weights dims have already been converted to our own internal format, 212 // where the multiplier is no longer readily apparent. So instead we get it 213 // as the quotient of output and input depths. We only want to do that when 214 // depth_multiplier had the zero value: any other value should be checked 215 // as done by the next if() below. 216 if (!op->depth_multiplier) { 217 op->depth_multiplier = output_depth / input_depth; 218 } 219 QCHECK_EQ(output_depth, input_depth * op->depth_multiplier) 220 << "input/output depths and depth_multiplier don't match"; 221 222 const int kheight = weights_shape.dims(1); 223 const int kwidth = weights_shape.dims(2); 224 ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, 225 op->stride_height, op->padding.type, 226 model->GetArray(output_name).mutable_shape(), 227 &op->padding.GetOrCreateFixedPadding()); 228 } 229 230 void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) { 231 const auto& input_array = model->GetArray(op->inputs[0]); 232 // Yield until input dims have been resolved. 233 if (!input_array.has_shape()) { 234 return; 235 } 236 const auto& input_shape = input_array.shape(); 237 CHECK_EQ(input_shape.dimensions_count(), 4); 238 239 const string& output_name = op->outputs[0]; 240 const int block_size = op->block_size; 241 CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name; 242 const int batch = input_shape.dims(0); 243 const int height = input_shape.dims(1); 244 const int width = input_shape.dims(2); 245 const int depth = input_shape.dims(3); 246 QCHECK_EQ(depth % (block_size * block_size), 0); 247 248 model->GetArray(output_name) 249 .copy_shape(Shape({batch, height * block_size, width * block_size, 250 depth / block_size / block_size})); 251 } 252 253 void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) { 254 const auto& input_array = model->GetArray(op->inputs[0]); 255 // Yield until input dims have been resolved. 256 if (!input_array.has_shape()) { 257 return; 258 } 259 const auto& input_shape = input_array.shape(); 260 CHECK_EQ(input_shape.dimensions_count(), 4); 261 262 const string& output_name = op->outputs[0]; 263 const int block_size = op->block_size; 264 CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name; 265 const int batch = input_shape.dims(0); 266 const int height = input_shape.dims(1); 267 const int width = input_shape.dims(2); 268 const int depth = input_shape.dims(3); 269 QCHECK_EQ(width % block_size, 0); 270 QCHECK_EQ(height % block_size, 0); 271 272 model->GetArray(output_name) 273 .copy_shape(Shape({batch, height / block_size, width / block_size, 274 depth * block_size * block_size})); 275 } 276 277 void ProcessFillOperator(Model* model, FillOperator* op) { 278 CHECK_EQ(op->inputs.size(), 2); 279 CHECK_EQ(op->outputs.size(), 1); 280 auto& output_array = model->GetArray(op->outputs[0]); 281 if (output_array.has_shape()) { 282 // We have already run 283 return; 284 } 285 286 auto& dims_array = model->GetArray(op->inputs[0]); 287 if (!dims_array.has_shape()) { 288 // Yield until dims shape been resolved. 289 return; 290 } 291 if (!dims_array.buffer) { 292 // Yield until the dims are constant 293 return; 294 } 295 CHECK(dims_array.data_type == ArrayDataType::kInt32) << "dims must be int32"; 296 CHECK_LE(RequiredBufferSizeForShape(dims_array.shape()), 4) 297 << "dims vector can be no larger than 4 values"; 298 299 std::vector<int32> const& dims = 300 dims_array.GetBuffer<ArrayDataType::kInt32>().data; 301 *(output_array.mutable_shape()->mutable_dims()) = dims; 302 } 303 304 void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { 305 if (!EnsureBiasVectorShape(model, op)) { 306 return; 307 } 308 309 const auto& input_array = model->GetArray(op->inputs[0]); 310 // Yield until input dims have been resolved. 311 if (!input_array.has_shape()) { 312 return; 313 } 314 const auto& input_shape = input_array.shape(); 315 CHECK_GE(input_shape.dimensions_count(), 1); 316 317 const auto& weights_array = model->GetArray(op->inputs[1]); 318 // Yield until weights dims have been resolved. 319 if (!weights_array.has_shape()) { 320 return; 321 } 322 const auto& weights_shape = weights_array.shape(); 323 324 const int weights_output_depth = weights_shape.dims(0); 325 CHECK_EQ(weights_shape.dimensions_count(), 2); 326 327 const int input_overall_size = RequiredBufferSizeForShape(input_shape); 328 const int matmul_repeats = input_overall_size / weights_shape.dims(1); 329 CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size); 330 331 auto& output_array = model->GetArray(op->outputs[0]); 332 output_array.copy_shape(Shape({matmul_repeats, weights_output_depth})); 333 } 334 335 void ProcessTensorFlowReshapeOperator(Model* model, 336 TensorFlowReshapeOperator* op) { 337 auto& output_array = model->GetArray(op->outputs[0]); 338 if (output_array.has_shape()) { 339 // We have already run 340 return; 341 } 342 343 const auto& input_array = model->GetArray(op->inputs[0]); 344 if (!input_array.has_shape()) { 345 // Yield until input dims have been resolved. 346 return; 347 } 348 const auto& input_shape = input_array.shape(); 349 350 auto& shape_array = model->GetArray(op->inputs[1]); 351 if (!shape_array.has_shape()) { 352 // Yield until target_shape shape been resolved. 353 return; 354 } 355 if (!shape_array.buffer) { 356 // Yield until the target_shape is constant 357 return; 358 } 359 CHECK(shape_array.data_type == ArrayDataType::kInt32) 360 << "Reshape dims must be int32"; 361 362 // shape_data is the raw array of ints describing the shape 363 // in the TensorFlow node. We intentionally make a copy here, rather than 364 // modify wildcards in-place below, because in some graphs, the same shape 365 // array with a wildcard may be referenced from multiple Reshape nodes, where 366 // the wildcard needs to resolved to distinct values. 367 std::vector<int32> shape_data = 368 shape_array.GetBuffer<ArrayDataType::kInt32>().data; 369 // The Reshape shape may have a wildcard dim, encoded as -1. 370 bool has_wildcard = false; 371 int wildcard_index = 0; 372 int product_non_wildcard_dims = 1; 373 for (int i = 0; i < shape_data.size(); i++) { 374 if (shape_data[i] == -1) { 375 CHECK(!has_wildcard); 376 has_wildcard = true; 377 wildcard_index = i; 378 } else { 379 product_non_wildcard_dims *= shape_data[i]; 380 } 381 } 382 const int input_flat_size = RequiredBufferSizeForShape(input_shape); 383 if (has_wildcard) { 384 CHECK_GE(input_flat_size, product_non_wildcard_dims) 385 << "Array not large enough to fill the requested dimensions for " 386 "Reshape op with output \"" 387 << op->outputs[0] << "\". Are your input shapes correct?"; 388 shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims; 389 } 390 auto& output_shape = *output_array.mutable_shape(); 391 *output_shape.mutable_dims() = shape_data; 392 CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape)) 393 << "Input cannot be reshaped to requested dimensions for Reshape op with " 394 "output \"" 395 << op->outputs[0] << "\". Are your input shapes correct?"; 396 } 397 398 void ProcessSimpleOperator(Model* model, Operator* op) { 399 const auto& input_array = model->GetArray(op->inputs[0]); 400 // Yield until input dims have been resolved. 401 if (!input_array.has_shape()) { 402 return; 403 } 404 405 const string& output_name = op->outputs[0]; 406 auto& output_array = model->GetArray(output_name); 407 if (output_array.has_shape()) { 408 return; 409 } 410 411 output_array.copy_shape(input_array.shape()); 412 } 413 414 void ProcessSimpleBinaryOperator(Model* model, Operator* op) { 415 CHECK_EQ(op->inputs.size(), 2); 416 const auto& input0_array = model->GetArray(op->inputs[0]); 417 const auto& input1_array = model->GetArray(op->inputs[1]); 418 // Yield until input dims have been resolved. 419 if (!input0_array.has_shape() || !input1_array.has_shape()) { 420 return; 421 } 422 const string& output_name = op->outputs[0]; 423 auto& output_array = model->GetArray(output_name); 424 ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(), 425 &output_array); 426 } 427 428 void ProcessAddNOperator(Model* model, Operator* op) { 429 // Yield until all input dims have been resolved. 430 // 431 // TODO(myenik): Since AddN does not support broadcasting, maybe we could 432 // actually use this to improve shape propagation by propagating the shape of 433 // one input to all other inputs once it is resolved instead of just the 434 // output, since all inputs must be the same size and shape for a well-formed 435 // graph. 436 for (const auto& input : op->inputs) { 437 const auto& input_array = model->GetArray(input); 438 if (!input_array.has_shape()) { 439 return; 440 } 441 } 442 443 // AddN does not support broadcasting, all inputs must be the same shape, so 444 // we just take the first input shape and apply it to the output. 445 const auto& input0_array = model->GetArray(op->inputs[0]); 446 auto& output_array = model->GetArray(op->outputs[0]); 447 output_array.copy_shape(input0_array.shape()); 448 } 449 450 bool KeepDims(const Operator& op) { 451 switch (op.type) { 452 case OperatorType::kTensorFlowMin: 453 return static_cast<const TensorFlowMinOperator&>(op).keep_dims; 454 case OperatorType::kTensorFlowMax: 455 return static_cast<const TensorFlowMaxOperator&>(op).keep_dims; 456 case OperatorType::kTensorFlowSum: 457 return static_cast<const TensorFlowSumOperator&>(op).keep_dims; 458 case OperatorType::kMean: 459 return static_cast<const MeanOperator&>(op).keep_dims; 460 default: 461 LOG(FATAL) << "Not a reduction operator!"; 462 return false; 463 } 464 } 465 466 void ProcessTensorFlowReductionOperator(Model* model, Operator* op) { 467 CHECK_LE(op->inputs.size(), 2); 468 auto& output_array = model->GetArray(op->outputs[0]); 469 if (output_array.has_shape()) { 470 return; 471 } 472 const auto& input_array = model->GetArray(op->inputs[0]); 473 if (!input_array.has_shape()) { 474 return; 475 } 476 const auto& input_shape = input_array.shape(); 477 const bool keep_dims = KeepDims(*op); 478 if (op->inputs.size() == 2) { 479 // There is a reduction_indices input. 480 const auto& reduction_array = model->GetArray(op->inputs[1]); 481 if (!reduction_array.buffer) { 482 return; 483 } 484 CHECK(reduction_array.buffer->type == ArrayDataType::kInt32); 485 const auto& reduction_array_vals = 486 reduction_array.GetBuffer<ArrayDataType::kInt32>().data; 487 auto& output_dims = *output_array.mutable_shape()->mutable_dims(); 488 output_dims.clear(); 489 for (int i = 0; i < input_shape.dimensions_count(); i++) { 490 bool is_reduction_dim = false; 491 for (int r : reduction_array_vals) { 492 if (i == r) { 493 is_reduction_dim = true; 494 } 495 } 496 if (!is_reduction_dim) { 497 output_dims.push_back(input_shape.dims(i)); 498 } else if (keep_dims) { 499 output_dims.push_back(1); 500 } 501 } 502 } else { 503 // No reduction_indices means complete reduction to a single scalar. 504 if (keep_dims) { 505 output_array.copy_shape(input_shape); 506 } else { 507 output_array.copy_shape(Shape({})); 508 } 509 } 510 } 511 512 void ProcessSliceOperator(Model* model, SliceOperator* op) { 513 CHECK_EQ(op->inputs.size(), 3); 514 CHECK_EQ(op->outputs.size(), 1); 515 516 // Yield until the Slice params have been resolved. 517 if (op->begin.empty()) return; 518 519 // Yield until input dims have been resolved. 520 const auto& input_array = model->GetArray(op->inputs[0]); 521 if (!input_array.has_shape()) return; 522 const Shape& input_shape = input_array.shape(); 523 524 auto& output_array = model->GetArray(op->outputs[0]); 525 if (output_array.has_shape()) return; 526 527 CHECK_EQ(input_shape.dims().size(), op->size.size()); 528 CHECK_EQ(op->begin.size(), op->size.size()); 529 530 std::vector<int> output_dims; 531 for (int i = 0; i < op->begin.size(); ++i) { 532 int size = op->size[i]; 533 if (size == -1) { 534 size = input_array.shape().dims(i) - op->begin[i]; 535 } 536 output_dims.push_back(size); 537 } 538 539 *output_array.mutable_shape()->mutable_dims() = output_dims; 540 } 541 542 void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) { 543 const string& input_name = op->inputs[0]; 544 const auto& input_array = model->GetArray(input_name); 545 // Yield until input dims have been resolved. 546 if (!input_array.has_shape()) { 547 return; 548 } 549 const auto& input_shape = input_array.shape(); 550 const string& output_name = op->outputs[0]; 551 Shape* output_shape = model->GetArray(output_name).mutable_shape(); 552 ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order, 553 output_shape); 554 } 555 556 void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { 557 // Yield until input dims have been resolved. 558 for (const auto& input_name : op->inputs) { 559 auto& input_array = model->GetArray(input_name); 560 if (!input_array.has_shape()) { 561 return; 562 } 563 } 564 auto& output_array = model->GetArray(op->outputs[0]); 565 // Use 0 input as basis for output dimensions. 566 const auto& first_input_array = model->GetArray(op->inputs[0]); 567 output_array.copy_shape(first_input_array.shape()); 568 // Negative axis means the count starts at the back of the dims(). 569 int axis = op->axis; 570 if (axis < 0) axis += first_input_array.shape().dims().size(); 571 // Determine the concat size, and enfore that all inputs have 572 // the same dimensions count. 573 int concat_size = 0; 574 for (const auto& input_name : op->inputs) { 575 auto& input_array = model->GetArray(input_name); 576 CHECK(input_array.has_shape()); 577 if (input_array.shape().dimensions_count() == 0) { 578 continue; 579 } 580 CHECK_EQ(input_array.shape().dimensions_count(), 581 output_array.shape().dimensions_count()); 582 const std::vector<int>& input_dims = input_array.shape().dims(); 583 CHECK_LT(axis, input_dims.size()); 584 concat_size += input_dims[axis]; 585 } 586 // Write out the concat_size on the output array shape. 587 auto& output_shape = *output_array.mutable_shape(); 588 auto& output_dims = *output_shape.mutable_dims(); 589 CHECK_LT(axis, output_shape.dimensions_count()); 590 output_dims[axis] = concat_size; 591 } 592 593 void ProcessRangeOperator(Model* model, RangeOperator* op) { 594 CHECK_EQ(op->inputs.size(), 3); 595 const auto& start_array = model->GetArray(op->inputs[0]); 596 if (!start_array.has_shape()) { 597 // Yield until input dims have been resolved. 598 return; 599 } 600 const auto& limit_array = model->GetArray(op->inputs[1]); 601 if (!limit_array.has_shape()) { 602 return; 603 } 604 const auto& delta_array = model->GetArray(op->inputs[2]); 605 if (!delta_array.has_shape()) { 606 return; 607 } 608 609 if (!IsConstantParameterArray(*model, op->inputs[0])) { 610 // Yield until inputs are constant. 611 return; 612 } 613 if (!IsConstantParameterArray(*model, op->inputs[1])) { 614 return; 615 } 616 if (!IsConstantParameterArray(*model, op->inputs[2])) { 617 return; 618 } 619 620 CHECK(start_array.data_type == ArrayDataType::kInt32) 621 << "Range op inputs must be int32."; 622 CHECK(limit_array.data_type == ArrayDataType::kInt32) 623 << "Range op inputs must be int32."; 624 CHECK(delta_array.data_type == ArrayDataType::kInt32) 625 << "Range op inputs must be int32."; 626 CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1) 627 << "Range op inputs must be scalar."; 628 CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1) 629 << "Range op inputs must be scalar."; 630 CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1) 631 << "Range op inputs must be scalar."; 632 int size = floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] - 633 start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) / 634 delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]); 635 636 // Only set the output shape. Contents are set by ResolveConstantRange. 637 CHECK_EQ(op->outputs.size(), 1); 638 auto& output_array = model->GetArray(op->outputs[0]); 639 Shape* output_shape = output_array.mutable_shape(); 640 output_shape->ReplaceDims({size}); 641 } 642 643 void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { 644 CHECK_EQ(op->inputs.size(), 2); 645 const string& input_name = op->inputs[1]; 646 const auto& input_array = model->GetArray(input_name); 647 // Yield until input dims have been resolved. 648 if (!input_array.has_shape()) { 649 return; 650 } 651 const Shape& input_shape = input_array.shape(); 652 653 // Yield until axis is constant. 654 if (!IsConstantParameterArray(*model, op->inputs[0])) { 655 return; 656 } 657 658 const auto& axis_array = model->GetArray(op->inputs[0]); 659 660 // Yield until axis dims have been resolved. 661 if (!axis_array.has_shape()) { 662 return; 663 } 664 665 CHECK(axis_array.data_type == ArrayDataType::kInt32) 666 << "Axis array must be int32."; 667 CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1) 668 << "Axis array must be scalar."; 669 670 int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0]; 671 if (axis < 0) { 672 axis += input_shape.dimensions_count(); 673 } 674 675 const int split_dim = input_shape.dims(axis); 676 CHECK_EQ(split_dim % op->num_split, 0); 677 const int split_depth = split_dim / op->num_split; 678 679 Shape output_shape = input_shape; 680 (*output_shape.mutable_dims())[axis] = split_depth; 681 682 CHECK_EQ(op->outputs.size(), op->num_split); 683 for (const auto& output : op->outputs) { 684 model->GetArray(output).copy_shape(output_shape); 685 } 686 } 687 688 void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) { 689 const string& input_name = op->inputs[0]; 690 const auto& input_array = model->GetArray(input_name); 691 // Yield until input dims have been resolved. 692 if (!input_array.has_shape()) { 693 return; 694 } 695 const auto& input_shape = input_array.shape(); 696 CHECK_EQ(input_shape.dimensions_count(), 4); 697 const string& output_name = op->outputs[0]; 698 const int output_depth = input_shape.dims(3); 699 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, 700 op->stride_width, op->stride_height, op->padding.type, 701 model->GetArray(output_name).mutable_shape(), 702 &op->padding.GetOrCreateFixedPadding()); 703 } 704 705 void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) { 706 const string& input_name = op->inputs[0]; 707 const auto& input_array = model->GetArray(input_name); 708 // Yield until input dims have been resolved. 709 if (!input_array.has_shape()) { 710 return; 711 } 712 const auto& input_shape = input_array.shape(); 713 CHECK_EQ(input_shape.dimensions_count(), 4); 714 const string& output_name = op->outputs[0]; 715 const int output_depth = input_shape.dims(3); 716 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, 717 op->stride_width, op->stride_height, op->padding.type, 718 model->GetArray(output_name).mutable_shape(), 719 &op->padding.GetOrCreateFixedPadding()); 720 } 721 722 void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) { 723 const string& input_name = op->inputs[0]; 724 const auto& input_array = model->GetArray(input_name); 725 // Yield until input dims have been resolved. 726 if (!input_array.has_shape()) { 727 return; 728 } 729 const auto& input_shape = input_array.shape(); 730 if (input_shape.dimensions_count() < 4) { 731 LOG(FATAL) << "missing dimensions for " << input_name; 732 } 733 const string& output_name = op->outputs[0]; 734 const int output_depth = input_shape.dims(3); 735 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, 736 op->stride_width, op->stride_height, op->padding.type, 737 model->GetArray(output_name).mutable_shape(), 738 &op->padding.GetOrCreateFixedPadding()); 739 } 740 741 void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) { 742 CHECK_EQ(op->inputs.size(), 2); 743 CHECK_EQ(op->outputs.size(), 1); 744 745 if (!model->GetArray(op->inputs[0]).has_shape() || 746 !model->GetArray(op->inputs[1]).has_shape()) { 747 return; 748 } 749 const auto& input_data_shape = model->GetArray(op->inputs[0]).shape(); 750 751 const string& output_size_name = op->inputs[1]; 752 const auto& output_size_array = model->GetArray(output_size_name); 753 CHECK(output_size_array.data_type == ArrayDataType::kInt32); 754 CHECK(output_size_array.has_shape()); 755 const auto& output_size_shape = output_size_array.shape(); 756 CHECK_EQ(output_size_shape.dimensions_count(), 1); 757 CHECK_EQ(output_size_shape.dims(0), 2); 758 if (!output_size_array.buffer) { 759 return; 760 } 761 std::vector<int32> output_shape = 762 output_size_array.GetBuffer<ArrayDataType::kInt32>().data; 763 model->GetArray(op->outputs[0]) 764 .copy_shape(Shape({input_data_shape.dims(0), output_shape[0], 765 output_shape[1], input_data_shape.dims(3)})); 766 } 767 768 void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { 769 // Only required for compact LstmCell with default NUM_INPUTS of inputs. 770 if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return; 771 772 const auto& input_array = 773 model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]); 774 // Yield until all input dims have been resolved. 775 if (!input_array.has_shape()) { 776 return; 777 } 778 const auto& input_shape = input_array.shape(); 779 CHECK_GE(input_shape.dimensions_count(), 2); 780 781 const auto& prev_activ_array = 782 model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]); 783 // Yield until all input dims have been resolved. 784 if (!prev_activ_array.has_shape()) { 785 return; 786 } 787 const auto& prev_activ_shape = prev_activ_array.shape(); 788 CHECK_GE(prev_activ_shape.dimensions_count(), 2); 789 790 const auto& weights_array = 791 model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]); 792 // Yield until weights dims have been resolved. 793 if (!weights_array.has_shape()) { 794 return; 795 } 796 const auto& weights_shape = weights_array.shape(); 797 CHECK_EQ(weights_shape.dimensions_count(), 2); 798 799 const auto& bias_array = 800 model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]); 801 // Yield until bias dims have been resolved. 802 if (!bias_array.has_shape()) { 803 return; 804 } 805 const auto& bias_shape = bias_array.shape(); 806 CHECK_GE(bias_shape.dimensions_count(), 1); 807 808 const auto& prev_state_array = 809 model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]); 810 // Yield until all input dims have been resolved. 811 if (!prev_state_array.has_shape()) { 812 return; 813 } 814 const auto& prev_state_shape = prev_state_array.shape(); 815 CHECK_GE(prev_state_shape.dimensions_count(), 2); 816 817 const int fc_output_depth = weights_shape.dims(0); 818 CHECK_EQ(fc_output_depth, bias_shape.dims(0)); 819 CHECK_EQ(fc_output_depth % 4, 0); 820 const int depth = fc_output_depth / 4; 821 822 const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1); 823 const int fc_input_depth = weights_shape.dims(1); 824 CHECK_EQ(input_depth + depth, fc_input_depth); 825 Shape output_shape(input_shape); 826 (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth; 827 828 // Set output dimensions 829 model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT]) 830 .copy_shape(output_shape); 831 model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT]) 832 .copy_shape(output_shape); 833 834 Shape concat_temp_shape(input_shape); 835 (*concat_temp_shape 836 .mutable_dims())[concat_temp_shape.dimensions_count() - 1] = 837 fc_input_depth; 838 model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP]) 839 .copy_shape(concat_temp_shape); 840 841 Shape activ_temp_shape(input_shape); 842 (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] = 843 fc_output_depth; 844 model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP]) 845 .copy_shape(activ_temp_shape); 846 } 847 848 void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { 849 const auto& input_array = model->GetArray(op->inputs[0]); 850 // Yield until input dims have been resolved. 851 if (!input_array.has_shape()) { 852 return; 853 } 854 const auto& input_shape = input_array.shape(); 855 // This method only handles input dimensions of 4. 856 if (input_shape.dimensions_count() != 4) { 857 return; 858 } 859 const auto input_height = input_shape.dims(1); 860 const auto input_width = input_shape.dims(2); 861 862 const auto& block_shape_array = model->GetArray(op->inputs[1]); 863 const auto& paddings_array = model->GetArray(op->inputs[2]); 864 const auto& block_shape_array_shape = block_shape_array.shape(); 865 const auto& paddings_array_shape = paddings_array.shape(); 866 QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); 867 QCHECK_EQ(paddings_array_shape.dimensions_count(), 2); 868 869 // We only support two dimensions. 870 QCHECK_EQ(block_shape_array_shape.dims(0), 2); 871 if (!block_shape_array.buffer) { 872 return; 873 } 874 QCHECK(block_shape_array.data_type == ArrayDataType::kInt32); 875 const auto& block_shape_data = 876 block_shape_array.GetBuffer<ArrayDataType::kInt32>().data; 877 auto block_height = block_shape_data[0]; 878 auto block_width = block_shape_data[1]; 879 880 QCHECK_EQ(paddings_array_shape.dims(0), 2); // Number of block dimensions 881 QCHECK_EQ(paddings_array_shape.dims(1), 2); // Two parameters per dimension. 882 if (!paddings_array.buffer) { 883 return; 884 } 885 QCHECK(paddings_array.data_type == ArrayDataType::kInt32); 886 const auto& paddings_data = 887 paddings_array.GetBuffer<ArrayDataType::kInt32>().data; 888 int height_with_paddings = input_height + paddings_data[0] + paddings_data[1]; 889 int width_with_paddings = input_width + paddings_data[2] + paddings_data[3]; 890 QCHECK_EQ(height_with_paddings % block_height, 0); 891 QCHECK_EQ(width_with_paddings % block_width, 0); 892 int output_height = height_with_paddings / block_height; 893 int output_width = width_with_paddings / block_width; 894 895 model->GetArray(op->outputs[0]) 896 .copy_shape(Shape({input_shape.dims(0) * block_height * block_width, 897 output_height, output_width, input_shape.dims(3)})); 898 } 899 900 void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { 901 const auto& input_array = model->GetArray(op->inputs[0]); 902 // Yield until input dims have been resolved. 903 if (!input_array.has_shape()) { 904 return; 905 } 906 const auto& input_shape = input_array.shape(); 907 CHECK_EQ(input_shape.dimensions_count(), 4); 908 const auto input_height = input_shape.dims(1); 909 const auto input_width = input_shape.dims(2); 910 911 const auto& block_shape_array = model->GetArray(op->inputs[1]); 912 const auto& crops_array = model->GetArray(op->inputs[2]); 913 const auto& block_shape_array_shape = block_shape_array.shape(); 914 const auto& crops_array_shape = crops_array.shape(); 915 QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); 916 QCHECK_EQ(crops_array_shape.dimensions_count(), 2); 917 918 // We only support two dimensions. 919 QCHECK_EQ(block_shape_array_shape.dims(0), 2); 920 if (!block_shape_array.buffer) { 921 return; 922 } 923 QCHECK(block_shape_array.data_type == ArrayDataType::kInt32); 924 const auto& block_shape_data = 925 block_shape_array.GetBuffer<ArrayDataType::kInt32>().data; 926 auto block_height = block_shape_data[0]; 927 auto block_width = block_shape_data[1]; 928 929 QCHECK_EQ(crops_array_shape.dims(0), 2); // Number of block dimensions 930 QCHECK_EQ(crops_array_shape.dims(1), 2); // Two parameters per dimension. 931 if (!crops_array.buffer) { 932 return; 933 } 934 QCHECK(crops_array.data_type == ArrayDataType::kInt32); 935 const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data; 936 // We don't support crops now. 937 QCHECK_EQ(crops_data[0], 0); 938 QCHECK_EQ(crops_data[1], 0); 939 QCHECK_EQ(crops_data[2], 0); 940 QCHECK_EQ(crops_data[3], 0); 941 942 QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0); 943 944 int output_height = input_height * block_height; 945 int output_width = input_width * block_width; 946 947 model->GetArray(op->outputs[0]) 948 .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width), 949 output_height, output_width, input_shape.dims(3)})); 950 } 951 952 void ProcessGatherOperator(Model* model, GatherOperator* op) { 953 const auto& input_array = model->GetArray(op->inputs[0]); 954 const auto& indices_array = model->GetArray(op->inputs[1]); 955 auto& output_array = model->GetArray(op->outputs[0]); 956 957 // Bail if we already know the output shape. 958 if (output_array.has_shape()) { 959 return; 960 } 961 962 // Yield until input dims have been resolved. 963 if (!input_array.has_shape() || !indices_array.has_shape()) { 964 return; 965 } 966 967 const auto& input_shape = input_array.shape(); 968 const auto& indices_shape = indices_array.shape(); 969 QCHECK_GE(input_shape.dimensions_count(), 1); 970 op->input_rank = input_shape.dimensions_count(); 971 972 // We only support 1-D indices. 973 QCHECK_EQ(indices_shape.dimensions_count(), 1); 974 975 // Copy the input dimensions to the output except for dimension 0, 976 // where the dimension of indices_shape is used. 977 // TODO(mgubin): if axis != 0 this is not true, change when it's supported. 978 auto output_dims = output_array.mutable_shape()->mutable_dims(); 979 output_dims->push_back(indices_shape.dims(0)); 980 for (int dim = 1; dim < input_shape.dimensions_count(); dim++) { 981 output_dims->push_back(input_shape.dims(dim)); 982 } 983 } 984 985 void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) { 986 const auto& input_values = model->GetArray(op->inputs[0]); 987 const auto& input_k = model->GetArray(op->inputs[1]); 988 auto& output_indexes = model->GetArray(op->outputs[0]); 989 auto& output_values = model->GetArray(op->outputs[1]); 990 991 // Bail if we already know the output shape. 992 if (output_indexes.has_shape()) { 993 QCHECK(output_values.has_shape()); 994 return; 995 } 996 997 // Yield until input dims have been resolved. 998 if (!input_values.has_shape()) { 999 return; 1000 } 1001 1002 const auto& input_values_shape = input_values.shape(); 1003 auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); 1004 auto output_values_dims = output_values.mutable_shape()->mutable_dims(); 1005 for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { 1006 output_indexes_dims->push_back(input_values_shape.dims(dim)); 1007 output_values_dims->push_back(input_values_shape.dims(dim)); 1008 } 1009 // If the value is initialized, we can specify the last dimension, otherwise 1010 // unknown. 1011 if (input_k.buffer) { 1012 const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0]; 1013 output_indexes_dims->push_back(k_value); 1014 output_values_dims->push_back(k_value); 1015 1016 } else { 1017 output_indexes_dims->push_back(0); 1018 output_values_dims->push_back(0); 1019 } 1020 } 1021 1022 void ProcessPadOperator(Model* model, PadOperator* op) { 1023 CHECK_EQ(op->inputs.size(), 2); 1024 CHECK_EQ(op->outputs.size(), 1); 1025 1026 const auto& input_array = model->GetArray(op->inputs[0]); 1027 1028 // Yield until input dims have been resolved. 1029 if (!input_array.has_shape()) return; 1030 1031 if (op->left_padding.empty()) return; 1032 CHECK_EQ(op->left_padding.size(), op->right_padding.size()); 1033 1034 auto& output_array = model->GetArray(op->outputs[0]); 1035 if (output_array.has_shape()) return; 1036 1037 Shape output_shape = input_array.shape(); 1038 std::vector<int>& dims = *output_shape.mutable_dims(); 1039 CHECK_EQ(op->left_padding.size(), dims.size()); 1040 1041 for (int i = 0; i < op->left_padding.size(); ++i) { 1042 dims[i] += op->left_padding[i] + op->right_padding[i]; 1043 } 1044 1045 output_array.copy_shape(output_shape); 1046 } 1047 1048 void ProcessRankOperator(Model* model, RankOperator* op) { 1049 CHECK_GE(op->inputs.size(), 1); 1050 CHECK_EQ(op->outputs.size(), 1); 1051 auto& output_array = model->GetArray(op->outputs[0]); 1052 if (output_array.has_shape()) { 1053 // Shape already propagated 1054 return; 1055 } 1056 1057 const auto& input_array = model->GetArray(op->inputs[0]); 1058 if (!input_array.has_shape()) { 1059 // Yield until input dims have been resolved. 1060 return; 1061 } 1062 1063 // Only set the output shape. Array contents are set by 1064 // ResolveConstantShapeOrRank. 1065 Shape* output_shape = output_array.mutable_shape(); 1066 output_shape->ReplaceDims({}); 1067 } 1068 1069 void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) { 1070 CHECK_GE(op->inputs.size(), 1); 1071 CHECK_EQ(op->outputs.size(), 1); 1072 auto& output_array = model->GetArray(op->outputs[0]); 1073 if (output_array.has_shape()) { 1074 // Shape already propagated 1075 return; 1076 } 1077 1078 const auto& input_array = model->GetArray(op->inputs[0]); 1079 if (!input_array.has_shape()) { 1080 // Yield until input dims have been resolved. 1081 return; 1082 } 1083 1084 // Only set the output shape. Array contents are set by 1085 // ResolveConstantShapeOrRank. 1086 Shape* output_shape = output_array.mutable_shape(); 1087 output_shape->ReplaceDims({input_array.shape().dimensions_count()}); 1088 } 1089 1090 void ProcessStackOperator(Model* model, StackOperator* op) { 1091 CHECK_GE(op->inputs.size(), 1); 1092 CHECK_EQ(op->outputs.size(), 1); 1093 auto& output_array = model->GetArray(op->outputs[0]); 1094 if (output_array.has_shape()) { 1095 // Shape already propagated 1096 return; 1097 } 1098 1099 std::unique_ptr<Shape> stacked_shape; 1100 for (const auto& input : op->inputs) { 1101 const auto& input_array = model->GetArray(input); 1102 if (!input_array.has_shape()) { 1103 // Yield until all input dims have been resolved. 1104 return; 1105 } 1106 1107 Shape shape = input_array.shape(); 1108 if (shape.dimensions_count() == 0) { 1109 // Convert 0D scalars to 1D scalars of shape {1}. 1110 shape.mutable_dims()->push_back(1); 1111 } 1112 if (!stacked_shape) { 1113 stacked_shape.reset(new Shape(shape)); 1114 } else { 1115 CHECK(*stacked_shape == shape) << "All input arrays to Stack operators " 1116 "must have the same shape. Input \"" 1117 << input << "\" is different."; 1118 } 1119 } 1120 1121 int axis = op->axis; 1122 if (axis < 0) { 1123 // Handle negative axis 1124 axis += stacked_shape->dims().size() + 1; 1125 } 1126 stacked_shape->mutable_dims()->insert( 1127 stacked_shape->mutable_dims()->begin() + axis, op->inputs.size()); 1128 output_array.copy_shape(*stacked_shape); 1129 } 1130 1131 void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { 1132 CHECK_GE(op->inputs.size(), 1); 1133 CHECK_EQ(op->outputs.size(), 1); 1134 auto& output_array = model->GetArray(op->outputs[0]); 1135 if (output_array.has_shape()) { 1136 // Shape already propagated 1137 return; 1138 } 1139 1140 if (op->start_indices.empty() || op->stop_indices.empty() || 1141 op->strides.empty()) { 1142 // ResolveStridedSliceAttributes has not run yet. 1143 return; 1144 } 1145 1146 const auto& input_array = model->GetArray(op->inputs[0]); 1147 if (!input_array.has_shape()) { 1148 // Yield until input dims have been resolved. 1149 return; 1150 } 1151 1152 if (op->ellipsis_mask != 0) { 1153 // Something like LOG_FIRST_N(WARNING, 10) would be prefferable to reduce 1154 // log noise. However, the TensorFlow logging library does not appear to 1155 // support this. 1156 LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0] 1157 << "\". ellipsis_mask is not supported (mask=" 1158 << op->ellipsis_mask << ")"; 1159 return; 1160 } 1161 if (op->new_axis_mask != 0) { 1162 LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0] 1163 << "\". new_axis_mask is not supported (mask=" 1164 << op->new_axis_mask << ")"; 1165 return; 1166 } 1167 1168 int dim_count = input_array.shape().dimensions_count(); 1169 CHECK(op->start_indices.size() == dim_count) 1170 << ": Incorrect number of start indices supplied to StridedSlice op with " 1171 "output \"" 1172 << op->outputs[0] << "\". Op requires " << dim_count << " start indices"; 1173 CHECK(op->stop_indices.size() == dim_count) 1174 << ": Incorrect number of stop indices supplied to StridedSlice op with " 1175 "output \"" 1176 << op->outputs[0] << "\". Op requires " << dim_count << " stop indices"; 1177 CHECK(op->strides.size() == dim_count) 1178 << ": Incorrect number of strides supplied to StridedSlice op with " 1179 " output \"" 1180 << op->outputs[0] << "\". Op requires " << dim_count << " strides"; 1181 1182 // Create output shape 1183 std::vector<int>* dims = output_array.mutable_shape()->mutable_dims(); 1184 1185 // Compute output shape 1186 for (int i = 0; i < dim_count; ++i) { 1187 const int mask = 1 << i; 1188 int start = (op->begin_mask & mask) ? 0 : op->start_indices[i]; 1189 if (start < 0) { 1190 // handle negative indices 1191 start += input_array.shape().dims(i); 1192 } 1193 int stop = (op->end_mask & mask) ? input_array.shape().dims(i) 1194 : op->stop_indices[i]; 1195 if (stop < 0) { 1196 // handle negative indices 1197 stop += input_array.shape().dims(i); 1198 } 1199 1200 int dim_size = ceil((stop - start) / static_cast<float>(op->strides[i])); 1201 dim_size = dim_size < 0 ? 0 : dim_size; 1202 if (op->shrink_axis_mask & mask) { 1203 CHECK_EQ(dim_size, 1) << "Output size for an axis must compute to 1 when " 1204 "shrinking that axis"; 1205 } else { 1206 dims->push_back(dim_size); 1207 } 1208 } 1209 } 1210 1211 void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) { 1212 CHECK_EQ(op->inputs.size(), 1); 1213 CHECK_EQ(op->outputs.size(), 1); 1214 1215 const auto& input_array = model->GetArray(op->inputs[0]); 1216 1217 // Yield until input dims have been resolved. 1218 if (!input_array.has_shape()) return; 1219 1220 auto& output_array = model->GetArray(op->outputs[0]); 1221 if (output_array.has_shape()) return; 1222 1223 const std::vector<int>& input_dims = input_array.shape().dims(); 1224 std::vector<int> output_dims; 1225 1226 for (int i = 0; i < input_dims.size(); ++i) { 1227 if (input_dims[i] != 1 || 1228 (!op->squeeze_dims.empty() && 1229 std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) == 1230 op->squeeze_dims.end())) { 1231 output_dims.push_back(input_dims[i]); 1232 } 1233 } 1234 *output_array.mutable_shape()->mutable_dims() = output_dims; 1235 } 1236 1237 void ProcessSvdfOperator(Model* model, SvdfOperator* op) { 1238 CHECK(op->inputs.size() == 3 || op->inputs.size() == 4); 1239 const auto& input_array = model->GetArray(op->inputs[0]); 1240 if (!input_array.has_shape()) return; 1241 1242 auto& weights_feature_array = model->GetArray(op->inputs[1]); 1243 if (!weights_feature_array.has_shape()) return; 1244 1245 const auto& weights_time_array = model->GetArray(op->inputs[2]); 1246 if (!weights_time_array.has_shape()) return; 1247 1248 const bool has_bias = (op->inputs.size() == 4); 1249 if (has_bias) { 1250 const auto& bias_array = model->GetArray(op->inputs[3]); 1251 if (!bias_array.has_shape()) return; 1252 } 1253 1254 const int batch_size = input_array.shape().dims()[0]; 1255 const int num_units = weights_feature_array.shape().dims()[0]; 1256 const int memory_size = weights_time_array.shape().dims()[1]; 1257 1258 auto& state_array = model->GetArray(op->outputs[0]); 1259 state_array.mutable_shape()->ReplaceDims( 1260 {batch_size, memory_size * num_units}); 1261 1262 auto& output_array = model->GetArray(op->outputs[1]); 1263 output_array.mutable_shape()->ReplaceDims({batch_size, num_units}); 1264 } 1265 1266 void ProcessTransposeOperator(Model* model, TransposeOperator* op) { 1267 auto& output_array = model->GetArray(op->outputs[0]); 1268 if (output_array.has_shape()) { 1269 // We have already run 1270 return; 1271 } 1272 1273 const auto& input_array = model->GetArray(op->inputs[0]); 1274 if (!input_array.has_shape()) { 1275 // Yield until input dims have been resolved. 1276 return; 1277 } 1278 const auto& input_shape = input_array.shape(); 1279 1280 auto& perm_array = model->GetArray(op->inputs[1]); 1281 if (!perm_array.has_shape()) { 1282 // Yield until permutation shape been resolved. 1283 return; 1284 } 1285 if (!perm_array.buffer) { 1286 // Yield until the permutation is constant 1287 return; 1288 } 1289 CHECK(perm_array.data_type == ArrayDataType::kInt32) 1290 << "Transpose permutation input must be int32"; 1291 1292 std::vector<int32> const& perm = 1293 perm_array.GetBuffer<ArrayDataType::kInt32>().data; 1294 CHECK_EQ(perm.size(), input_shape.dimensions_count()) 1295 << "Transpose permutation input " << op->inputs[0] 1296 << " must be same length as input dimensions"; 1297 std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims(); 1298 for (int i = 0; i < perm.size(); i++) { 1299 int axis = perm[i]; 1300 CHECK_GE(axis, 0); 1301 CHECK_LT(axis, input_shape.dimensions_count()); 1302 output_dims->push_back(input_shape.dims(axis)); 1303 } 1304 } 1305 1306 void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { 1307 CHECK_EQ(op->inputs.size(), 2); 1308 const auto& input_array = model->GetArray(op->inputs[0]); 1309 // Yield until input dims have been resolved. 1310 if (!input_array.has_shape()) { 1311 return; 1312 } 1313 1314 // The current ArgMax implementation only supports 4-dimensional inputs with 1315 // the last dimension as the axis to perform ArgMax for. 1316 const std::vector<int>& input_dims = input_array.shape().dims(); 1317 CHECK_EQ(input_dims.size(), 4); 1318 std::vector<int> output_dims; 1319 1320 output_dims.reserve(input_dims.size() - 1); 1321 for (int i = 0; i < input_dims.size() - 1; ++i) { 1322 output_dims.push_back(input_dims[i]); 1323 } 1324 output_dims.push_back(1); 1325 const string& output_name = op->outputs[0]; 1326 auto& output_array = model->GetArray(output_name); 1327 if (output_array.has_shape()) { 1328 return; 1329 } 1330 *output_array.mutable_shape()->mutable_dims() = output_dims; 1331 } 1332 1333 } // namespace 1334 1335 bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { 1336 auto it = model->operators.begin() + op_index; 1337 auto* op = it->get(); 1338 std::unordered_map<string, std::vector<int>> old_output_dims; 1339 for (const auto& output : op->outputs) { 1340 if (model->GetArray(output).has_shape()) { 1341 old_output_dims[output] = model->GetArray(output).shape().dims(); 1342 } 1343 } 1344 1345 switch (op->type) { 1346 case OperatorType::kBatchNormalization: 1347 case OperatorType::kL2Normalization: 1348 case OperatorType::kDequantize: 1349 case OperatorType::kRelu: 1350 case OperatorType::kRelu1: 1351 case OperatorType::kRelu6: 1352 case OperatorType::kSoftmax: 1353 case OperatorType::kLogSoftmax: 1354 case OperatorType::kLogistic: 1355 case OperatorType::kTanh: 1356 case OperatorType::kLocalResponseNormalization: 1357 case OperatorType::kTensorFlowIdentity: 1358 case OperatorType::kFakeQuant: 1359 case OperatorType::kNeg: 1360 case OperatorType::kTensorFlowRsqrt: 1361 case OperatorType::kTensorFlowSqrt: 1362 case OperatorType::kTensorFlowSquare: 1363 case OperatorType::kTensorFlowAll: 1364 case OperatorType::kTensorFlowAssert: 1365 case OperatorType::kCast: 1366 case OperatorType::kFloor: 1367 case OperatorType::kExp: 1368 ProcessSimpleOperator(model, op); 1369 break; 1370 case OperatorType::kGather: 1371 ProcessGatherOperator(model, static_cast<GatherOperator*>(op)); 1372 break; 1373 case OperatorType::kTopK_V2: 1374 ProcessTopkV2Operator(model, static_cast<TopKV2Operator*>(op)); 1375 break; 1376 case OperatorType::kAdd: 1377 case OperatorType::kSub: 1378 case OperatorType::kMul: 1379 case OperatorType::kDiv: 1380 case OperatorType::kFloorDiv: 1381 case OperatorType::kFloorMod: 1382 case OperatorType::kTensorFlowLess: 1383 case OperatorType::kTensorFlowLessEqual: 1384 case OperatorType::kTensorFlowGreater: 1385 case OperatorType::kTensorFlowMaximum: 1386 case OperatorType::kTensorFlowMinimum: 1387 case OperatorType::kTensorFlowGreaterEqual: 1388 ProcessSimpleBinaryOperator(model, op); 1389 break; 1390 case OperatorType::kAddN: 1391 ProcessAddNOperator(model, op); 1392 break; 1393 case OperatorType::kConv: 1394 ProcessConvOperator(model, static_cast<ConvOperator*>(op)); 1395 break; 1396 case OperatorType::kTransposeConv: 1397 // Unimplemented, hopefully another graph transformation will drop it or 1398 // rewrite it. 1399 break; 1400 case OperatorType::kDepthwiseConv: 1401 ProcessDepthwiseConvOperator(model, 1402 static_cast<DepthwiseConvOperator*>(op)); 1403 break; 1404 case OperatorType::kDepthToSpace: 1405 ProcessDepthToSpaceOperator(model, 1406 static_cast<DepthToSpaceOperator*>(op)); 1407 break; 1408 case OperatorType::kSpaceToDepth: 1409 ProcessSpaceToDepthOperator(model, 1410 static_cast<SpaceToDepthOperator*>(op)); 1411 break; 1412 case OperatorType::kFill: 1413 ProcessFillOperator(model, static_cast<FillOperator*>(op)); 1414 break; 1415 case OperatorType::kFullyConnected: 1416 ProcessFullyConnectedOperator(model, 1417 static_cast<FullyConnectedOperator*>(op)); 1418 break; 1419 case OperatorType::kTensorFlowReshape: 1420 ProcessTensorFlowReshapeOperator( 1421 model, static_cast<TensorFlowReshapeOperator*>(op)); 1422 break; 1423 case OperatorType::kAveragePool: 1424 ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op)); 1425 break; 1426 case OperatorType::kMaxPool: 1427 ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op)); 1428 break; 1429 case OperatorType::kL2Pool: 1430 ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op)); 1431 break; 1432 case OperatorType::kTensorFlowMin: 1433 case OperatorType::kTensorFlowMax: 1434 case OperatorType::kTensorFlowSum: 1435 case OperatorType::kMean: 1436 ProcessTensorFlowReductionOperator(model, op); 1437 break; 1438 1439 case OperatorType::kSlice: 1440 ProcessSliceOperator(model, static_cast<SliceOperator*>(op)); 1441 break; 1442 1443 case OperatorType::kTensorFlowTile: 1444 // We don't currently implement the propagation of fixed sizes through 1445 // a TensorFlow Tile. 1446 // 1447 // Fortunately, we don't need to: so far, we have only dealt with Tile 1448 // or Slice ops in subgraphs that are identified as L2Normalization. 1449 // See IdentifyL2Normalization. 1450 break; 1451 case OperatorType::kTensorFlowSwitch: 1452 // We can't know the sizes of the outputs until we have resolved the 1453 // predicate, and once we have resolved the predicate, the whole 1454 // Switch node will get resolved away. 1455 // See ResolveTensorFlowSwitch. 1456 break; 1457 case OperatorType::kTensorFlowMerge: 1458 // No need to bother resolving TensorFlow Merge ops: other graph 1459 // transformations will remove them anyway. 1460 // See ResolveTensorFlowMerge. 1461 break; 1462 case OperatorType::kTensorFlowSplit: 1463 ProcessTensorFlowSplitOperator(model, 1464 static_cast<TensorFlowSplitOperator*>(op)); 1465 break; 1466 case OperatorType::kSqueeze: 1467 ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op)); 1468 break; 1469 case OperatorType::kTensorFlowConcat: 1470 case OperatorType::kTensorFlowConcatV2: 1471 // Unimplemented, hopefully another graph transformation will 1472 // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat 1473 // will resolve this node to a DepthConcatenation, or else we have 1474 // a more general non-depth concatenation that will hopefully be dropped, 1475 // or else at the moment we will abort. 1476 break; 1477 case OperatorType::kExpandDims: 1478 // Yield until ExpandDims is converted to Reshape 1479 break; 1480 case OperatorType::kRange: 1481 ProcessRangeOperator(model, static_cast<RangeOperator*>(op)); 1482 break; 1483 case OperatorType::kRank: 1484 ProcessRankOperator(model, static_cast<RankOperator*>(op)); 1485 break; 1486 case OperatorType::kTensorFlowShape: 1487 ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op)); 1488 break; 1489 case OperatorType::kStack: 1490 ProcessStackOperator(model, static_cast<StackOperator*>(op)); 1491 break; 1492 case OperatorType::kReorderAxes: 1493 ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op)); 1494 break; 1495 case OperatorType::kConcatenation: 1496 ProcessConcatenationOperator(model, 1497 static_cast<ConcatenationOperator*>(op)); 1498 break; 1499 case OperatorType::kResizeBilinear: 1500 ProcessResizeBilinearOperator(model, 1501 static_cast<ResizeBilinearOperator*>(op)); 1502 break; 1503 case OperatorType::kLstmCell: 1504 ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op)); 1505 break; 1506 case OperatorType::kBatchMatMul: 1507 case OperatorType::kTensorFlowMatMul: 1508 // MatMul operators are converted to FullyConnected, after which their 1509 // shapes are propagated. 1510 break; 1511 case OperatorType::kSpaceToBatchND: 1512 ProcessSpaceToBatchNDOperator(model, 1513 static_cast<SpaceToBatchNDOperator*>(op)); 1514 break; 1515 case OperatorType::kBatchToSpaceND: 1516 ProcessBatchToSpaceNDOperator(model, 1517 static_cast<BatchToSpaceNDOperator*>(op)); 1518 break; 1519 case OperatorType::kPad: 1520 ProcessPadOperator(model, static_cast<PadOperator*>(op)); 1521 break; 1522 case OperatorType::kStridedSlice: 1523 ProcessStridedSliceOperator(model, 1524 static_cast<StridedSliceOperator*>(op)); 1525 break; 1526 case OperatorType::kArgMax: 1527 ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op)); 1528 break; 1529 case OperatorType::kTensorFlowUnsupported: 1530 break; 1531 case OperatorType::kSvdf: 1532 ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op)); 1533 break; 1534 case OperatorType::kTranspose: 1535 ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op)); 1536 break; 1537 default: 1538 // Unimplemented, another graph transformation should drop it. 1539 LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); 1540 } 1541 1542 // Return true if any output dim changed, false if none changed. 1543 // Assumption: no transformation clears an output shape, they only add shapes. 1544 for (const auto& output : op->outputs) { 1545 if (model->GetArray(output).has_shape() && 1546 (old_output_dims[output] != model->GetArray(output).shape().dims())) { 1547 AddMessageF("Set shape of %s to [%s]", output, 1548 absl::StrJoin(model->GetArray(output).shape().dims(), ",")); 1549 return true; 1550 } 1551 } 1552 return false; 1553 } 1554 1555 } // namespace toco 1556