Home | History | Annotate | Download | only in ragged
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Value for RaggedTensor."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.util.tf_export import tf_export
     24 
     25 
     26 @tf_export(v1=["ragged.RaggedTensorValue"])
     27 class RaggedTensorValue(object):
     28   """Represents the value of a `RaggedTensor`.
     29 
     30   Warning: `RaggedTensorValue` should only be used in graph mode; in
     31   eager mode, the `tf.RaggedTensor` class contains its value directly.
     32 
     33   See `tf.RaggedTensor` for a description of ragged tensors.
     34   """
     35 
     36   def __init__(self, values, row_splits):
     37     """Creates a `RaggedTensorValue`.
     38 
     39     Args:
     40       values: A numpy array of any type and shape; or a RaggedTensorValue.
     41       row_splits: A 1-D int64 numpy array.
     42     """
     43     if not (isinstance(row_splits, (np.ndarray, np.generic)) and
     44             row_splits.dtype == np.int64 and row_splits.ndim == 1):
     45       raise TypeError("row_splits must be a 1D int64 numpy array")
     46     if not isinstance(values, (np.ndarray, np.generic, RaggedTensorValue)):
     47       raise TypeError("values must be a numpy array or a RaggedTensorValue")
     48     self._values = values
     49     self._row_splits = row_splits
     50 
     51   row_splits = property(
     52       lambda self: self._row_splits,
     53       doc="""The split indices for the ragged tensor value.""")
     54   values = property(
     55       lambda self: self._values,
     56       doc="""The concatenated values for all rows in this tensor.""")
     57   dtype = property(
     58       lambda self: self._values.dtype,
     59       doc="""The numpy dtype of values in this tensor.""")
     60 
     61   @property
     62   def flat_values(self):
     63     """The innermost `values` array for this ragged tensor value."""
     64     rt_values = self.values
     65     while isinstance(rt_values, RaggedTensorValue):
     66       rt_values = rt_values.values
     67     return rt_values
     68 
     69   @property
     70   def nested_row_splits(self):
     71     """The row_splits for all ragged dimensions in this ragged tensor value."""
     72     rt_nested_splits = [self.row_splits]
     73     rt_values = self.values
     74     while isinstance(rt_values, RaggedTensorValue):
     75       rt_nested_splits.append(rt_values.row_splits)
     76       rt_values = rt_values.values
     77     return tuple(rt_nested_splits)
     78 
     79   @property
     80   def ragged_rank(self):
     81     """The number of ragged dimensions in this ragged tensor value."""
     82     values_is_ragged = isinstance(self._values, RaggedTensorValue)
     83     return self._values.ragged_rank + 1 if values_is_ragged else 1
     84 
     85   @property
     86   def shape(self):
     87     """A tuple indicating the shape of this RaggedTensorValue."""
     88     return (self._row_splits.shape[0] - 1,) + (None,) + self._values.shape[1:]
     89 
     90   def __str__(self):
     91     return "<tf.RaggedTensorValue %s>" % self.to_list()
     92 
     93   def __repr__(self):
     94     return "tf.RaggedTensorValue(values=%r, row_splits=%r)" % (self._values,
     95                                                                self._row_splits)
     96 
     97   def to_list(self):
     98     """Returns this ragged tensor value as a nested Python list."""
     99     if isinstance(self._values, RaggedTensorValue):
    100       values_as_list = self._values.to_list()
    101     else:
    102       values_as_list = self._values.tolist()
    103     return [
    104         values_as_list[self._row_splits[i]:self._row_splits[i + 1]]
    105         for i in range(len(self._row_splits) - 1)
    106     ]
    107