Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A library of common shape functions."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 import six.moves
     22 
     23 from tensorflow.python import pywrap_tensorflow
     24 from tensorflow.python.framework import cpp_shape_inference_pb2
     25 from tensorflow.python.framework import errors
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import tensor_shape
     28 from tensorflow.python.framework import tensor_util
     29 
     30 
     31 def scalar_shape(unused_op):
     32   """Shape function for ops that output a scalar value."""
     33   return [tensor_shape.scalar()]
     34 
     35 
     36 def unchanged_shape(op):
     37   """Shape function for ops that output a tensor like their first input."""
     38   return [op.inputs[0].get_shape()]
     39 
     40 
     41 def unchanged_shape_with_rank(rank):
     42   """Returns a shape function for ops that constrain the rank of their input.
     43 
     44   Args:
     45     rank: The exact rank of the input and output.
     46 
     47   Returns:
     48     A shape function for ops that output a tensor of the same size as their
     49     input, with a particular rank.
     50   """
     51 
     52   def _ShapeFunction(op):
     53     return [op.inputs[0].get_shape().with_rank(rank)]
     54 
     55   return _ShapeFunction
     56 
     57 
     58 def unchanged_shape_with_rank_at_least(rank):
     59   """Returns a shape function for ops that constrain the rank of their input.
     60 
     61   Args:
     62     rank: A lower bound on the rank of the input and output.
     63 
     64   Returns:
     65     A shape function for ops that output a tensor of the same size as their
     66     input, with a particular rank.
     67   """
     68 
     69   def _ShapeFunction(op):
     70     return [op.inputs[0].get_shape().with_rank_at_least(rank)]
     71 
     72   return _ShapeFunction
     73 
     74 
     75 def unchanged_shape_with_rank_at_most(rank):
     76   """Returns a shape function for ops that constrain the rank of their input.
     77 
     78   Args:
     79     rank: An upper bound on the rank of the input and output.
     80 
     81   Returns:
     82     A shape function for ops that output a tensor of the same size as their
     83     input, with a particular rank.
     84   """
     85 
     86   def _ShapeFunction(op):
     87     return [op.inputs[0].get_shape().with_rank_at_most(rank)]
     88 
     89   return _ShapeFunction
     90 
     91 
     92 def matmul_shape(op):
     93   """Shape function for a MatMul op."""
     94   a_shape = op.inputs[0].get_shape().with_rank(2)
     95   transpose_a = op.get_attr("transpose_a")
     96   b_shape = op.inputs[1].get_shape().with_rank(2)
     97   transpose_b = op.get_attr("transpose_b")
     98   output_rows = a_shape[1] if transpose_a else a_shape[0]
     99   output_cols = b_shape[0] if transpose_b else b_shape[1]
    100   inner_a = a_shape[0] if transpose_a else a_shape[1]
    101   inner_b = b_shape[1] if transpose_b else b_shape[0]
    102   inner_a.assert_is_compatible_with(inner_b)
    103   return [tensor_shape.TensorShape([output_rows, output_cols])]
    104 
    105 
    106 def get_conv_output_size(input_size, filter_size, strides, padding_type):
    107   """Returns the spatial size of a n-d convolution/pooling output."""
    108   input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
    109   filter_size = tuple([tensor_shape.as_dimension(x).value for x in filter_size])
    110   strides = [int(x) for x in strides]
    111 
    112   if all(x == 1 for x in input_size) and all(x == 1 for x in filter_size):
    113     return input_size
    114 
    115   if any(x is not None and y is not None and x > y for x, y in
    116          zip(filter_size, input_size)):
    117     raise ValueError("Filter must not be larger than the input: "
    118                      "Filter: %r Input: %r" % (filter_size, input_size))
    119 
    120   if padding_type == b"VALID":
    121 
    122     def _valid(in_dim, k_dim, s_dim):
    123       if in_dim is not None and k_dim is not None:
    124         return (in_dim - k_dim + s_dim) // s_dim
    125       else:
    126         return None
    127 
    128     output_size = [
    129         _valid(in_dim, k_dim, s_dim)
    130         for in_dim, k_dim, s_dim in zip(input_size, filter_size, strides)
    131     ]
    132   elif padding_type == b"SAME":
    133 
    134     def _same(in_dim, s_dim):
    135       if in_dim is not None:
    136         return (in_dim + s_dim - 1) // s_dim
    137       else:
    138         return None
    139 
    140     output_size = [_same(in_dim, s_dim)
    141                    for in_dim, s_dim in zip(input_size, strides)]
    142   else:
    143     raise ValueError("Invalid padding: %r" % padding_type)
    144 
    145   return tuple(output_size)
    146 
    147 
    148 def get2d_conv_output_size(input_height, input_width, filter_height,
    149                            filter_width, row_stride, col_stride, padding_type):
    150   """Returns the number of rows and columns in a convolution/pooling output."""
    151   return get_conv_output_size((input_height, input_width),
    152                               (filter_height, filter_width),
    153                               (row_stride, col_stride), padding_type)
    154 
    155 
    156 def conv2d_shape(op):
    157   """Shape function for a Conv2D op.
    158 
    159   This op has two inputs:
    160 
    161   * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
    162   * filter, a 4D tensor with shape =  [filter_rows, filter_cols,
    163     depth_in, depth_out]
    164 
    165   The output is a 4D tensor with shape = [batch_size, out_rows,
    166   out_cols, depth_out], where out_rows and out_cols depend on the
    167   value of the op's "padding" and "strides" attrs.
    168 
    169   Args:
    170     op: A Conv2D Operation.
    171 
    172   Returns:
    173     A list containing the Shape of the Conv2D output.
    174 
    175   Raises:
    176     ValueError: If the shapes of the input or filter are incompatible.
    177   """
    178   input_shape = op.inputs[0].get_shape().with_rank(4)
    179   filter_shape = op.inputs[1].get_shape().with_rank(4)
    180 
    181   try:
    182     data_format = op.get_attr("data_format")
    183   except ValueError:
    184     data_format = None
    185 
    186   if data_format == b"NCHW":
    187     # Convert input shape to the default NHWC for inference.
    188     input_shape = [input_shape[0], input_shape[2], input_shape[3],
    189                    input_shape[1]]
    190 
    191   batch_size = input_shape[0]
    192   in_rows = input_shape[1]
    193   in_cols = input_shape[2]
    194 
    195   filter_rows = filter_shape[0]
    196   filter_cols = filter_shape[1]
    197   depth_out = filter_shape[3]
    198   # Check that the input depths are compatible.
    199   input_shape[3].assert_is_compatible_with(filter_shape[2])
    200 
    201   if data_format == b"NCHW":
    202     stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
    203   else:
    204     stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
    205 
    206   if stride_b != 1 or stride_d != 1:
    207     raise ValueError("Current implementation does not yet support "
    208                      "strides in the batch and depth dimensions.")
    209   # TODO(mrry,shlens): Raise an error if the stride would cause
    210   # information in the input to be ignored. This will require a change
    211   # in the kernel implementation.
    212   padding = op.get_attr("padding")
    213   out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
    214                                               filter_cols, stride_r, stride_c,
    215                                               padding)
    216 
    217   output_shape = [batch_size, out_rows, out_cols, depth_out]
    218   if data_format == b"NCHW":
    219     # Convert output shape back to NCHW.
    220     output_shape = [output_shape[0], output_shape[3], output_shape[1],
    221                     output_shape[2]]
    222   return [tensor_shape.TensorShape(output_shape)]
    223 
    224 
    225 def depthwise_conv2d_native_shape(op):
    226   """Shape function for a DepthwiseConv2D op.
    227 
    228   This op has two inputs:
    229 
    230   * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
    231   * filter, a 4D tensor with shape =  [filter_rows, filter_cols,
    232     depth_in, depthwise_multiplier]
    233 
    234   The output is a 4D tensor with shape = [batch_size, out_rows,
    235   out_cols, depth_in*depthwise_multiplier], where out_rows and out_cols depend
    236   on the value of the op's "padding" and "strides" attrs.
    237 
    238   Args:
    239     op: A DepthwiseConv2dNative Operation.
    240 
    241   Returns:
    242     A list containing the Shape of the DepthwiseConv2DNative output.
    243 
    244   Raises:
    245     ValueError: If the shapes of the input or filter are incompatible.
    246   """
    247   input_shape = op.inputs[0].get_shape().with_rank(4)
    248   filter_shape = op.inputs[1].get_shape().with_rank(4)
    249 
    250   batch_size = input_shape[0]
    251   in_rows = input_shape[1]
    252   in_cols = input_shape[2]
    253 
    254   filter_rows = filter_shape[0]
    255   filter_cols = filter_shape[1]
    256   depth_out = filter_shape[3] * filter_shape[2]
    257   # Check that the input depths are compatible.
    258   input_shape[3].assert_is_compatible_with(filter_shape[2])
    259 
    260   stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
    261   if stride_b != 1 or stride_d != 1:
    262     raise ValueError("Current implementation does not yet support "
    263                      "strides in the batch and depth dimensions.")
    264   if stride_r != stride_c:
    265     # TODO(shlens): Add support for this.
    266     raise ValueError("Current implementation only supports equal length "
    267                      "strides in the row and column dimensions.")
    268 
    269   # TODO(mrry,shlens): Raise an error if the stride would cause
    270   # information in the input to be ignored. This will require a change
    271   # in the kernel implementation.
    272   stride = stride_r
    273   padding = op.get_attr("padding")
    274   out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
    275                                               filter_cols, stride, stride,
    276                                               padding)
    277 
    278   return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
    279 
    280 
    281 def separable_conv2d_shape(op):
    282   """Shape function for a SeparableConv2D op.
    283 
    284   This op has three inputs:
    285 
    286   * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
    287 
    288   * depthwise_filter, a 4D tensor with shape = [filter_rows,
    289     filter_cols, depth_in, depth_multiplier]
    290 
    291   * pointwise_filter, a 4D tensor with shape = [1, 1, depth_in *
    292     depth_multiplier, depth_out]
    293 
    294   The output is a 4D tensor with shape = [batch_size, out_rows,
    295   out_cols, depth_out], where out_rows and out_cols depend on the
    296   value of the op's "padding" and "strides" attrs.
    297 
    298   Args:
    299     op: A SeparableConv2D Operation.
    300 
    301   Returns:
    302     A list containing the Shape of the SeparableConv2D output.
    303 
    304   Raises:
    305     ValueError: If the shapes of the input or filter are incompatible.
    306   """
    307   input_shape = op.inputs[0].get_shape().with_rank(4)
    308   depthwise_filter_shape = op.inputs[1].get_shape().merge_with(
    309       tensor_shape.TensorShape([None, None, input_shape[3], None]))
    310   pointwise_depth_in = depthwise_filter_shape[2] * depthwise_filter_shape[3]
    311 
    312   pointwise_filter_shape = op.inputs[2].get_shape().merge_with(
    313       tensor_shape.TensorShape([1, 1, pointwise_depth_in, None]))
    314 
    315   batch_size = input_shape[0]
    316   in_rows = input_shape[1]
    317   in_cols = input_shape[2]
    318 
    319   filter_rows = depthwise_filter_shape[0]
    320   filter_cols = depthwise_filter_shape[1]
    321   depth_out = pointwise_filter_shape[3]
    322 
    323   stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
    324   if stride_b != 1 or stride_d != 1:
    325     raise ValueError("Current implementation does not yet support "
    326                      "strides in the batch and depth dimensions.")
    327   if stride_r != stride_c:
    328     # TODO(shlens): Add support for this.
    329     raise ValueError("Current implementation only supports equal length "
    330                      "strides in the row and column dimensions.")
    331 
    332   # TODO(mrry,shlens): Raise an error if the stride would cause
    333   # information in the input to be ignored. This will require a change
    334   # in the kernel implementation.
    335   stride = stride_r
    336   padding = op.get_attr("padding")
    337   out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
    338                                               filter_cols, stride, stride,
    339                                               padding)
    340 
    341   return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
    342 
    343 
    344 def avg_pool_shape(op):
    345   """Shape function for an AvgPool op.
    346 
    347   This op has one input:
    348 
    349   * input, a 4D tensor with shape = [batch_size, rows, cols, depth]
    350 
    351   The output is a 4D tensor with shape = [batch_size, out_rows,
    352   out_cols, depth_out], where out_rows and out_cols depend on the
    353   value of the op's "ksize", "strides", and "padding" attrs.
    354 
    355   Args:
    356     op: An AvgPool Operation.
    357 
    358   Returns:
    359     A single-element list containing the Shape of the AvgPool output.
    360 
    361   Raises:
    362     ValueError: If the shape of the input is invalid or incompatible with
    363       the values of the attrs.
    364   """
    365   input_shape = op.inputs[0].get_shape().with_rank(4)
    366   try:
    367     data_format = op.get_attr("data_format")
    368   except ValueError:
    369     data_format = None
    370 
    371   if data_format == b"NCHW":
    372     # Convert input shape to the default NHWC for inference.
    373     input_shape = [input_shape[0], input_shape[2], input_shape[3],
    374                    input_shape[1]]
    375 
    376   if data_format == b"NCHW":
    377     ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
    378     stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
    379   else:
    380     ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
    381     stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
    382 
    383   batch_size = input_shape[0]
    384   in_rows = input_shape[1]
    385   in_cols = input_shape[2]
    386   depth = input_shape[3]
    387 
    388   if ksize_b != 1 or ksize_d != 1:
    389     raise ValueError("Current implementation does not support pooling "
    390                      "in the batch and depth dimensions.")
    391   if stride_b != 1 or stride_d != 1:
    392     raise ValueError("Current implementation does not support strides "
    393                      "in the batch and depth dimensions.")
    394 
    395   # TODO(mrry,shlens): Raise an error if the stride would cause
    396   # information in the input to be ignored. This will require a change
    397   # in the kernel implementation.
    398   padding = op.get_attr("padding")
    399 
    400   out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
    401                                               ksize_c, stride_r, stride_c,
    402                                               padding)
    403 
    404   output_shape = [batch_size, out_rows, out_cols, depth]
    405   if data_format == b"NCHW":
    406     # Convert output shape back to NCHW.
    407     output_shape = [output_shape[0], output_shape[3], output_shape[1],
    408                     output_shape[2]]
    409   return [tensor_shape.TensorShape(output_shape)]
    410 
    411 
    412 def max_pool_shape(op):
    413   """Shape function for a MaxPool op.
    414 
    415   This op has one input:
    416 
    417   * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
    418 
    419   The output is a 4D tensor with shape = [batch_size, out_rows,
    420   out_cols, depth_out], where out_rows, out_cols, and depth_out depend
    421   on the value of the op's "ksize", "strides", and "padding" attrs.
    422 
    423   Args:
    424     op: A MaxPool Operation.
    425 
    426   Returns:
    427     A single-element list containing the Shape of the MaxPool output.
    428 
    429   Raises:
    430     ValueError: If the shape of the input is invalid or incompatible with
    431       the values of the attrs.
    432   """
    433   input_shape = op.inputs[0].get_shape().with_rank(4)
    434   try:
    435     data_format = op.get_attr("data_format")
    436   except ValueError:
    437     data_format = None
    438 
    439   if data_format == b"NCHW":
    440     # Convert input shape to the default NHWC for inference.
    441     input_shape = [input_shape[0], input_shape[2], input_shape[3],
    442                    input_shape[1]]
    443 
    444   if data_format == b"NCHW":
    445     ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
    446     stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
    447   else:
    448     ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
    449     stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
    450 
    451   batch_size = input_shape[0]
    452   in_rows = input_shape[1]
    453   in_cols = input_shape[2]
    454   depth = input_shape[3]
    455 
    456   if ksize_b != 1:
    457     raise ValueError("Current implementation does not support pooling "
    458                      "in the batch dimension.")
    459   if stride_b != 1:
    460     raise ValueError("Current implementation does not support strides "
    461                      "in the batch dimension.")
    462 
    463   if not ((ksize_r == 1 and ksize_c == 1) or ksize_d == 1):
    464     raise ValueError("MaxPooling supports exactly one of pooling across depth "
    465                      "or pooling across width/height.")
    466 
    467   # TODO(mrry,shlens): Raise an error if the stride would cause
    468   # information in the input to be ignored. This will require a change
    469   # in the kernel implementation.
    470   if ksize_d == 1:
    471     padding = op.get_attr("padding")
    472     out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
    473                                                 ksize_c, stride_r, stride_c,
    474                                                 padding)
    475     output_shape = [batch_size, out_rows, out_cols, depth]
    476   else:
    477     if depth % ksize_d > 0:
    478       raise ValueError("Depthwise max pooling requires the depth window "
    479                        "to evenly divide the input depth.")
    480     if stride_d != ksize_d:
    481       raise ValueError("Depthwise max pooling requires the depth window "
    482                        "to equal the depth stride.")
    483     output_shape = [batch_size, in_rows, in_cols, depth // ksize_d]
    484 
    485   if data_format == b"NCHW":
    486     # Convert output shape back to NCHW.
    487     output_shape = [output_shape[0], output_shape[3], output_shape[1],
    488                     output_shape[2]]
    489   return [tensor_shape.TensorShape(output_shape)]
    490 
    491 
    492 def no_outputs(unused_op):
    493   """Shape function for use with ops that have no outputs."""
    494   return []
    495 
    496 
    497 def unknown_shape(op):
    498   """Shape function for use with ops whose output shapes are unknown."""
    499   return [tensor_shape.unknown_shape() for _ in op.outputs]
    500 
    501 
    502 def _broadcast_shape_helper(shape_x, shape_y):
    503   """Helper functions for is_broadcast_compatible and broadcast_shape.
    504 
    505   Args:
    506     shape_x: A `TensorShape`
    507     shape_y: A `TensorShape`
    508 
    509   Returns:
    510     Returns None if the shapes are not broadcast compatible,
    511     a list of the broadcast dimensions otherwise.
    512   """
    513   # To compute the broadcasted dimensions, we zip together shape_x and shape_y,
    514   # and pad with 1 to make them the same length.
    515   broadcasted_dims = reversed(list(six.moves.zip_longest(
    516       reversed(shape_x.dims),
    517       reversed(shape_y.dims),
    518       fillvalue=tensor_shape.Dimension(1))))
    519   # Next we combine the dimensions according to the numpy broadcasting rules.
    520   # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
    521   return_dims = []
    522   for (dim_x, dim_y) in broadcasted_dims:
    523     if dim_x.value is None or dim_y.value is None:
    524       # One or both dimensions is unknown. If either dimension is greater than
    525       # 1, we assume that the program is correct, and the other dimension will
    526       # be broadcast to match it.
    527       # TODO(mrry): If we eliminate the shape checks in C++, we must still
    528       # assert that the unknown dim is either 1 or the same as the known dim.
    529       if dim_x.value is not None and dim_x.value > 1:
    530         return_dims.append(dim_x)
    531       elif dim_y.value is not None and dim_y.value > 1:
    532         return_dims.append(dim_y)
    533       else:
    534         return_dims.append(None)
    535     elif dim_x.value == 1:
    536       # We will broadcast dim_x to dim_y.
    537       return_dims.append(dim_y)
    538     elif dim_y.value == 1:
    539       # We will broadcast dim_y to dim_x.
    540       return_dims.append(dim_x)
    541     elif dim_x.value == dim_y.value:
    542       # The dimensions are compatible, so output is the same size in that
    543       # dimension.
    544       return_dims.append(dim_x.merge_with(dim_y))
    545     else:
    546       return None
    547   return return_dims
    548 
    549 
    550 def is_broadcast_compatible(shape_x, shape_y):
    551   """Returns True if `shape_x` and `shape_y` are broadcast compatible.
    552 
    553   Args:
    554     shape_x: A `TensorShape`
    555     shape_y: A `TensorShape`
    556 
    557   Returns:
    558     True if a shape exists that both `shape_x` and `shape_y` can be broadcasted
    559     to.  False otherwise.
    560   """
    561   if shape_x.ndims is None or shape_y.ndims is None:
    562     return False
    563   return _broadcast_shape_helper(shape_x, shape_y) is not None
    564 
    565 
    566 def broadcast_shape(shape_x, shape_y):
    567   """Returns the broadcasted shape between `shape_x` and `shape_y`.
    568 
    569   Args:
    570     shape_x: A `TensorShape`
    571     shape_y: A `TensorShape`
    572 
    573   Returns:
    574     A `TensorShape` representing the broadcasted shape.
    575 
    576   Raises:
    577     ValueError: If the two shapes can not be broadcasted.
    578   """
    579   if shape_x.ndims is None or shape_y.ndims is None:
    580     return tensor_shape.unknown_shape()
    581   return_dims = _broadcast_shape_helper(shape_x, shape_y)
    582   if return_dims is None:
    583     raise ValueError("Incompatible shapes for broadcasting: %s and %s"
    584                      % (shape_x, shape_y))
    585   return tensor_shape.TensorShape(return_dims)
    586 
    587 
    588 def call_cpp_shape_fn(op, require_shape_fn=True):
    589   """A shape function that delegates to the registered C++ shape function.
    590 
    591   Args:
    592     op: the node in the graph for which to compute output shapes.
    593     require_shape_fn: If true, and the C++ shape function is not registered
    594       in the current binary then an exception is raised; otherwise, if the
    595       C++ shape function is not registered then unknown_shape is used.
    596 
    597   Returns:
    598     A dictionary with the following keys:
    599       shapes: A TensorShape list of the output shapes of the op, as computed
    600         using the C++ shape inference function registered for the op.
    601       handle_shapes: A TensorShape list of the shapes for handle outputs, if
    602          any.
    603       handle_dtypes: A list of DataType enums for the handle outputs, if any.
    604 
    605   Raises:
    606     ValueError: If the C++ shape function returned an error (e.g. because the
    607       shapes of the inputs are of the wrong rank or otherwise incompatible
    608       according to the shape function).
    609     RuntimeError: If the C++ shape function is not registered and
    610       <require_shape_fn> is True.
    611   """
    612   if op.type == "Const":
    613     # To avoid serializing large constants, we special-case constant
    614     # here, even though it has a C++ shape function.  When Python
    615     # calls the C / C-API directly, we should be able to remove this.
    616     return {
    617         "shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)],
    618         "handle_data": [None]
    619     }
    620 
    621   input_tensors_needed = []
    622   input_tensors_as_shapes_needed = []
    623 
    624   while True:
    625     res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
    626                                   input_tensors_as_shapes_needed,
    627                                   require_shape_fn)
    628     if not isinstance(res, dict):
    629       # Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).
    630       return res
    631 
    632     # See if we need to evaluate some inputs.
    633     if not res["inputs_needed"]:
    634       return res
    635     p = cpp_shape_inference_pb2.CppShapeInferenceInputsNeeded()
    636     p = p.FromString(res["inputs_needed"])
    637     changed = False
    638     for idx in p.input_tensors_needed:
    639       if idx not in input_tensors_needed:
    640         input_tensors_needed.append(idx)
    641         changed = True
    642     for idx in p.input_tensors_as_shapes_needed:
    643       if idx not in input_tensors_as_shapes_needed:
    644         input_tensors_as_shapes_needed.append(idx)
    645         changed = True
    646     if not changed:
    647       return res
    648 
    649 
    650 def _call_cpp_shape_fn_impl(
    651     op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn):
    652   """Core implementation of call_cpp_shape_fn."""
    653   graph_def_version = op.graph.graph_def_versions.producer
    654   node_def_str = op.node_def.SerializeToString()
    655 
    656   def tensor_to_inference_result(t):
    657     r = cpp_shape_inference_pb2.CppShapeInferenceResult()
    658     r.shape.CopyFrom(t.get_shape().as_proto())
    659     # pylint: disable=protected-access
    660     if t._handle_data is not None:
    661       r.handle_data.CopyFrom(t._handle_data)
    662     # pylint: enable=protected-access
    663     return r.SerializeToString()
    664   input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
    665 
    666   input_tensors = [None for i in input_shapes]
    667   for idx in input_tensors_needed:
    668     v = tensor_util.constant_value(op.inputs[idx])
    669     if v is not None:
    670       input_tensors[idx] = np.asarray(v)
    671 
    672   serialized_unknown_shape = (
    673       tensor_shape.TensorShape(None).as_proto().SerializeToString())
    674   arr = [serialized_unknown_shape for i in input_shapes]
    675   for idx in input_tensors_as_shapes_needed:
    676     s = tensor_util.constant_value_as_shape(op.inputs[idx])
    677     if s is not None:
    678       arr[idx] = s.as_proto().SerializeToString()
    679   input_tensors_as_shapes = arr
    680 
    681   missing_shape_fn = False
    682   try:
    683     with errors.raise_exception_on_not_ok_status() as status:
    684       output = pywrap_tensorflow.RunCppShapeInference(
    685           graph_def_version, node_def_str, input_shapes, input_tensors,
    686           input_tensors_as_shapes, status)
    687   except errors.InvalidArgumentError as err:
    688     if err.message.startswith("No shape inference function exists for op"):
    689       missing_shape_fn = True
    690     else:
    691       raise ValueError(err.message)
    692 
    693   if missing_shape_fn:
    694     if require_shape_fn:
    695       raise RuntimeError(
    696           "No C++ shape function registered for standard op: %s" % op.type)
    697     return unknown_shape(op)
    698 
    699   output_shapes = output[:-1]
    700 
    701   # Convert TensorShapeProto values in output_shapes.
    702   result_protos = [
    703       cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
    704       for s in output_shapes
    705   ]
    706   result = [r.shape for r in result_protos]
    707   result_handle_data = [
    708       r.handle_data if r.handle_data.is_set else None for r in result_protos
    709   ]
    710 
    711   return {
    712       "shapes": result,
    713       "handle_data": result_handle_data,
    714       "inputs_needed": output[-1]
    715   }
    716 
    717 # pylint: disable=protected-access
    718 ops._set_call_cpp_shape_fn(call_cpp_shape_fn)
    719 # pylint: enable=protected-access
    720