1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Helper functions for creating partitioned variables. 17 18 This is a convenient abstraction to partition a large variable across 19 multiple smaller variables that can be assigned to different devices. 20 21 The full variable can be reconstructed by concatenating the smaller variables. 22 Using partitioned variables instead of a single variable is mostly a 23 performance choice. It however also has an impact on: 24 25 1. Random initialization, as the random number generator is called once per 26 slice 27 2. Updates, as they happen in parallel across slices 28 29 A key design goal is to allow a different graph to repartition a variable 30 with the same name but different slicings, including possibly no partitions. 31 32 TODO(touts): If an initializer provides a seed, the seed must be changed 33 deterministically for each slice, maybe by adding one to it, otherwise each 34 slice will use the same values. Maybe this can be done by passing the 35 slice offsets to the initializer functions. 36 37 Typical usage: 38 39 ```python 40 # Create a list of partitioned variables with: 41 vs = create_partitioned_variables( 42 <shape>, <slicing>, <initializer>, name=<optional-name>) 43 44 # Pass the list as inputs to embedding_lookup for sharded, parallel lookup: 45 y = embedding_lookup(vs, ids, partition_strategy="div") 46 47 # Or fetch the variables in parallel to speed up large matmuls: 48 z = matmul(x, concat(slice_dim, vs)) 49 ``` 50 """ 51 from __future__ import absolute_import 52 from __future__ import division 53 from __future__ import print_function 54 55 import math 56 57 from tensorflow.python.framework import dtypes 58 from tensorflow.python.framework import tensor_shape 59 from tensorflow.python.ops import variable_scope 60 from tensorflow.python.platform import tf_logging as logging 61 from tensorflow.python.util.tf_export import tf_export 62 63 __all__ = [ 64 "create_partitioned_variables", 65 "variable_axis_size_partitioner", 66 "min_max_variable_partitioner", 67 "fixed_size_partitioner", 68 ] 69 70 71 @tf_export("variable_axis_size_partitioner") 72 def variable_axis_size_partitioner( 73 max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None): 74 """Get a partitioner for VariableScope to keep shards below `max_shard_bytes`. 75 76 This partitioner will shard a Variable along one axis, attempting to keep 77 the maximum shard size below `max_shard_bytes`. In practice, this is not 78 always possible when sharding along only one axis. When this happens, 79 this axis is sharded as much as possible (i.e., every dimension becomes 80 a separate shard). 81 82 If the partitioner hits the `max_shards` limit, then each shard may end up 83 larger than `max_shard_bytes`. By default `max_shards` equals `None` and no 84 limit on the number of shards is enforced. 85 86 One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost 87 `64MB`, to keep below the protobuf byte limit. 88 89 Args: 90 max_shard_bytes: The maximum size any given shard is allowed to be. 91 axis: The axis to partition along. Default: outermost axis. 92 bytes_per_string_element: If the `Variable` is of type string, this provides 93 an estimate of how large each scalar in the `Variable` is. 94 max_shards: The maximum number of shards in int created taking precedence 95 over `max_shard_bytes`. 96 97 Returns: 98 A partition function usable as the `partitioner` argument to 99 `variable_scope`, `get_variable`, and `get_partitioned_variable_list`. 100 101 Raises: 102 ValueError: If any of the byte counts are non-positive. 103 """ 104 if max_shard_bytes < 1 or bytes_per_string_element < 1: 105 raise ValueError( 106 "Both max_shard_bytes and bytes_per_string_element must be positive.") 107 if max_shards and max_shards < 1: 108 raise ValueError( 109 "max_shards must be positive.") 110 111 def _partitioner(shape, dtype): 112 """Partitioner that partitions shards to have max_shard_bytes total size. 113 114 Args: 115 shape: A `TensorShape`. 116 dtype: A `DType`. 117 118 Returns: 119 A tuple representing how much to slice each axis in shape. 120 121 Raises: 122 ValueError: If shape is not a fully defined `TensorShape` or dtype is not 123 a `DType`. 124 """ 125 if not isinstance(shape, tensor_shape.TensorShape): 126 raise ValueError("shape is not a TensorShape: %s" % shape) 127 if not shape.is_fully_defined(): 128 raise ValueError("shape is not fully defined: %s" % shape) 129 if not isinstance(dtype, dtypes.DType): 130 raise ValueError("dtype is not a DType: %s" % dtype) 131 132 if dtype.base_dtype == dtypes.string: 133 element_size = bytes_per_string_element 134 else: 135 element_size = dtype.size 136 137 partitions = [1] * shape.ndims 138 bytes_per_slice = 1.0 * ( 139 shape.num_elements() / shape[axis].value) * element_size 140 # How many slices can we fit on one shard of size at most max_shard_bytes? 141 # At least one slice is required. 142 slices_per_shard = max(1, math.floor(max_shard_bytes / bytes_per_slice)) 143 # How many shards do we need for axis given that each shard fits 144 # slices_per_shard slices from a total of shape[axis].value slices? 145 axis_shards = int(math.ceil(1.0 * shape[axis].value / slices_per_shard)) 146 if max_shards: 147 axis_shards = min(max_shards, axis_shards) 148 149 partitions[axis] = axis_shards 150 151 return partitions 152 153 return _partitioner 154 155 156 @tf_export("min_max_variable_partitioner") 157 def min_max_variable_partitioner(max_partitions=1, axis=0, 158 min_slice_size=256 << 10, 159 bytes_per_string_element=16): 160 """Partitioner to allocate minimum size per slice. 161 162 Returns a partitioner that partitions the variable of given shape and dtype 163 such that each partition has a minimum of `min_slice_size` slice of the 164 variable. The maximum number of such partitions (upper bound) is given by 165 `max_partitions`. 166 167 Args: 168 max_partitions: Upper bound on the number of partitions. Defaults to 1. 169 axis: Axis along which to partition the variable. Defaults to 0. 170 min_slice_size: Minimum size of the variable slice per partition. Defaults 171 to 256K. 172 bytes_per_string_element: If the `Variable` is of type string, this provides 173 an estimate of how large each scalar in the `Variable` is. 174 175 Returns: 176 A partition function usable as the `partitioner` argument to 177 `variable_scope`, `get_variable`, and `get_partitioned_variable_list`. 178 179 """ 180 def _partitioner(shape, dtype): 181 """Partitioner that partitions list for a variable of given shape and type. 182 183 Ex: Consider partitioning a variable of type float32 with 184 shape=[1024, 1024]. 185 If `max_partitions` >= 16, this function would return 186 [(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1]. 187 If `max_partitions` < 16, this function would return 188 [`max_partitions`, 1]. 189 190 Args: 191 shape: Shape of the variable. 192 dtype: Type of the variable. 193 194 Returns: 195 List of partitions for each axis (currently only one axis can be 196 partitioned). 197 198 Raises: 199 ValueError: If axis to partition along does not exist for the variable. 200 """ 201 if axis >= len(shape): 202 raise ValueError("Can not partition variable along axis %d when shape is " 203 "only %s" % (axis, shape)) 204 if dtype.base_dtype == dtypes.string: 205 bytes_per_element = bytes_per_string_element 206 else: 207 bytes_per_element = dtype.size 208 total_size_bytes = shape.num_elements() * bytes_per_element 209 partitions = total_size_bytes / min_slice_size 210 partitions_list = [1] * len(shape) 211 # We can not partition the variable beyond what its shape or 212 # `max_partitions` allows. 213 partitions_list[axis] = max(1, min(shape[axis].value, 214 max_partitions, 215 int(math.ceil(partitions)))) 216 return partitions_list 217 return _partitioner 218 219 220 @tf_export("fixed_size_partitioner") 221 def fixed_size_partitioner(num_shards, axis=0): 222 """Partitioner to specify a fixed number of shards along given axis. 223 224 Args: 225 num_shards: `int`, number of shards to partition variable. 226 axis: `int`, axis to partition on. 227 228 Returns: 229 A partition function usable as the `partitioner` argument to 230 `variable_scope`, `get_variable`, and `get_partitioned_variable_list`. 231 """ 232 def _partitioner(shape, **unused_args): 233 partitions_list = [1] * len(shape) 234 partitions_list[axis] = min(num_shards, shape[axis].value) 235 return partitions_list 236 return _partitioner 237 238 239 @tf_export("create_partitioned_variables") 240 def create_partitioned_variables( 241 shape, slicing, initializer, dtype=dtypes.float32, 242 trainable=True, collections=None, name=None, reuse=None): 243 """Create a list of partitioned variables according to the given `slicing`. 244 245 Currently only one dimension of the full variable can be sliced, and the 246 full variable can be reconstructed by the concatenation of the returned 247 list along that dimension. 248 249 Args: 250 shape: List of integers. The shape of the full variable. 251 slicing: List of integers. How to partition the variable. 252 Must be of the same length as `shape`. Each value 253 indicate how many slices to create in the corresponding 254 dimension. Presently only one of the values can be more than 1; 255 that is, the variable can only be sliced along one dimension. 256 257 For convenience, The requested number of partitions does not have to 258 divide the corresponding dimension evenly. If it does not, the 259 shapes of the partitions are incremented by 1 starting from partition 260 0 until all slack is absorbed. The adjustment rules may change in the 261 future, but as you can save/restore these variables with different 262 slicing specifications this should not be a problem. 263 initializer: A `Tensor` of shape `shape` or a variable initializer 264 function. If a function, it will be called once for each slice, 265 passing the shape and data type of the slice as parameters. The 266 function must return a tensor with the same shape as the slice. 267 dtype: Type of the variables. Ignored if `initializer` is a `Tensor`. 268 trainable: If True also add all the variables to the graph collection 269 `GraphKeys.TRAINABLE_VARIABLES`. 270 collections: List of graph collections keys to add the variables to. 271 Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 272 name: Optional name for the full variable. Defaults to 273 `"PartitionedVariable"` and gets uniquified automatically. 274 reuse: Boolean or `None`; if `True` and name is set, it would reuse 275 previously created variables. if `False` it will create new variables. 276 if `None`, it would inherit the parent scope reuse. 277 278 Returns: 279 A list of Variables corresponding to the slicing. 280 281 Raises: 282 ValueError: If any of the arguments is malformed. 283 """ 284 logging.warn( 285 "create_partitioned_variables is deprecated. Use " 286 "tf.get_variable with a partitioner set, or " 287 "tf.get_partitioned_variable_list, instead.") 288 289 if len(shape) != len(slicing): 290 raise ValueError("The 'shape' and 'slicing' of a partitioned Variable " 291 "must have the length: shape: %s, slicing: %s" % 292 (shape, slicing)) 293 if len(shape) < 1: 294 raise ValueError("A partitioned Variable must have rank at least 1: " 295 "shape: %s" % shape) 296 297 # Legacy: we are provided the slicing directly, so just pass it to 298 # the partitioner. 299 partitioner = lambda **unused_kwargs: slicing 300 301 with variable_scope.variable_scope( 302 name, "PartitionedVariable", reuse=reuse): 303 # pylint: disable=protected-access 304 partitioned_var = variable_scope._get_partitioned_variable( 305 name=None, 306 shape=shape, 307 dtype=dtype, 308 initializer=initializer, 309 trainable=trainable, 310 partitioner=partitioner, 311 collections=collections) 312 return list(partitioned_var) 313 # pylint: enable=protected-access 314