1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Methods to allow dask.DataFrame.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import numpy as np 23 24 try: 25 # pylint: disable=g-import-not-at-top 26 import dask.dataframe as dd 27 allowed_classes = (dd.Series, dd.DataFrame) 28 HAS_DASK = True 29 except ImportError: 30 HAS_DASK = False 31 32 33 def _add_to_index(df, start): 34 """New dask.dataframe with values added to index of each subdataframe.""" 35 df = df.copy() 36 df.index += start 37 return df 38 39 40 def _get_divisions(df): 41 """Number of rows in each sub-dataframe.""" 42 lengths = df.map_partitions(len).compute() 43 divisions = np.cumsum(lengths).tolist() 44 divisions.insert(0, 0) 45 return divisions 46 47 48 def _construct_dask_df_with_divisions(df): 49 """Construct the new task graph and make a new dask.dataframe around it.""" 50 divisions = _get_divisions(df) 51 # pylint: disable=protected-access 52 name = 'csv-index' + df._name 53 dsk = {(name, i): (_add_to_index, (df._name, i), divisions[i]) 54 for i in range(df.npartitions)} 55 # pylint: enable=protected-access 56 from toolz import merge # pylint: disable=g-import-not-at-top 57 if isinstance(df, dd.DataFrame): 58 return dd.DataFrame(merge(dsk, df.dask), name, df.columns, divisions) 59 elif isinstance(df, dd.Series): 60 return dd.Series(merge(dsk, df.dask), name, df.name, divisions) 61 62 63 def extract_dask_data(data): 64 """Extract data from dask.Series or dask.DataFrame for predictors. 65 66 Given a distributed dask.DataFrame or dask.Series containing columns or names 67 for one or more predictors, this operation returns a single dask.DataFrame or 68 dask.Series that can be iterated over. 69 70 Args: 71 data: A distributed dask.DataFrame or dask.Series. 72 73 Returns: 74 A dask.DataFrame or dask.Series that can be iterated over. 75 If the supplied argument is neither a dask.DataFrame nor a dask.Series this 76 operation returns it without modification. 77 """ 78 if isinstance(data, allowed_classes): 79 return _construct_dask_df_with_divisions(data) 80 else: 81 return data 82 83 84 def extract_dask_labels(labels): 85 """Extract data from dask.Series or dask.DataFrame for labels. 86 87 Given a distributed dask.DataFrame or dask.Series containing exactly one 88 column or name, this operation returns a single dask.DataFrame or dask.Series 89 that can be iterated over. 90 91 Args: 92 labels: A distributed dask.DataFrame or dask.Series with exactly one 93 column or name. 94 95 Returns: 96 A dask.DataFrame or dask.Series that can be iterated over. 97 If the supplied argument is neither a dask.DataFrame nor a dask.Series this 98 operation returns it without modification. 99 100 Raises: 101 ValueError: If the supplied dask.DataFrame contains more than one 102 column or the supplied dask.Series contains more than 103 one name. 104 """ 105 if isinstance(labels, dd.DataFrame): 106 ncol = labels.columns 107 elif isinstance(labels, dd.Series): 108 ncol = labels.name 109 if isinstance(labels, allowed_classes): 110 if len(ncol) > 1: 111 raise ValueError('Only one column for labels is allowed.') 112 return _construct_dask_df_with_divisions(labels) 113 else: 114 return labels 115