Home | History | Annotate | Download | only in learn_io
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Methods to allow dask.DataFrame."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import numpy as np
     23 
     24 try:
     25   # pylint: disable=g-import-not-at-top
     26   import dask.dataframe as dd
     27   allowed_classes = (dd.Series, dd.DataFrame)
     28   HAS_DASK = True
     29 except ImportError:
     30   HAS_DASK = False
     31 
     32 
     33 def _add_to_index(df, start):
     34   """New dask.dataframe with values added to index of each subdataframe."""
     35   df = df.copy()
     36   df.index += start
     37   return df
     38 
     39 
     40 def _get_divisions(df):
     41   """Number of rows in each sub-dataframe."""
     42   lengths = df.map_partitions(len).compute()
     43   divisions = np.cumsum(lengths).tolist()
     44   divisions.insert(0, 0)
     45   return divisions
     46 
     47 
     48 def _construct_dask_df_with_divisions(df):
     49   """Construct the new task graph and make a new dask.dataframe around it."""
     50   divisions = _get_divisions(df)
     51   # pylint: disable=protected-access
     52   name = 'csv-index' + df._name
     53   dsk = {(name, i): (_add_to_index, (df._name, i), divisions[i])
     54          for i in range(df.npartitions)}
     55   # pylint: enable=protected-access
     56   from toolz import merge  # pylint: disable=g-import-not-at-top
     57   if isinstance(df, dd.DataFrame):
     58     return dd.DataFrame(merge(dsk, df.dask), name, df.columns, divisions)
     59   elif isinstance(df, dd.Series):
     60     return dd.Series(merge(dsk, df.dask), name, df.name, divisions)
     61 
     62 
     63 def extract_dask_data(data):
     64   """Extract data from dask.Series or dask.DataFrame for predictors.
     65 
     66   Given a distributed dask.DataFrame or dask.Series containing columns or names
     67   for one or more predictors, this operation returns a single dask.DataFrame or
     68   dask.Series that can be iterated over.
     69 
     70   Args:
     71     data: A distributed dask.DataFrame or dask.Series.
     72 
     73   Returns:
     74     A dask.DataFrame or dask.Series that can be iterated over.
     75     If the supplied argument is neither a dask.DataFrame nor a dask.Series this
     76     operation returns it without modification.
     77   """
     78   if isinstance(data, allowed_classes):
     79     return _construct_dask_df_with_divisions(data)
     80   else:
     81     return data
     82 
     83 
     84 def extract_dask_labels(labels):
     85   """Extract data from dask.Series or dask.DataFrame for labels.
     86 
     87   Given a distributed dask.DataFrame or dask.Series containing exactly one
     88   column or name, this operation returns a single dask.DataFrame or dask.Series
     89   that can be iterated over.
     90 
     91   Args:
     92     labels: A distributed dask.DataFrame or dask.Series with exactly one
     93             column or name.
     94 
     95   Returns:
     96     A dask.DataFrame or dask.Series that can be iterated over.
     97     If the supplied argument is neither a dask.DataFrame nor a dask.Series this
     98     operation returns it without modification.
     99 
    100   Raises:
    101     ValueError: If the supplied dask.DataFrame contains more than one
    102                 column or the supplied dask.Series contains more than
    103                 one name.
    104   """
    105   if isinstance(labels, dd.DataFrame):
    106     ncol = labels.columns
    107   elif isinstance(labels, dd.Series):
    108     ncol = labels.name
    109   if isinstance(labels, allowed_classes):
    110     if len(ncol) > 1:
    111       raise ValueError('Only one column for labels is allowed.')
    112     return _construct_dask_df_with_divisions(labels)
    113   else:
    114     return labels
    115