1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """For seeding individual ops based on a graph-level seed. 17 """ 18 19 from __future__ import absolute_import 20 from __future__ import division 21 from __future__ import print_function 22 23 from tensorflow.python.eager import context 24 from tensorflow.python.framework import ops 25 from tensorflow.python.util.tf_export import tf_export 26 27 28 DEFAULT_GRAPH_SEED = 87654321 29 _MAXINT32 = 2**31 - 1 30 31 32 def _truncate_seed(seed): 33 return seed % _MAXINT32 # Truncate to fit into 32-bit integer 34 35 36 @tf_export('get_seed') 37 def get_seed(op_seed): 38 """Returns the local seeds an operation should use given an op-specific seed. 39 40 Given operation-specific seed, `op_seed`, this helper function returns two 41 seeds derived from graph-level and op-level seeds. Many random operations 42 internally use the two seeds to allow user to change the seed globally for a 43 graph, or for only specific operations. 44 45 For details on how the graph-level seed interacts with op seeds, see 46 @{tf.set_random_seed}. 47 48 Args: 49 op_seed: integer. 50 51 Returns: 52 A tuple of two integers that should be used for the local seed of this 53 operation. 54 """ 55 is_graph_mode = context.in_graph_mode() 56 57 if is_graph_mode: 58 global_seed = ops.get_default_graph().seed 59 else: 60 global_seed = context.global_seed() 61 62 if global_seed is not None: 63 if op_seed is None: 64 # pylint: disable=protected-access 65 if is_graph_mode: 66 op_seed = ops.get_default_graph()._last_id 67 else: 68 op_seed = context.internal_operation_seed() 69 70 seeds = _truncate_seed(global_seed), _truncate_seed(op_seed) 71 else: 72 if op_seed is not None: 73 seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed) 74 else: 75 seeds = None, None 76 # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would 77 # be unexpected since Python docs say nondeterminism is (None, None). 78 if seeds == (0, 0): 79 return (0, _MAXINT32) 80 return seeds 81 82 83 @tf_export('set_random_seed') 84 def set_random_seed(seed): 85 """Sets the graph-level random seed. 86 87 Operations that rely on a random seed actually derive it from two seeds: 88 the graph-level and operation-level seeds. This sets the graph-level seed. 89 90 Its interactions with operation-level seeds is as follows: 91 92 1. If neither the graph-level nor the operation seed is set: 93 A random seed is used for this op. 94 2. If the graph-level seed is set, but the operation seed is not: 95 The system deterministically picks an operation seed in conjunction 96 with the graph-level seed so that it gets a unique random sequence. 97 3. If the graph-level seed is not set, but the operation seed is set: 98 A default graph-level seed and the specified operation seed are used to 99 determine the random sequence. 100 4. If both the graph-level and the operation seed are set: 101 Both seeds are used in conjunction to determine the random sequence. 102 103 To illustrate the user-visible effects, consider these examples: 104 105 To generate different sequences across sessions, set neither 106 graph-level nor op-level seeds: 107 108 ```python 109 a = tf.random_uniform([1]) 110 b = tf.random_normal([1]) 111 112 print("Session 1") 113 with tf.Session() as sess1: 114 print(sess1.run(a)) # generates 'A1' 115 print(sess1.run(a)) # generates 'A2' 116 print(sess1.run(b)) # generates 'B1' 117 print(sess1.run(b)) # generates 'B2' 118 119 print("Session 2") 120 with tf.Session() as sess2: 121 print(sess2.run(a)) # generates 'A3' 122 print(sess2.run(a)) # generates 'A4' 123 print(sess2.run(b)) # generates 'B3' 124 print(sess2.run(b)) # generates 'B4' 125 ``` 126 127 To generate the same repeatable sequence for an op across sessions, set the 128 seed for the op: 129 130 ```python 131 a = tf.random_uniform([1], seed=1) 132 b = tf.random_normal([1]) 133 134 # Repeatedly running this block with the same graph will generate the same 135 # sequence of values for 'a', but different sequences of values for 'b'. 136 print("Session 1") 137 with tf.Session() as sess1: 138 print(sess1.run(a)) # generates 'A1' 139 print(sess1.run(a)) # generates 'A2' 140 print(sess1.run(b)) # generates 'B1' 141 print(sess1.run(b)) # generates 'B2' 142 143 print("Session 2") 144 with tf.Session() as sess2: 145 print(sess2.run(a)) # generates 'A1' 146 print(sess2.run(a)) # generates 'A2' 147 print(sess2.run(b)) # generates 'B3' 148 print(sess2.run(b)) # generates 'B4' 149 ``` 150 151 To make the random sequences generated by all ops be repeatable across 152 sessions, set a graph-level seed: 153 154 ```python 155 tf.set_random_seed(1234) 156 a = tf.random_uniform([1]) 157 b = tf.random_normal([1]) 158 159 # Repeatedly running this block with the same graph will generate the same 160 # sequences of 'a' and 'b'. 161 print("Session 1") 162 with tf.Session() as sess1: 163 print(sess1.run(a)) # generates 'A1' 164 print(sess1.run(a)) # generates 'A2' 165 print(sess1.run(b)) # generates 'B1' 166 print(sess1.run(b)) # generates 'B2' 167 168 print("Session 2") 169 with tf.Session() as sess2: 170 print(sess2.run(a)) # generates 'A1' 171 print(sess2.run(a)) # generates 'A2' 172 print(sess2.run(b)) # generates 'B1' 173 print(sess2.run(b)) # generates 'B2' 174 ``` 175 176 Args: 177 seed: integer. 178 """ 179 if context.in_graph_mode(): 180 ops.get_default_graph().seed = seed 181 else: 182 context.set_global_seed(seed) 183