Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """For seeding individual ops based on a graph-level seed.
     17 """
     18 
     19 from __future__ import absolute_import
     20 from __future__ import division
     21 from __future__ import print_function
     22 
     23 from tensorflow.python.eager import context
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.util.tf_export import tf_export
     26 
     27 
     28 DEFAULT_GRAPH_SEED = 87654321
     29 _MAXINT32 = 2**31 - 1
     30 
     31 
     32 def _truncate_seed(seed):
     33   return seed % _MAXINT32  # Truncate to fit into 32-bit integer
     34 
     35 
     36 @tf_export('get_seed')
     37 def get_seed(op_seed):
     38   """Returns the local seeds an operation should use given an op-specific seed.
     39 
     40   Given operation-specific seed, `op_seed`, this helper function returns two
     41   seeds derived from graph-level and op-level seeds. Many random operations
     42   internally use the two seeds to allow user to change the seed globally for a
     43   graph, or for only specific operations.
     44 
     45   For details on how the graph-level seed interacts with op seeds, see
     46   @{tf.set_random_seed}.
     47 
     48   Args:
     49     op_seed: integer.
     50 
     51   Returns:
     52     A tuple of two integers that should be used for the local seed of this
     53     operation.
     54   """
     55   is_graph_mode = context.in_graph_mode()
     56 
     57   if is_graph_mode:
     58     global_seed = ops.get_default_graph().seed
     59   else:
     60     global_seed = context.global_seed()
     61 
     62   if global_seed is not None:
     63     if op_seed is None:
     64       # pylint: disable=protected-access
     65       if is_graph_mode:
     66         op_seed = ops.get_default_graph()._last_id
     67       else:
     68         op_seed = context.internal_operation_seed()
     69 
     70     seeds = _truncate_seed(global_seed), _truncate_seed(op_seed)
     71   else:
     72     if op_seed is not None:
     73       seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed)
     74     else:
     75       seeds = None, None
     76   # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would
     77   # be unexpected since Python docs say nondeterminism is (None, None).
     78   if seeds == (0, 0):
     79     return (0, _MAXINT32)
     80   return seeds
     81 
     82 
     83 @tf_export('set_random_seed')
     84 def set_random_seed(seed):
     85   """Sets the graph-level random seed.
     86 
     87   Operations that rely on a random seed actually derive it from two seeds:
     88   the graph-level and operation-level seeds. This sets the graph-level seed.
     89 
     90   Its interactions with operation-level seeds is as follows:
     91 
     92     1. If neither the graph-level nor the operation seed is set:
     93       A random seed is used for this op.
     94     2. If the graph-level seed is set, but the operation seed is not:
     95       The system deterministically picks an operation seed in conjunction
     96       with the graph-level seed so that it gets a unique random sequence.
     97     3. If the graph-level seed is not set, but the operation seed is set:
     98       A default graph-level seed and the specified operation seed are used to
     99       determine the random sequence.
    100     4. If both the graph-level and the operation seed are set:
    101       Both seeds are used in conjunction to determine the random sequence.
    102 
    103   To illustrate the user-visible effects, consider these examples:
    104 
    105   To generate different sequences across sessions, set neither
    106   graph-level nor op-level seeds:
    107 
    108   ```python
    109   a = tf.random_uniform([1])
    110   b = tf.random_normal([1])
    111 
    112   print("Session 1")
    113   with tf.Session() as sess1:
    114     print(sess1.run(a))  # generates 'A1'
    115     print(sess1.run(a))  # generates 'A2'
    116     print(sess1.run(b))  # generates 'B1'
    117     print(sess1.run(b))  # generates 'B2'
    118 
    119   print("Session 2")
    120   with tf.Session() as sess2:
    121     print(sess2.run(a))  # generates 'A3'
    122     print(sess2.run(a))  # generates 'A4'
    123     print(sess2.run(b))  # generates 'B3'
    124     print(sess2.run(b))  # generates 'B4'
    125   ```
    126 
    127   To generate the same repeatable sequence for an op across sessions, set the
    128   seed for the op:
    129 
    130   ```python
    131   a = tf.random_uniform([1], seed=1)
    132   b = tf.random_normal([1])
    133 
    134   # Repeatedly running this block with the same graph will generate the same
    135   # sequence of values for 'a', but different sequences of values for 'b'.
    136   print("Session 1")
    137   with tf.Session() as sess1:
    138     print(sess1.run(a))  # generates 'A1'
    139     print(sess1.run(a))  # generates 'A2'
    140     print(sess1.run(b))  # generates 'B1'
    141     print(sess1.run(b))  # generates 'B2'
    142 
    143   print("Session 2")
    144   with tf.Session() as sess2:
    145     print(sess2.run(a))  # generates 'A1'
    146     print(sess2.run(a))  # generates 'A2'
    147     print(sess2.run(b))  # generates 'B3'
    148     print(sess2.run(b))  # generates 'B4'
    149   ```
    150 
    151   To make the random sequences generated by all ops be repeatable across
    152   sessions, set a graph-level seed:
    153 
    154   ```python
    155   tf.set_random_seed(1234)
    156   a = tf.random_uniform([1])
    157   b = tf.random_normal([1])
    158 
    159   # Repeatedly running this block with the same graph will generate the same
    160   # sequences of 'a' and 'b'.
    161   print("Session 1")
    162   with tf.Session() as sess1:
    163     print(sess1.run(a))  # generates 'A1'
    164     print(sess1.run(a))  # generates 'A2'
    165     print(sess1.run(b))  # generates 'B1'
    166     print(sess1.run(b))  # generates 'B2'
    167 
    168   print("Session 2")
    169   with tf.Session() as sess2:
    170     print(sess2.run(a))  # generates 'A1'
    171     print(sess2.run(a))  # generates 'A2'
    172     print(sess2.run(b))  # generates 'B1'
    173     print(sess2.run(b))  # generates 'B2'
    174   ```
    175 
    176   Args:
    177     seed: integer.
    178   """
    179   if context.in_graph_mode():
    180     ops.get_default_graph().seed = seed
    181   else:
    182     context.set_global_seed(seed)
    183