1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Datasets for random number generators.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.data.util import nest 22 from tensorflow.python.data.util import sparse 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import ops 26 from tensorflow.python.framework import random_seed 27 from tensorflow.python.framework import tensor_shape 28 from tensorflow.python.ops import gen_dataset_ops 29 30 31 class RandomDataset(dataset_ops.Dataset): 32 """A `Dataset` of pseudorandom values.""" 33 34 def __init__(self, seed=None): 35 """A `Dataset` of pseudorandom values.""" 36 super(RandomDataset, self).__init__() 37 seed, seed2 = random_seed.get_seed(seed) 38 if seed is None: 39 self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") 40 else: 41 self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") 42 if seed2 is None: 43 self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") 44 else: 45 self._seed2 = ops.convert_to_tensor( 46 seed2, dtype=dtypes.int64, name="seed2") 47 48 def _as_variant_tensor(self): 49 return gen_dataset_ops.random_dataset( 50 seed=self._seed, 51 seed2=self._seed2, 52 output_shapes=nest.flatten( 53 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 54 output_types=nest.flatten( 55 sparse.as_dense_types(self.output_types, self.output_classes))) 56 57 @property 58 def output_classes(self): 59 return ops.Tensor 60 61 @property 62 def output_shapes(self): 63 return tensor_shape.scalar() 64 65 @property 66 def output_types(self): 67 return dtypes.int64 68