Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A library of common shape functions."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 import six.moves
     22 
     23 from tensorflow.python import pywrap_tensorflow
     24 from tensorflow.python.framework import cpp_shape_inference_pb2
     25 from tensorflow.python.framework import errors
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import tensor_shape
     28 from tensorflow.python.framework import tensor_util
     29 
     30 
     31 def has_fully_defined_shape(tensor):
     32   """Returns true if tensor has a fully defined shape."""
     33   return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined()
     34 
     35 
     36 def rank(tensor):
     37   """Return a rank if it is a tensor, else return None."""
     38   if isinstance(tensor, ops.Tensor):
     39     return tensor._rank()  # pylint: disable=protected-access
     40   return None
     41 
     42 
     43 def scalar_shape(unused_op):
     44   """Shape function for ops that output a scalar value."""
     45   return [tensor_shape.scalar()]
     46 
     47 
     48 def unchanged_shape(op):
     49   """Shape function for ops that output a tensor like their first input."""
     50   return [op.inputs[0].get_shape()]
     51 
     52 
     53 def unchanged_shape_with_rank(rank):
     54   """Returns a shape function for ops that constrain the rank of their input.
     55 
     56   Args:
     57     rank: The exact rank of the input and output.
     58 
     59   Returns:
     60     A shape function for ops that output a tensor of the same size as their
     61     input, with a particular rank.
     62   """
     63 
     64   def _ShapeFunction(op):
     65     return [op.inputs[0].get_shape().with_rank(rank)]
     66 
     67   return _ShapeFunction
     68 
     69 
     70 def unchanged_shape_with_rank_at_least(rank):
     71   """Returns a shape function for ops that constrain the rank of their input.
     72 
     73   Args:
     74     rank: A lower bound on the rank of the input and output.
     75 
     76   Returns:
     77     A shape function for ops that output a tensor of the same size as their
     78     input, with a particular rank.
     79   """
     80 
     81   def _ShapeFunction(op):
     82     return [op.inputs[0].get_shape().with_rank_at_least(rank)]
     83 
     84   return _ShapeFunction
     85 
     86 
     87 def unchanged_shape_with_rank_at_most(rank):
     88   """Returns a shape function for ops that constrain the rank of their input.
     89 
     90   Args:
     91     rank: An upper bound on the rank of the input and output.
     92 
     93   Returns:
     94     A shape function for ops that output a tensor of the same size as their
     95     input, with a particular rank.
     96   """
     97 
     98   def _ShapeFunction(op):
     99     return [op.inputs[0].get_shape().with_rank_at_most(rank)]
    100 
    101   return _ShapeFunction
    102 
    103 
    104 def matmul_shape(op):
    105   """Shape function for a MatMul op."""
    106   a_shape = op.inputs[0].get_shape().with_rank(2)
    107   transpose_a = op.get_attr("transpose_a")
    108   b_shape = op.inputs[1].get_shape().with_rank(2)
    109   transpose_b = op.get_attr("transpose_b")
    110   output_rows = a_shape[1] if transpose_a else a_shape[0]
    111   output_cols = b_shape[0] if transpose_b else b_shape[1]
    112   inner_a = a_shape[0] if transpose_a else a_shape[1]
    113   inner_b = b_shape[1] if transpose_b else b_shape[0]
    114   inner_a.assert_is_compatible_with(inner_b)
    115   return [tensor_shape.TensorShape([output_rows, output_cols])]
    116 
    117 
    118 def get_conv_output_size(input_size, filter_size, strides, padding_type):
    119   """Returns the spatial size of a n-d convolution/pooling output."""
    120   input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
    121   filter_size = tuple([tensor_shape.as_dimension(x).value for x in filter_size])
    122   strides = [int(x) for x in strides]
    123 
    124   if all(x == 1 for x in input_size) and all(x == 1 for x in filter_size):
    125     return input_size
    126 
    127   if any(x is not None and y is not None and x > y for x, y in
    128          zip(filter_size, input_size)):
    129     raise ValueError("Filter must not be larger than the input: "
    130                      "Filter: %r Input: %r" % (filter_size, input_size))
    131 
    132   if padding_type == b"VALID":
    133 
    134     def _valid(in_dim, k_dim, s_dim):
    135       if in_dim is not None and k_dim is not None:
    136         return (in_dim - k_dim + s_dim) // s_dim
    137       else:
    138         return None
    139 
    140     output_size = [
    141         _valid(in_dim, k_dim, s_dim)
    142         for in_dim, k_dim, s_dim in zip(input_size, filter_size, strides)
    143     ]
    144   elif padding_type == b"SAME":
    145 
    146     def _same(in_dim, s_dim):
    147       if in_dim is not None:
    148         return (in_dim + s_dim - 1) // s_dim
    149       else:
    150         return None
    151 
    152     output_size = [_same(in_dim, s_dim)
    153                    for in_dim, s_dim in zip(input_size, strides)]
    154   else:
    155     raise ValueError("Invalid padding: %r" % padding_type)
    156 
    157   return tuple(output_size)
    158 
    159 
    160 def get2d_conv_output_size(input_height, input_width, filter_height,
    161                            filter_width, row_stride, col_stride, padding_type):
    162   """Returns the number of rows and columns in a convolution/pooling output."""
    163   return get_conv_output_size((input_height, input_width),
    164                               (filter_height, filter_width),
    165                               (row_stride, col_stride), padding_type)
    166 
    167 
    168 def conv2d_shape(op):
    169   """Shape function for a Conv2D op.
    170 
    171   This op has two inputs:
    172 
    173   * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
    174   * filter, a 4D tensor with shape =  [filter_rows, filter_cols,
    175     depth_in, depth_out]
    176 
    177   The output is a 4D tensor with shape = [batch_size, out_rows,
    178   out_cols, depth_out], where out_rows and out_cols depend on the
    179   value of the op's "padding" and "strides" attrs.
    180 
    181   Args:
    182     op: A Conv2D Operation.
    183 
    184   Returns:
    185     A list containing the Shape of the Conv2D output.
    186 
    187   Raises:
    188     ValueError: If the shapes of the input or filter are incompatible.
    189   """
    190   input_shape = op.inputs[0].get_shape().with_rank(4)
    191   filter_shape = op.inputs[1].get_shape().with_rank(4)
    192 
    193   try:
    194     data_format = op.get_attr("data_format")
    195   except ValueError:
    196     data_format = None
    197 
    198   if data_format == b"NCHW":
    199     # Convert input shape to the default NHWC for inference.
    200     input_shape = [input_shape[0], input_shape[2], input_shape[3],
    201                    input_shape[1]]
    202 
    203   batch_size = input_shape[0]
    204   in_rows = input_shape[1]
    205   in_cols = input_shape[2]
    206 
    207   filter_rows = filter_shape[0]
    208   filter_cols = filter_shape[1]
    209   depth_out = filter_shape[3]
    210   # Check that the input depths are compatible.
    211   input_shape[3].assert_is_compatible_with(filter_shape[2])
    212 
    213   if data_format == b"NCHW":
    214     stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
    215   else:
    216     stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
    217 
    218   if stride_b != 1 or stride_d != 1:
    219     raise ValueError("Current implementation does not yet support "
    220                      "strides in the batch and depth dimensions.")
    221   # TODO(mrry,shlens): Raise an error if the stride would cause
    222   # information in the input to be ignored. This will require a change
    223   # in the kernel implementation.
    224   padding = op.get_attr("padding")
    225   out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
    226                                               filter_cols, stride_r, stride_c,
    227                                               padding)
    228 
    229   output_shape = [batch_size, out_rows, out_cols, depth_out]
    230   if data_format == b"NCHW":
    231     # Convert output shape back to NCHW.
    232     output_shape = [output_shape[0], output_shape[3], output_shape[1],
    233                     output_shape[2]]
    234   return [tensor_shape.TensorShape(output_shape)]
    235 
    236 
    237 def depthwise_conv2d_native_shape(op):
    238   """Shape function for a DepthwiseConv2D op.
    239 
    240   This op has two inputs:
    241 
    242   * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
    243   * filter, a 4D tensor with shape =  [filter_rows, filter_cols,
    244     depth_in, depthwise_multiplier]
    245 
    246   The output is a 4D tensor with shape = [batch_size, out_rows,
    247   out_cols, depth_in*depthwise_multiplier], where out_rows and out_cols depend
    248   on the value of the op's "padding" and "strides" attrs.
    249 
    250   Args:
    251     op: A DepthwiseConv2dNative Operation.
    252 
    253   Returns:
    254     A list containing the Shape of the DepthwiseConv2DNative output.
    255 
    256   Raises:
    257     ValueError: If the shapes of the input or filter are incompatible.
    258   """
    259   input_shape = op.inputs[0].get_shape().with_rank(4)
    260   filter_shape = op.inputs[1].get_shape().with_rank(4)
    261 
    262   batch_size = input_shape[0]
    263   in_rows = input_shape[1]
    264   in_cols = input_shape[2]
    265 
    266   filter_rows = filter_shape[0]
    267   filter_cols = filter_shape[1]
    268   depth_out = filter_shape[3] * filter_shape[2]
    269   # Check that the input depths are compatible.
    270   input_shape[3].assert_is_compatible_with(filter_shape[2])
    271 
    272   stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
    273   if stride_b != 1 or stride_d != 1:
    274     raise ValueError("Current implementation does not yet support "
    275                      "strides in the batch and depth dimensions.")
    276   if stride_r != stride_c:
    277     # TODO(shlens): Add support for this.
    278     raise ValueError("Current implementation only supports equal length "
    279                      "strides in the row and column dimensions.")
    280 
    281   # TODO(mrry,shlens): Raise an error if the stride would cause
    282   # information in the input to be ignored. This will require a change
    283   # in the kernel implementation.
    284   stride = stride_r
    285   padding = op.get_attr("padding")
    286   out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
    287                                               filter_cols, stride, stride,
    288                                               padding)
    289 
    290   return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
    291 
    292 
    293 def separable_conv2d_shape(op):
    294   """Shape function for a SeparableConv2D op.
    295 
    296   This op has three inputs:
    297 
    298   * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
    299 
    300   * depthwise_filter, a 4D tensor with shape = [filter_rows,
    301     filter_cols, depth_in, depth_multiplier]
    302 
    303   * pointwise_filter, a 4D tensor with shape = [1, 1, depth_in *
    304     depth_multiplier, depth_out]
    305 
    306   The output is a 4D tensor with shape = [batch_size, out_rows,
    307   out_cols, depth_out], where out_rows and out_cols depend on the
    308   value of the op's "padding" and "strides" attrs.
    309 
    310   Args:
    311     op: A SeparableConv2D Operation.
    312 
    313   Returns:
    314     A list containing the Shape of the SeparableConv2D output.
    315 
    316   Raises:
    317     ValueError: If the shapes of the input or filter are incompatible.
    318   """
    319   input_shape = op.inputs[0].get_shape().with_rank(4)
    320   depthwise_filter_shape = op.inputs[1].get_shape().merge_with(
    321       tensor_shape.TensorShape([None, None, input_shape[3], None]))
    322   pointwise_depth_in = depthwise_filter_shape[2] * depthwise_filter_shape[3]
    323 
    324   pointwise_filter_shape = op.inputs[2].get_shape().merge_with(
    325       tensor_shape.TensorShape([1, 1, pointwise_depth_in, None]))
    326 
    327   batch_size = input_shape[0]
    328   in_rows = input_shape[1]
    329   in_cols = input_shape[2]
    330 
    331   filter_rows = depthwise_filter_shape[0]
    332   filter_cols = depthwise_filter_shape[1]
    333   depth_out = pointwise_filter_shape[3]
    334 
    335   stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
    336   if stride_b != 1 or stride_d != 1:
    337     raise ValueError("Current implementation does not yet support "
    338                      "strides in the batch and depth dimensions.")
    339   if stride_r != stride_c:
    340     # TODO(shlens): Add support for this.
    341     raise ValueError("Current implementation only supports equal length "
    342                      "strides in the row and column dimensions.")
    343 
    344   # TODO(mrry,shlens): Raise an error if the stride would cause
    345   # information in the input to be ignored. This will require a change
    346   # in the kernel implementation.
    347   stride = stride_r
    348   padding = op.get_attr("padding")
    349   out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
    350                                               filter_cols, stride, stride,
    351                                               padding)
    352 
    353   return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
    354 
    355 
    356 def avg_pool_shape(op):
    357   """Shape function for an AvgPool op.
    358 
    359   This op has one input:
    360 
    361   * input, a 4D tensor with shape = [batch_size, rows, cols, depth]
    362 
    363   The output is a 4D tensor with shape = [batch_size, out_rows,
    364   out_cols, depth_out], where out_rows and out_cols depend on the
    365   value of the op's "ksize", "strides", and "padding" attrs.
    366 
    367   Args:
    368     op: An AvgPool Operation.
    369 
    370   Returns:
    371     A single-element list containing the Shape of the AvgPool output.
    372 
    373   Raises:
    374     ValueError: If the shape of the input is invalid or incompatible with
    375       the values of the attrs.
    376   """
    377   input_shape = op.inputs[0].get_shape().with_rank(4)
    378   try:
    379     data_format = op.get_attr("data_format")
    380   except ValueError:
    381     data_format = None
    382 
    383   if data_format == b"NCHW":
    384     # Convert input shape to the default NHWC for inference.
    385     input_shape = [input_shape[0], input_shape[2], input_shape[3],
    386                    input_shape[1]]
    387 
    388   if data_format == b"NCHW":
    389     ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
    390     stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
    391   else:
    392     ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
    393     stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
    394 
    395   batch_size = input_shape[0]
    396   in_rows = input_shape[1]
    397   in_cols = input_shape[2]
    398   depth = input_shape[3]
    399 
    400   if ksize_b != 1 or ksize_d != 1:
    401     raise ValueError("Current implementation does not support pooling "
    402                      "in the batch and depth dimensions.")
    403   if stride_b != 1 or stride_d != 1:
    404     raise ValueError("Current implementation does not support strides "
    405                      "in the batch and depth dimensions.")
    406 
    407   # TODO(mrry,shlens): Raise an error if the stride would cause
    408   # information in the input to be ignored. This will require a change
    409   # in the kernel implementation.
    410   padding = op.get_attr("padding")
    411 
    412   out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
    413                                               ksize_c, stride_r, stride_c,
    414                                               padding)
    415 
    416   output_shape = [batch_size, out_rows, out_cols, depth]
    417   if data_format == b"NCHW":
    418     # Convert output shape back to NCHW.
    419     output_shape = [output_shape[0], output_shape[3], output_shape[1],
    420                     output_shape[2]]
    421   return [tensor_shape.TensorShape(output_shape)]
    422 
    423 
    424 def max_pool_shape(op):
    425   """Shape function for a MaxPool op.
    426 
    427   This op has one input:
    428 
    429   * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
    430 
    431   The output is a 4D tensor with shape = [batch_size, out_rows,
    432   out_cols, depth_out], where out_rows, out_cols, and depth_out depend
    433   on the value of the op's "ksize", "strides", and "padding" attrs.
    434 
    435   Args:
    436     op: A MaxPool Operation.
    437 
    438   Returns:
    439     A single-element list containing the Shape of the MaxPool output.
    440 
    441   Raises:
    442     ValueError: If the shape of the input is invalid or incompatible with
    443       the values of the attrs.
    444   """
    445   input_shape = op.inputs[0].get_shape().with_rank(4)
    446   try:
    447     data_format = op.get_attr("data_format")
    448   except ValueError:
    449     data_format = None
    450 
    451   if data_format == b"NCHW":
    452     # Convert input shape to the default NHWC for inference.
    453     input_shape = [input_shape[0], input_shape[2], input_shape[3],
    454                    input_shape[1]]
    455 
    456   if data_format == b"NCHW":
    457     ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
    458     stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
    459   else:
    460     ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
    461     stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
    462 
    463   batch_size = input_shape[0]
    464   in_rows = input_shape[1]
    465   in_cols = input_shape[2]
    466   depth = input_shape[3]
    467 
    468   if ksize_b != 1:
    469     raise ValueError("Current implementation does not support pooling "
    470                      "in the batch dimension.")
    471   if stride_b != 1:
    472     raise ValueError("Current implementation does not support strides "
    473                      "in the batch dimension.")
    474 
    475   if not ((ksize_r == 1 and ksize_c == 1) or ksize_d == 1):
    476     raise ValueError("MaxPooling supports exactly one of pooling across depth "
    477                      "or pooling across width/height.")
    478 
    479   # TODO(mrry,shlens): Raise an error if the stride would cause
    480   # information in the input to be ignored. This will require a change
    481   # in the kernel implementation.
    482   if ksize_d == 1:
    483     padding = op.get_attr("padding")
    484     out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
    485                                                 ksize_c, stride_r, stride_c,
    486                                                 padding)
    487     output_shape = [batch_size, out_rows, out_cols, depth]
    488   else:
    489     if depth % ksize_d > 0:
    490       raise ValueError("Depthwise max pooling requires the depth window "
    491                        "to evenly divide the input depth.")
    492     if stride_d != ksize_d:
    493       raise ValueError("Depthwise max pooling requires the depth window "
    494                        "to equal the depth stride.")
    495     output_shape = [batch_size, in_rows, in_cols, depth // ksize_d]
    496 
    497   if data_format == b"NCHW":
    498     # Convert output shape back to NCHW.
    499     output_shape = [output_shape[0], output_shape[3], output_shape[1],
    500                     output_shape[2]]
    501   return [tensor_shape.TensorShape(output_shape)]
    502 
    503 
    504 def no_outputs(unused_op):
    505   """Shape function for use with ops that have no outputs."""
    506   return []
    507 
    508 
    509 def unknown_shape(op):
    510   """Shape function for use with ops whose output shapes are unknown."""
    511   return [tensor_shape.unknown_shape() for _ in op.outputs]
    512 
    513 
    514 def _broadcast_shape_helper(shape_x, shape_y):
    515   """Helper functions for is_broadcast_compatible and broadcast_shape.
    516 
    517   Args:
    518     shape_x: A `TensorShape`
    519     shape_y: A `TensorShape`
    520 
    521   Returns:
    522     Returns None if the shapes are not broadcast compatible,
    523     a list of the broadcast dimensions otherwise.
    524   """
    525   # To compute the broadcasted dimensions, we zip together shape_x and shape_y,
    526   # and pad with 1 to make them the same length.
    527   broadcasted_dims = reversed(list(six.moves.zip_longest(
    528       reversed(shape_x.dims),
    529       reversed(shape_y.dims),
    530       fillvalue=tensor_shape.Dimension(1))))
    531   # Next we combine the dimensions according to the numpy broadcasting rules.
    532   # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
    533   return_dims = []
    534   for (dim_x, dim_y) in broadcasted_dims:
    535     if dim_x.value is None or dim_y.value is None:
    536       # One or both dimensions is unknown. If either dimension is greater than
    537       # 1, we assume that the program is correct, and the other dimension will
    538       # be broadcast to match it.
    539       # TODO(mrry): If we eliminate the shape checks in C++, we must still
    540       # assert that the unknown dim is either 1 or the same as the known dim.
    541       if dim_x.value is not None and dim_x.value > 1:
    542         return_dims.append(dim_x)
    543       elif dim_y.value is not None and dim_y.value > 1:
    544         return_dims.append(dim_y)
    545       else:
    546         return_dims.append(None)
    547     elif dim_x.value == 1:
    548       # We will broadcast dim_x to dim_y.
    549       return_dims.append(dim_y)
    550     elif dim_y.value == 1:
    551       # We will broadcast dim_y to dim_x.
    552       return_dims.append(dim_x)
    553     elif dim_x.value == dim_y.value:
    554       # The dimensions are compatible, so output is the same size in that
    555       # dimension.
    556       return_dims.append(dim_x.merge_with(dim_y))
    557     else:
    558       return None
    559   return return_dims
    560 
    561 
    562 def is_broadcast_compatible(shape_x, shape_y):
    563   """Returns True if `shape_x` and `shape_y` are broadcast compatible.
    564 
    565   Args:
    566     shape_x: A `TensorShape`
    567     shape_y: A `TensorShape`
    568 
    569   Returns:
    570     True if a shape exists that both `shape_x` and `shape_y` can be broadcasted
    571     to.  False otherwise.
    572   """
    573   if shape_x.ndims is None or shape_y.ndims is None:
    574     return False
    575   return _broadcast_shape_helper(shape_x, shape_y) is not None
    576 
    577 
    578 def broadcast_shape(shape_x, shape_y):
    579   """Returns the broadcasted shape between `shape_x` and `shape_y`.
    580 
    581   Args:
    582     shape_x: A `TensorShape`
    583     shape_y: A `TensorShape`
    584 
    585   Returns:
    586     A `TensorShape` representing the broadcasted shape.
    587 
    588   Raises:
    589     ValueError: If the two shapes can not be broadcasted.
    590   """
    591   if shape_x.ndims is None or shape_y.ndims is None:
    592     return tensor_shape.unknown_shape()
    593   return_dims = _broadcast_shape_helper(shape_x, shape_y)
    594   if return_dims is None:
    595     raise ValueError("Incompatible shapes for broadcasting: %s and %s"
    596                      % (shape_x, shape_y))
    597   return tensor_shape.TensorShape(return_dims)
    598 
    599 
    600 def call_cpp_shape_fn(op, require_shape_fn=True):
    601   """A shape function that delegates to the registered C++ shape function.
    602 
    603   Args:
    604     op: the node in the graph for which to compute output shapes.
    605     require_shape_fn: If true, and the C++ shape function is not registered
    606       in the current binary then an exception is raised; otherwise, if the
    607       C++ shape function is not registered then unknown_shape is used.
    608 
    609   Returns:
    610     A dictionary with the following keys:
    611       shapes: A TensorShape list of the output shapes of the op, as computed
    612         using the C++ shape inference function registered for the op.
    613       handle_shapes: A TensorShape list of the shapes for handle outputs, if
    614          any.
    615       handle_dtypes: A list of DataType enums for the handle outputs, if any.
    616 
    617   Raises:
    618     ValueError: If the C++ shape function returned an error (e.g. because the
    619       shapes of the inputs are of the wrong rank or otherwise incompatible
    620       according to the shape function).
    621     RuntimeError: If the C++ shape function is not registered and
    622       <require_shape_fn> is True.
    623   """
    624   if op.type == "Const":
    625     # To avoid serializing large constants, we special-case constant
    626     # here, even though it has a C++ shape function.  When Python
    627     # calls the C / C-API directly, we should be able to remove this.
    628     return {
    629         "shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)],
    630         "handle_data": [None]
    631     }
    632 
    633   input_tensors_needed = []
    634   input_tensors_as_shapes_needed = []
    635 
    636   while True:
    637     res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
    638                                   input_tensors_as_shapes_needed,
    639                                   require_shape_fn)
    640     if not isinstance(res, dict):
    641       # Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).
    642       return res
    643 
    644     # See if we need to evaluate some inputs.
    645     if not res["inputs_needed"]:
    646       return res
    647     p = cpp_shape_inference_pb2.CppShapeInferenceInputsNeeded()
    648     p = p.FromString(res["inputs_needed"])
    649     changed = False
    650     for idx in p.input_tensors_needed:
    651       if idx not in input_tensors_needed:
    652         input_tensors_needed.append(idx)
    653         changed = True
    654     for idx in p.input_tensors_as_shapes_needed:
    655       if idx not in input_tensors_as_shapes_needed:
    656         input_tensors_as_shapes_needed.append(idx)
    657         changed = True
    658     if not changed:
    659       return res
    660 
    661 
    662 def _call_cpp_shape_fn_impl(
    663     op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn):
    664   """Core implementation of call_cpp_shape_fn."""
    665   graph_def_version = op.graph.graph_def_versions.producer
    666   node_def_str = op.node_def.SerializeToString()
    667 
    668   def tensor_to_inference_result(t):
    669     r = cpp_shape_inference_pb2.CppShapeInferenceResult()
    670     r.shape.CopyFrom(t.get_shape().as_proto())
    671     # pylint: disable=protected-access
    672     if t._handle_data is not None:
    673       r.handle_data.CopyFrom(t._handle_data)
    674     # pylint: enable=protected-access
    675     return r.SerializeToString()
    676   input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
    677 
    678   input_tensors = [None for i in input_shapes]
    679   for idx in input_tensors_needed:
    680     v = tensor_util.constant_value(op.inputs[idx])
    681     if v is not None:
    682       input_tensors[idx] = np.asarray(v)
    683 
    684   serialized_unknown_shape = (
    685       tensor_shape.TensorShape(None).as_proto().SerializeToString())
    686   arr = [serialized_unknown_shape for i in input_shapes]
    687   for idx in input_tensors_as_shapes_needed:
    688     s = tensor_util.constant_value_as_shape(op.inputs[idx])
    689     if s is not None:
    690       arr[idx] = s.as_proto().SerializeToString()
    691   input_tensors_as_shapes = arr
    692 
    693   missing_shape_fn = False
    694   try:
    695     with errors.raise_exception_on_not_ok_status() as status:
    696       output = pywrap_tensorflow.RunCppShapeInference(
    697           graph_def_version, node_def_str, input_shapes, input_tensors,
    698           input_tensors_as_shapes, status)
    699   except errors.InvalidArgumentError as err:
    700     if err.message.startswith("No shape inference function exists for op"):
    701       missing_shape_fn = True
    702     else:
    703       raise ValueError(err.message)
    704 
    705   if missing_shape_fn:
    706     if require_shape_fn:
    707       raise RuntimeError(
    708           "No C++ shape function registered for standard op: %s" % op.type)
    709     return unknown_shape(op)
    710 
    711   output_shapes = output[:-1]
    712 
    713   # Convert TensorShapeProto values in output_shapes.
    714   result_protos = [
    715       cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
    716       for s in output_shapes
    717   ]
    718   result = [r.shape for r in result_protos]
    719   result_handle_data = [
    720       r.handle_data if r.handle_data.is_set else None for r in result_protos
    721   ]
    722 
    723   return {
    724       "shapes": result,
    725       "handle_data": result_handle_data,
    726       "inputs_needed": output[-1]
    727   }
    728 
    729 # pylint: disable=protected-access
    730 ops._set_call_cpp_shape_fn(call_cpp_shape_fn)
    731 # pylint: enable=protected-access
    732