Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Datasets for random number generators."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.data.util import nest
     22 from tensorflow.python.data.util import sparse
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.framework import random_seed
     27 from tensorflow.python.framework import tensor_shape
     28 from tensorflow.python.ops import gen_dataset_ops
     29 
     30 
     31 class RandomDataset(dataset_ops.Dataset):
     32   """A `Dataset` of pseudorandom values."""
     33 
     34   def __init__(self, seed=None):
     35     """A `Dataset` of pseudorandom values."""
     36     super(RandomDataset, self).__init__()
     37     seed, seed2 = random_seed.get_seed(seed)
     38     if seed is None:
     39       self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed")
     40     else:
     41       self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed")
     42     if seed2 is None:
     43       self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2")
     44     else:
     45       self._seed2 = ops.convert_to_tensor(
     46           seed2, dtype=dtypes.int64, name="seed2")
     47 
     48   def _as_variant_tensor(self):
     49     return gen_dataset_ops.random_dataset(
     50         seed=self._seed,
     51         seed2=self._seed2,
     52         output_shapes=nest.flatten(
     53             sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
     54         output_types=nest.flatten(
     55             sparse.as_dense_types(self.output_types, self.output_classes)))
     56 
     57   @property
     58   def output_classes(self):
     59     return ops.Tensor
     60 
     61   @property
     62   def output_shapes(self):
     63     return tensor_shape.scalar()
     64 
     65   @property
     66   def output_types(self):
     67     return dtypes.int64
     68