Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Helper functions for creating partitioned variables.
     17 
     18 This is a convenient abstraction to partition a large variable across
     19 multiple smaller variables that can be assigned to different devices.
     20 
     21 The full variable can be reconstructed by concatenating the smaller variables.
     22 Using partitioned variables instead of a single variable is mostly a
     23 performance choice.  It however also has an impact on:
     24 
     25 1. Random initialization, as the random number generator is called once per
     26    slice
     27 2. Updates, as they happen in parallel across slices
     28 
     29 A key design goal is to allow a different graph to repartition a variable
     30 with the same name but different slicings, including possibly no partitions.
     31 
     32 TODO(touts): If an initializer provides a seed, the seed must be changed
     33 deterministically for each slice, maybe by adding one to it, otherwise each
     34 slice will use the same values.  Maybe this can be done by passing the
     35 slice offsets to the initializer functions.
     36 
     37 Typical usage:
     38 
     39 ```python
     40 # Create a list of partitioned variables with:
     41 vs = create_partitioned_variables(
     42     <shape>, <slicing>, <initializer>, name=<optional-name>)
     43 
     44 # Pass the list as inputs to embedding_lookup for sharded, parallel lookup:
     45 y = embedding_lookup(vs, ids, partition_strategy="div")
     46 
     47 # Or fetch the variables in parallel to speed up large matmuls:
     48 z = matmul(x, concat(slice_dim, vs))
     49 ```
     50 """
     51 from __future__ import absolute_import
     52 from __future__ import division
     53 from __future__ import print_function
     54 
     55 import math
     56 
     57 from tensorflow.python.framework import dtypes
     58 from tensorflow.python.framework import tensor_shape
     59 from tensorflow.python.ops import variable_scope
     60 from tensorflow.python.platform import tf_logging as logging
     61 from tensorflow.python.util.tf_export import tf_export
     62 
     63 __all__ = [
     64     "create_partitioned_variables",
     65     "variable_axis_size_partitioner",
     66     "min_max_variable_partitioner",
     67     "fixed_size_partitioner",
     68 ]
     69 
     70 
     71 @tf_export("variable_axis_size_partitioner")
     72 def variable_axis_size_partitioner(
     73     max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
     74   """Get a partitioner for VariableScope to keep shards below `max_shard_bytes`.
     75 
     76   This partitioner will shard a Variable along one axis, attempting to keep
     77   the maximum shard size below `max_shard_bytes`.  In practice, this is not
     78   always possible when sharding along only one axis.  When this happens,
     79   this axis is sharded as much as possible (i.e., every dimension becomes
     80   a separate shard).
     81 
     82   If the partitioner hits the `max_shards` limit, then each shard may end up
     83   larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
     84   limit on the number of shards is enforced.
     85 
     86   One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost
     87   `64MB`, to keep below the protobuf byte limit.
     88 
     89   Args:
     90     max_shard_bytes: The maximum size any given shard is allowed to be.
     91     axis: The axis to partition along.  Default: outermost axis.
     92     bytes_per_string_element: If the `Variable` is of type string, this provides
     93       an estimate of how large each scalar in the `Variable` is.
     94     max_shards: The maximum number of shards in int created taking precedence
     95       over `max_shard_bytes`.
     96 
     97   Returns:
     98     A partition function usable as the `partitioner` argument to
     99     `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
    100 
    101   Raises:
    102     ValueError: If any of the byte counts are non-positive.
    103   """
    104   if max_shard_bytes < 1 or bytes_per_string_element < 1:
    105     raise ValueError(
    106         "Both max_shard_bytes and bytes_per_string_element must be positive.")
    107   if max_shards and max_shards < 1:
    108     raise ValueError(
    109         "max_shards must be positive.")
    110 
    111   def _partitioner(shape, dtype):
    112     """Partitioner that partitions shards to have max_shard_bytes total size.
    113 
    114     Args:
    115       shape: A `TensorShape`.
    116       dtype: A `DType`.
    117 
    118     Returns:
    119       A tuple representing how much to slice each axis in shape.
    120 
    121     Raises:
    122       ValueError: If shape is not a fully defined `TensorShape` or dtype is not
    123         a `DType`.
    124     """
    125     if not isinstance(shape, tensor_shape.TensorShape):
    126       raise ValueError("shape is not a TensorShape: %s" % shape)
    127     if not shape.is_fully_defined():
    128       raise ValueError("shape is not fully defined: %s" % shape)
    129     if not isinstance(dtype, dtypes.DType):
    130       raise ValueError("dtype is not a DType: %s" % dtype)
    131 
    132     if dtype.base_dtype == dtypes.string:
    133       element_size = bytes_per_string_element
    134     else:
    135       element_size = dtype.size
    136 
    137     partitions = [1] * shape.ndims
    138     bytes_per_slice = 1.0 * (
    139         shape.num_elements() / shape[axis].value) * element_size
    140     # How many slices can we fit on one shard of size at most max_shard_bytes?
    141     # At least one slice is required.
    142     slices_per_shard = max(1, math.floor(max_shard_bytes / bytes_per_slice))
    143     # How many shards do we need for axis given that each shard fits
    144     # slices_per_shard slices from a total of shape[axis].value slices?
    145     axis_shards = int(math.ceil(1.0 * shape[axis].value / slices_per_shard))
    146     if max_shards:
    147       axis_shards = min(max_shards, axis_shards)
    148 
    149     partitions[axis] = axis_shards
    150 
    151     return partitions
    152 
    153   return _partitioner
    154 
    155 
    156 @tf_export("min_max_variable_partitioner")
    157 def min_max_variable_partitioner(max_partitions=1, axis=0,
    158                                  min_slice_size=256 << 10,
    159                                  bytes_per_string_element=16):
    160   """Partitioner to allocate minimum size per slice.
    161 
    162   Returns a partitioner that partitions the variable of given shape and dtype
    163   such that each partition has a minimum of `min_slice_size` slice of the
    164   variable. The maximum number of such partitions (upper bound) is given by
    165   `max_partitions`.
    166 
    167   Args:
    168     max_partitions: Upper bound on the number of partitions. Defaults to 1.
    169     axis: Axis along which to partition the variable. Defaults to 0.
    170     min_slice_size: Minimum size of the variable slice per partition. Defaults
    171       to 256K.
    172     bytes_per_string_element: If the `Variable` is of type string, this provides
    173       an estimate of how large each scalar in the `Variable` is.
    174 
    175   Returns:
    176     A partition function usable as the `partitioner` argument to
    177     `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
    178 
    179   """
    180   def _partitioner(shape, dtype):
    181     """Partitioner that partitions list for a variable of given shape and type.
    182 
    183     Ex: Consider partitioning a variable of type float32 with
    184       shape=[1024, 1024].
    185       If `max_partitions` >= 16, this function would return
    186         [(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1].
    187       If `max_partitions` < 16, this function would return
    188         [`max_partitions`, 1].
    189 
    190     Args:
    191       shape: Shape of the variable.
    192       dtype: Type of the variable.
    193 
    194     Returns:
    195       List of partitions for each axis (currently only one axis can be
    196       partitioned).
    197 
    198     Raises:
    199       ValueError: If axis to partition along does not exist for the variable.
    200     """
    201     if axis >= len(shape):
    202       raise ValueError("Can not partition variable along axis %d when shape is "
    203                        "only %s" % (axis, shape))
    204     if dtype.base_dtype == dtypes.string:
    205       bytes_per_element = bytes_per_string_element
    206     else:
    207       bytes_per_element = dtype.size
    208     total_size_bytes = shape.num_elements() * bytes_per_element
    209     partitions = total_size_bytes / min_slice_size
    210     partitions_list = [1] * len(shape)
    211     # We can not partition the variable beyond what its shape or
    212     # `max_partitions` allows.
    213     partitions_list[axis] = max(1, min(shape[axis].value,
    214                                        max_partitions,
    215                                        int(math.ceil(partitions))))
    216     return partitions_list
    217   return _partitioner
    218 
    219 
    220 @tf_export("fixed_size_partitioner")
    221 def fixed_size_partitioner(num_shards, axis=0):
    222   """Partitioner to specify a fixed number of shards along given axis.
    223 
    224   Args:
    225     num_shards: `int`, number of shards to partition variable.
    226     axis: `int`, axis to partition on.
    227 
    228   Returns:
    229     A partition function usable as the `partitioner` argument to
    230     `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
    231   """
    232   def _partitioner(shape, **unused_args):
    233     partitions_list = [1] * len(shape)
    234     partitions_list[axis] = min(num_shards, shape[axis].value)
    235     return partitions_list
    236   return _partitioner
    237 
    238 
    239 @tf_export("create_partitioned_variables")
    240 def create_partitioned_variables(
    241     shape, slicing, initializer, dtype=dtypes.float32,
    242     trainable=True, collections=None, name=None, reuse=None):
    243   """Create a list of partitioned variables according to the given `slicing`.
    244 
    245   Currently only one dimension of the full variable can be sliced, and the
    246   full variable can be reconstructed by the concatenation of the returned
    247   list along that dimension.
    248 
    249   Args:
    250     shape: List of integers.  The shape of the full variable.
    251     slicing: List of integers.  How to partition the variable.
    252       Must be of the same length as `shape`.  Each value
    253       indicate how many slices to create in the corresponding
    254       dimension.  Presently only one of the values can be more than 1;
    255       that is, the variable can only be sliced along one dimension.
    256 
    257       For convenience, The requested number of partitions does not have to
    258       divide the corresponding dimension evenly.  If it does not, the
    259       shapes of the partitions are incremented by 1 starting from partition
    260       0 until all slack is absorbed.  The adjustment rules may change in the
    261       future, but as you can save/restore these variables with different
    262       slicing specifications this should not be a problem.
    263     initializer: A `Tensor` of shape `shape` or a variable initializer
    264       function.  If a function, it will be called once for each slice,
    265       passing the shape and data type of the slice as parameters.  The
    266       function must return a tensor with the same shape as the slice.
    267     dtype: Type of the variables. Ignored if `initializer` is a `Tensor`.
    268     trainable: If True also add all the variables to the graph collection
    269       `GraphKeys.TRAINABLE_VARIABLES`.
    270     collections: List of graph collections keys to add the variables to.
    271       Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
    272     name: Optional name for the full variable.  Defaults to
    273       `"PartitionedVariable"` and gets uniquified automatically.
    274     reuse: Boolean or `None`; if `True` and name is set, it would reuse
    275       previously created variables. if `False` it will create new variables.
    276       if `None`, it would inherit the parent scope reuse.
    277 
    278   Returns:
    279     A list of Variables corresponding to the slicing.
    280 
    281   Raises:
    282     ValueError: If any of the arguments is malformed.
    283   """
    284   logging.warn(
    285       "create_partitioned_variables is deprecated.  Use "
    286       "tf.get_variable with a partitioner set, or "
    287       "tf.get_partitioned_variable_list, instead.")
    288 
    289   if len(shape) != len(slicing):
    290     raise ValueError("The 'shape' and 'slicing' of a partitioned Variable "
    291                      "must have the length: shape: %s, slicing: %s" %
    292                      (shape, slicing))
    293   if len(shape) < 1:
    294     raise ValueError("A partitioned Variable must have rank at least 1: "
    295                      "shape: %s" % shape)
    296 
    297   # Legacy: we are provided the slicing directly, so just pass it to
    298   # the partitioner.
    299   partitioner = lambda **unused_kwargs: slicing
    300 
    301   with variable_scope.variable_scope(
    302       name, "PartitionedVariable", reuse=reuse):
    303     # pylint: disable=protected-access
    304     partitioned_var = variable_scope._get_partitioned_variable(
    305         name=None,
    306         shape=shape,
    307         dtype=dtype,
    308         initializer=initializer,
    309         trainable=trainable,
    310         partitioner=partitioner,
    311         collections=collections)
    312     return list(partitioned_var)
    313     # pylint: enable=protected-access
    314