1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A library of common shape functions.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 import six.moves 22 23 from tensorflow.python import pywrap_tensorflow 24 from tensorflow.python.framework import cpp_shape_inference_pb2 25 from tensorflow.python.framework import errors 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import tensor_shape 28 from tensorflow.python.framework import tensor_util 29 30 31 def scalar_shape(unused_op): 32 """Shape function for ops that output a scalar value.""" 33 return [tensor_shape.scalar()] 34 35 36 def unchanged_shape(op): 37 """Shape function for ops that output a tensor like their first input.""" 38 return [op.inputs[0].get_shape()] 39 40 41 def unchanged_shape_with_rank(rank): 42 """Returns a shape function for ops that constrain the rank of their input. 43 44 Args: 45 rank: The exact rank of the input and output. 46 47 Returns: 48 A shape function for ops that output a tensor of the same size as their 49 input, with a particular rank. 50 """ 51 52 def _ShapeFunction(op): 53 return [op.inputs[0].get_shape().with_rank(rank)] 54 55 return _ShapeFunction 56 57 58 def unchanged_shape_with_rank_at_least(rank): 59 """Returns a shape function for ops that constrain the rank of their input. 60 61 Args: 62 rank: A lower bound on the rank of the input and output. 63 64 Returns: 65 A shape function for ops that output a tensor of the same size as their 66 input, with a particular rank. 67 """ 68 69 def _ShapeFunction(op): 70 return [op.inputs[0].get_shape().with_rank_at_least(rank)] 71 72 return _ShapeFunction 73 74 75 def unchanged_shape_with_rank_at_most(rank): 76 """Returns a shape function for ops that constrain the rank of their input. 77 78 Args: 79 rank: An upper bound on the rank of the input and output. 80 81 Returns: 82 A shape function for ops that output a tensor of the same size as their 83 input, with a particular rank. 84 """ 85 86 def _ShapeFunction(op): 87 return [op.inputs[0].get_shape().with_rank_at_most(rank)] 88 89 return _ShapeFunction 90 91 92 def matmul_shape(op): 93 """Shape function for a MatMul op.""" 94 a_shape = op.inputs[0].get_shape().with_rank(2) 95 transpose_a = op.get_attr("transpose_a") 96 b_shape = op.inputs[1].get_shape().with_rank(2) 97 transpose_b = op.get_attr("transpose_b") 98 output_rows = a_shape[1] if transpose_a else a_shape[0] 99 output_cols = b_shape[0] if transpose_b else b_shape[1] 100 inner_a = a_shape[0] if transpose_a else a_shape[1] 101 inner_b = b_shape[1] if transpose_b else b_shape[0] 102 inner_a.assert_is_compatible_with(inner_b) 103 return [tensor_shape.TensorShape([output_rows, output_cols])] 104 105 106 def get_conv_output_size(input_size, filter_size, strides, padding_type): 107 """Returns the spatial size of a n-d convolution/pooling output.""" 108 input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size]) 109 filter_size = tuple([tensor_shape.as_dimension(x).value for x in filter_size]) 110 strides = [int(x) for x in strides] 111 112 if all(x == 1 for x in input_size) and all(x == 1 for x in filter_size): 113 return input_size 114 115 if any(x is not None and y is not None and x > y for x, y in 116 zip(filter_size, input_size)): 117 raise ValueError("Filter must not be larger than the input: " 118 "Filter: %r Input: %r" % (filter_size, input_size)) 119 120 if padding_type == b"VALID": 121 122 def _valid(in_dim, k_dim, s_dim): 123 if in_dim is not None and k_dim is not None: 124 return (in_dim - k_dim + s_dim) // s_dim 125 else: 126 return None 127 128 output_size = [ 129 _valid(in_dim, k_dim, s_dim) 130 for in_dim, k_dim, s_dim in zip(input_size, filter_size, strides) 131 ] 132 elif padding_type == b"SAME": 133 134 def _same(in_dim, s_dim): 135 if in_dim is not None: 136 return (in_dim + s_dim - 1) // s_dim 137 else: 138 return None 139 140 output_size = [_same(in_dim, s_dim) 141 for in_dim, s_dim in zip(input_size, strides)] 142 else: 143 raise ValueError("Invalid padding: %r" % padding_type) 144 145 return tuple(output_size) 146 147 148 def get2d_conv_output_size(input_height, input_width, filter_height, 149 filter_width, row_stride, col_stride, padding_type): 150 """Returns the number of rows and columns in a convolution/pooling output.""" 151 return get_conv_output_size((input_height, input_width), 152 (filter_height, filter_width), 153 (row_stride, col_stride), padding_type) 154 155 156 def conv2d_shape(op): 157 """Shape function for a Conv2D op. 158 159 This op has two inputs: 160 161 * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in] 162 * filter, a 4D tensor with shape = [filter_rows, filter_cols, 163 depth_in, depth_out] 164 165 The output is a 4D tensor with shape = [batch_size, out_rows, 166 out_cols, depth_out], where out_rows and out_cols depend on the 167 value of the op's "padding" and "strides" attrs. 168 169 Args: 170 op: A Conv2D Operation. 171 172 Returns: 173 A list containing the Shape of the Conv2D output. 174 175 Raises: 176 ValueError: If the shapes of the input or filter are incompatible. 177 """ 178 input_shape = op.inputs[0].get_shape().with_rank(4) 179 filter_shape = op.inputs[1].get_shape().with_rank(4) 180 181 try: 182 data_format = op.get_attr("data_format") 183 except ValueError: 184 data_format = None 185 186 if data_format == b"NCHW": 187 # Convert input shape to the default NHWC for inference. 188 input_shape = [input_shape[0], input_shape[2], input_shape[3], 189 input_shape[1]] 190 191 batch_size = input_shape[0] 192 in_rows = input_shape[1] 193 in_cols = input_shape[2] 194 195 filter_rows = filter_shape[0] 196 filter_cols = filter_shape[1] 197 depth_out = filter_shape[3] 198 # Check that the input depths are compatible. 199 input_shape[3].assert_is_compatible_with(filter_shape[2]) 200 201 if data_format == b"NCHW": 202 stride_b, stride_d, stride_r, stride_c = op.get_attr("strides") 203 else: 204 stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") 205 206 if stride_b != 1 or stride_d != 1: 207 raise ValueError("Current implementation does not yet support " 208 "strides in the batch and depth dimensions.") 209 # TODO(mrry,shlens): Raise an error if the stride would cause 210 # information in the input to be ignored. This will require a change 211 # in the kernel implementation. 212 padding = op.get_attr("padding") 213 out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows, 214 filter_cols, stride_r, stride_c, 215 padding) 216 217 output_shape = [batch_size, out_rows, out_cols, depth_out] 218 if data_format == b"NCHW": 219 # Convert output shape back to NCHW. 220 output_shape = [output_shape[0], output_shape[3], output_shape[1], 221 output_shape[2]] 222 return [tensor_shape.TensorShape(output_shape)] 223 224 225 def depthwise_conv2d_native_shape(op): 226 """Shape function for a DepthwiseConv2D op. 227 228 This op has two inputs: 229 230 * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in] 231 * filter, a 4D tensor with shape = [filter_rows, filter_cols, 232 depth_in, depthwise_multiplier] 233 234 The output is a 4D tensor with shape = [batch_size, out_rows, 235 out_cols, depth_in*depthwise_multiplier], where out_rows and out_cols depend 236 on the value of the op's "padding" and "strides" attrs. 237 238 Args: 239 op: A DepthwiseConv2dNative Operation. 240 241 Returns: 242 A list containing the Shape of the DepthwiseConv2DNative output. 243 244 Raises: 245 ValueError: If the shapes of the input or filter are incompatible. 246 """ 247 input_shape = op.inputs[0].get_shape().with_rank(4) 248 filter_shape = op.inputs[1].get_shape().with_rank(4) 249 250 batch_size = input_shape[0] 251 in_rows = input_shape[1] 252 in_cols = input_shape[2] 253 254 filter_rows = filter_shape[0] 255 filter_cols = filter_shape[1] 256 depth_out = filter_shape[3] * filter_shape[2] 257 # Check that the input depths are compatible. 258 input_shape[3].assert_is_compatible_with(filter_shape[2]) 259 260 stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") 261 if stride_b != 1 or stride_d != 1: 262 raise ValueError("Current implementation does not yet support " 263 "strides in the batch and depth dimensions.") 264 if stride_r != stride_c: 265 # TODO(shlens): Add support for this. 266 raise ValueError("Current implementation only supports equal length " 267 "strides in the row and column dimensions.") 268 269 # TODO(mrry,shlens): Raise an error if the stride would cause 270 # information in the input to be ignored. This will require a change 271 # in the kernel implementation. 272 stride = stride_r 273 padding = op.get_attr("padding") 274 out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows, 275 filter_cols, stride, stride, 276 padding) 277 278 return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])] 279 280 281 def separable_conv2d_shape(op): 282 """Shape function for a SeparableConv2D op. 283 284 This op has three inputs: 285 286 * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in] 287 288 * depthwise_filter, a 4D tensor with shape = [filter_rows, 289 filter_cols, depth_in, depth_multiplier] 290 291 * pointwise_filter, a 4D tensor with shape = [1, 1, depth_in * 292 depth_multiplier, depth_out] 293 294 The output is a 4D tensor with shape = [batch_size, out_rows, 295 out_cols, depth_out], where out_rows and out_cols depend on the 296 value of the op's "padding" and "strides" attrs. 297 298 Args: 299 op: A SeparableConv2D Operation. 300 301 Returns: 302 A list containing the Shape of the SeparableConv2D output. 303 304 Raises: 305 ValueError: If the shapes of the input or filter are incompatible. 306 """ 307 input_shape = op.inputs[0].get_shape().with_rank(4) 308 depthwise_filter_shape = op.inputs[1].get_shape().merge_with( 309 tensor_shape.TensorShape([None, None, input_shape[3], None])) 310 pointwise_depth_in = depthwise_filter_shape[2] * depthwise_filter_shape[3] 311 312 pointwise_filter_shape = op.inputs[2].get_shape().merge_with( 313 tensor_shape.TensorShape([1, 1, pointwise_depth_in, None])) 314 315 batch_size = input_shape[0] 316 in_rows = input_shape[1] 317 in_cols = input_shape[2] 318 319 filter_rows = depthwise_filter_shape[0] 320 filter_cols = depthwise_filter_shape[1] 321 depth_out = pointwise_filter_shape[3] 322 323 stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") 324 if stride_b != 1 or stride_d != 1: 325 raise ValueError("Current implementation does not yet support " 326 "strides in the batch and depth dimensions.") 327 if stride_r != stride_c: 328 # TODO(shlens): Add support for this. 329 raise ValueError("Current implementation only supports equal length " 330 "strides in the row and column dimensions.") 331 332 # TODO(mrry,shlens): Raise an error if the stride would cause 333 # information in the input to be ignored. This will require a change 334 # in the kernel implementation. 335 stride = stride_r 336 padding = op.get_attr("padding") 337 out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows, 338 filter_cols, stride, stride, 339 padding) 340 341 return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])] 342 343 344 def avg_pool_shape(op): 345 """Shape function for an AvgPool op. 346 347 This op has one input: 348 349 * input, a 4D tensor with shape = [batch_size, rows, cols, depth] 350 351 The output is a 4D tensor with shape = [batch_size, out_rows, 352 out_cols, depth_out], where out_rows and out_cols depend on the 353 value of the op's "ksize", "strides", and "padding" attrs. 354 355 Args: 356 op: An AvgPool Operation. 357 358 Returns: 359 A single-element list containing the Shape of the AvgPool output. 360 361 Raises: 362 ValueError: If the shape of the input is invalid or incompatible with 363 the values of the attrs. 364 """ 365 input_shape = op.inputs[0].get_shape().with_rank(4) 366 try: 367 data_format = op.get_attr("data_format") 368 except ValueError: 369 data_format = None 370 371 if data_format == b"NCHW": 372 # Convert input shape to the default NHWC for inference. 373 input_shape = [input_shape[0], input_shape[2], input_shape[3], 374 input_shape[1]] 375 376 if data_format == b"NCHW": 377 ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize") 378 stride_b, stride_d, stride_r, stride_c = op.get_attr("strides") 379 else: 380 ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize") 381 stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") 382 383 batch_size = input_shape[0] 384 in_rows = input_shape[1] 385 in_cols = input_shape[2] 386 depth = input_shape[3] 387 388 if ksize_b != 1 or ksize_d != 1: 389 raise ValueError("Current implementation does not support pooling " 390 "in the batch and depth dimensions.") 391 if stride_b != 1 or stride_d != 1: 392 raise ValueError("Current implementation does not support strides " 393 "in the batch and depth dimensions.") 394 395 # TODO(mrry,shlens): Raise an error if the stride would cause 396 # information in the input to be ignored. This will require a change 397 # in the kernel implementation. 398 padding = op.get_attr("padding") 399 400 out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r, 401 ksize_c, stride_r, stride_c, 402 padding) 403 404 output_shape = [batch_size, out_rows, out_cols, depth] 405 if data_format == b"NCHW": 406 # Convert output shape back to NCHW. 407 output_shape = [output_shape[0], output_shape[3], output_shape[1], 408 output_shape[2]] 409 return [tensor_shape.TensorShape(output_shape)] 410 411 412 def max_pool_shape(op): 413 """Shape function for a MaxPool op. 414 415 This op has one input: 416 417 * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in] 418 419 The output is a 4D tensor with shape = [batch_size, out_rows, 420 out_cols, depth_out], where out_rows, out_cols, and depth_out depend 421 on the value of the op's "ksize", "strides", and "padding" attrs. 422 423 Args: 424 op: A MaxPool Operation. 425 426 Returns: 427 A single-element list containing the Shape of the MaxPool output. 428 429 Raises: 430 ValueError: If the shape of the input is invalid or incompatible with 431 the values of the attrs. 432 """ 433 input_shape = op.inputs[0].get_shape().with_rank(4) 434 try: 435 data_format = op.get_attr("data_format") 436 except ValueError: 437 data_format = None 438 439 if data_format == b"NCHW": 440 # Convert input shape to the default NHWC for inference. 441 input_shape = [input_shape[0], input_shape[2], input_shape[3], 442 input_shape[1]] 443 444 if data_format == b"NCHW": 445 ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize") 446 stride_b, stride_d, stride_r, stride_c = op.get_attr("strides") 447 else: 448 ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize") 449 stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") 450 451 batch_size = input_shape[0] 452 in_rows = input_shape[1] 453 in_cols = input_shape[2] 454 depth = input_shape[3] 455 456 if ksize_b != 1: 457 raise ValueError("Current implementation does not support pooling " 458 "in the batch dimension.") 459 if stride_b != 1: 460 raise ValueError("Current implementation does not support strides " 461 "in the batch dimension.") 462 463 if not ((ksize_r == 1 and ksize_c == 1) or ksize_d == 1): 464 raise ValueError("MaxPooling supports exactly one of pooling across depth " 465 "or pooling across width/height.") 466 467 # TODO(mrry,shlens): Raise an error if the stride would cause 468 # information in the input to be ignored. This will require a change 469 # in the kernel implementation. 470 if ksize_d == 1: 471 padding = op.get_attr("padding") 472 out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r, 473 ksize_c, stride_r, stride_c, 474 padding) 475 output_shape = [batch_size, out_rows, out_cols, depth] 476 else: 477 if depth % ksize_d > 0: 478 raise ValueError("Depthwise max pooling requires the depth window " 479 "to evenly divide the input depth.") 480 if stride_d != ksize_d: 481 raise ValueError("Depthwise max pooling requires the depth window " 482 "to equal the depth stride.") 483 output_shape = [batch_size, in_rows, in_cols, depth // ksize_d] 484 485 if data_format == b"NCHW": 486 # Convert output shape back to NCHW. 487 output_shape = [output_shape[0], output_shape[3], output_shape[1], 488 output_shape[2]] 489 return [tensor_shape.TensorShape(output_shape)] 490 491 492 def no_outputs(unused_op): 493 """Shape function for use with ops that have no outputs.""" 494 return [] 495 496 497 def unknown_shape(op): 498 """Shape function for use with ops whose output shapes are unknown.""" 499 return [tensor_shape.unknown_shape() for _ in op.outputs] 500 501 502 def _broadcast_shape_helper(shape_x, shape_y): 503 """Helper functions for is_broadcast_compatible and broadcast_shape. 504 505 Args: 506 shape_x: A `TensorShape` 507 shape_y: A `TensorShape` 508 509 Returns: 510 Returns None if the shapes are not broadcast compatible, 511 a list of the broadcast dimensions otherwise. 512 """ 513 # To compute the broadcasted dimensions, we zip together shape_x and shape_y, 514 # and pad with 1 to make them the same length. 515 broadcasted_dims = reversed(list(six.moves.zip_longest( 516 reversed(shape_x.dims), 517 reversed(shape_y.dims), 518 fillvalue=tensor_shape.Dimension(1)))) 519 # Next we combine the dimensions according to the numpy broadcasting rules. 520 # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html 521 return_dims = [] 522 for (dim_x, dim_y) in broadcasted_dims: 523 if dim_x.value is None or dim_y.value is None: 524 # One or both dimensions is unknown. If either dimension is greater than 525 # 1, we assume that the program is correct, and the other dimension will 526 # be broadcast to match it. 527 # TODO(mrry): If we eliminate the shape checks in C++, we must still 528 # assert that the unknown dim is either 1 or the same as the known dim. 529 if dim_x.value is not None and dim_x.value > 1: 530 return_dims.append(dim_x) 531 elif dim_y.value is not None and dim_y.value > 1: 532 return_dims.append(dim_y) 533 else: 534 return_dims.append(None) 535 elif dim_x.value == 1: 536 # We will broadcast dim_x to dim_y. 537 return_dims.append(dim_y) 538 elif dim_y.value == 1: 539 # We will broadcast dim_y to dim_x. 540 return_dims.append(dim_x) 541 elif dim_x.value == dim_y.value: 542 # The dimensions are compatible, so output is the same size in that 543 # dimension. 544 return_dims.append(dim_x.merge_with(dim_y)) 545 else: 546 return None 547 return return_dims 548 549 550 def is_broadcast_compatible(shape_x, shape_y): 551 """Returns True if `shape_x` and `shape_y` are broadcast compatible. 552 553 Args: 554 shape_x: A `TensorShape` 555 shape_y: A `TensorShape` 556 557 Returns: 558 True if a shape exists that both `shape_x` and `shape_y` can be broadcasted 559 to. False otherwise. 560 """ 561 if shape_x.ndims is None or shape_y.ndims is None: 562 return False 563 return _broadcast_shape_helper(shape_x, shape_y) is not None 564 565 566 def broadcast_shape(shape_x, shape_y): 567 """Returns the broadcasted shape between `shape_x` and `shape_y`. 568 569 Args: 570 shape_x: A `TensorShape` 571 shape_y: A `TensorShape` 572 573 Returns: 574 A `TensorShape` representing the broadcasted shape. 575 576 Raises: 577 ValueError: If the two shapes can not be broadcasted. 578 """ 579 if shape_x.ndims is None or shape_y.ndims is None: 580 return tensor_shape.unknown_shape() 581 return_dims = _broadcast_shape_helper(shape_x, shape_y) 582 if return_dims is None: 583 raise ValueError("Incompatible shapes for broadcasting: %s and %s" 584 % (shape_x, shape_y)) 585 return tensor_shape.TensorShape(return_dims) 586 587 588 def call_cpp_shape_fn(op, require_shape_fn=True): 589 """A shape function that delegates to the registered C++ shape function. 590 591 Args: 592 op: the node in the graph for which to compute output shapes. 593 require_shape_fn: If true, and the C++ shape function is not registered 594 in the current binary then an exception is raised; otherwise, if the 595 C++ shape function is not registered then unknown_shape is used. 596 597 Returns: 598 A dictionary with the following keys: 599 shapes: A TensorShape list of the output shapes of the op, as computed 600 using the C++ shape inference function registered for the op. 601 handle_shapes: A TensorShape list of the shapes for handle outputs, if 602 any. 603 handle_dtypes: A list of DataType enums for the handle outputs, if any. 604 605 Raises: 606 ValueError: If the C++ shape function returned an error (e.g. because the 607 shapes of the inputs are of the wrong rank or otherwise incompatible 608 according to the shape function). 609 RuntimeError: If the C++ shape function is not registered and 610 <require_shape_fn> is True. 611 """ 612 if op.type == "Const": 613 # To avoid serializing large constants, we special-case constant 614 # here, even though it has a C++ shape function. When Python 615 # calls the C / C-API directly, we should be able to remove this. 616 return { 617 "shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)], 618 "handle_data": [None] 619 } 620 621 input_tensors_needed = [] 622 input_tensors_as_shapes_needed = [] 623 624 while True: 625 res = _call_cpp_shape_fn_impl(op, input_tensors_needed, 626 input_tensors_as_shapes_needed, 627 require_shape_fn) 628 if not isinstance(res, dict): 629 # Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op). 630 return res 631 632 # See if we need to evaluate some inputs. 633 if not res["inputs_needed"]: 634 return res 635 p = cpp_shape_inference_pb2.CppShapeInferenceInputsNeeded() 636 p = p.FromString(res["inputs_needed"]) 637 changed = False 638 for idx in p.input_tensors_needed: 639 if idx not in input_tensors_needed: 640 input_tensors_needed.append(idx) 641 changed = True 642 for idx in p.input_tensors_as_shapes_needed: 643 if idx not in input_tensors_as_shapes_needed: 644 input_tensors_as_shapes_needed.append(idx) 645 changed = True 646 if not changed: 647 return res 648 649 650 def _call_cpp_shape_fn_impl( 651 op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn): 652 """Core implementation of call_cpp_shape_fn.""" 653 graph_def_version = op.graph.graph_def_versions.producer 654 node_def_str = op.node_def.SerializeToString() 655 656 def tensor_to_inference_result(t): 657 r = cpp_shape_inference_pb2.CppShapeInferenceResult() 658 r.shape.CopyFrom(t.get_shape().as_proto()) 659 # pylint: disable=protected-access 660 if t._handle_data is not None: 661 r.handle_data.CopyFrom(t._handle_data) 662 # pylint: enable=protected-access 663 return r.SerializeToString() 664 input_shapes = [tensor_to_inference_result(i) for i in op.inputs] 665 666 input_tensors = [None for i in input_shapes] 667 for idx in input_tensors_needed: 668 v = tensor_util.constant_value(op.inputs[idx]) 669 if v is not None: 670 input_tensors[idx] = np.asarray(v) 671 672 serialized_unknown_shape = ( 673 tensor_shape.TensorShape(None).as_proto().SerializeToString()) 674 arr = [serialized_unknown_shape for i in input_shapes] 675 for idx in input_tensors_as_shapes_needed: 676 s = tensor_util.constant_value_as_shape(op.inputs[idx]) 677 if s is not None: 678 arr[idx] = s.as_proto().SerializeToString() 679 input_tensors_as_shapes = arr 680 681 missing_shape_fn = False 682 try: 683 with errors.raise_exception_on_not_ok_status() as status: 684 output = pywrap_tensorflow.RunCppShapeInference( 685 graph_def_version, node_def_str, input_shapes, input_tensors, 686 input_tensors_as_shapes, status) 687 except errors.InvalidArgumentError as err: 688 if err.message.startswith("No shape inference function exists for op"): 689 missing_shape_fn = True 690 else: 691 raise ValueError(err.message) 692 693 if missing_shape_fn: 694 if require_shape_fn: 695 raise RuntimeError( 696 "No C++ shape function registered for standard op: %s" % op.type) 697 return unknown_shape(op) 698 699 output_shapes = output[:-1] 700 701 # Convert TensorShapeProto values in output_shapes. 702 result_protos = [ 703 cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s) 704 for s in output_shapes 705 ] 706 result = [r.shape for r in result_protos] 707 result_handle_data = [ 708 r.handle_data if r.handle_data.is_set else None for r in result_protos 709 ] 710 711 return { 712 "shapes": result, 713 "handle_data": result_handle_data, 714 "inputs_needed": output[-1] 715 } 716 717 # pylint: disable=protected-access 718 ops._set_call_cpp_shape_fn(call_cpp_shape_fn) 719 # pylint: enable=protected-access 720