Home | History | Annotate | Download | only in ragged

Lines Matching refs:lengths

46       lengths.  `RaggedTensorDynamicShape` records the size of each ragged
47 dimension using an integer vector containing the slice lengths for all
267 def broadcast_dimension(self, axis, lengths):
268 """Returns a shape that is broadcast-compatible with self & lengths.
270 * If dimension[axis] is uniform and lengths is a scalar, the check
271 that either lengths==1 or axis==1 or lengths==axis, and tile
272 dimension[axis] with tf.where(lengths==axis, 1, axis) repeats.
274 * If dimension[axis] is uniform and lengths is a vector, then check
276 lengths repeats. (we can skip tiling if we statically know that
279 * If dimension[axis] is ragged and lengths is a scalar, then check
280 that lengths==1.
282 * If dimension[axis] is ragged and lengths is a vector, then check
283 that self.dimension_size(axis) == lengths.
287 lengths: 0-D or 1-D integer `Tensor`.
292 lengths = ragged_util.convert_to_int_tensor(
293 lengths, name='lengths', dtype=dtypes.int64)
294 # Check whether lengths is a scalar (for uniform dimensions) or
296 if lengths.shape.ndims is None:
297 raise ValueError('lengths must have a known rank.')
298 elif lengths.shape.ndims > 1:
299 raise ValueError('lengths must be a scalar or vector')
301 lengths_is_scalar = (lengths.shape.ndims == 0)
306 condition = math_ops.equal(lengths, 1)
309 math_ops.equal(lengths, self.dimension_size(axis)))
314 math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1)
315 | math_ops.equal(axis_dim_size, lengths))
320 'lengths=', lengths, 'dim_size=',
335 return self._broadcast_uniform_partitioned_dimension(axis, lengths)
340 return self._broadcast_inner_dimension_to_uniform(axis, lengths)
345 return self._broadcast_inner_dimension_to_ragged(axis, lengths)
356 def _broadcast_uniform_partitioned_dimension(self, axis, lengths):
357 """Broadcasts the partitioned dimension `axis` to match `lengths`."""
361 if lengths.shape.ndims == 0:
362 lengths = array_ops.where(
363 math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size)
364 repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1)
368 array_ops.size(lengths, out_type=dtypes.int64) + 1)
369 repeats = lengths
371 partitioned_sizes.append(lengths)
386 """Broadcasts the inner dimension `axis` to match `lengths`."""
398 def _broadcast_inner_dimension_to_ragged(self, axis, lengths):
403 ]) + (lengths,))