Home | History | Annotate | Download | only in v4
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
     16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
     17 #include <ctime>
     18 #include <unordered_map>
     19 #include "google/protobuf/any.pb.h"
     20 #include "google/protobuf/wrappers.pb.h"
     21 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
     22 #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/framework/tensor_types.h"
     25 #include "tensorflow/core/lib/random/philox_random.h"
     26 #include "tensorflow/core/lib/random/simple_philox.h"
     27 
     28 namespace tensorflow {
     29 namespace tensorforest {
     30 
     31 typedef TTypes<const float, 2>::ConstTensor DenseStorageType;
     32 typedef TTypes<const int64, 2>::ConstTensor SparseIndicesStorageType;
     33 typedef TTypes<const float, 1>::ConstTensor SparseValuesStorageType;
     34 
     35 class TensorDataSet {
     36  public:
     37   TensorDataSet(const tensorforest::TensorForestDataSpec& input_spec,
     38                 int32 seed)
     39       : dense_data_(nullptr),
     40         sparse_indices_(nullptr),
     41         sparse_values_(nullptr),
     42         input_spec_(input_spec),
     43         split_sampling_random_seed_(seed) {
     44     int column_count = 0;
     45     for (int i = 0; i < input_spec_.dense_size(); ++i) {
     46       for (int j = 0; j < input_spec_.dense(i).size(); ++j) {
     47         decision_trees::FeatureId id;
     48         id.mutable_id()->set_value(strings::StrCat(column_count));
     49         available_features_.push_back(id);
     50         ++column_count;
     51       }
     52     }
     53 
     54     // Set up the random number generator.
     55     if (split_sampling_random_seed_ == 0) {
     56       uint64 time_seed = static_cast<uint64>(std::clock());
     57       single_rand_ = std::unique_ptr<random::PhiloxRandom>(
     58           new random::PhiloxRandom(time_seed));
     59     } else {
     60       single_rand_ = std::unique_ptr<random::PhiloxRandom>(
     61           new random::PhiloxRandom(split_sampling_random_seed_));
     62     }
     63 
     64     rng_ = std::unique_ptr<random::SimplePhilox>(
     65         new random::SimplePhilox(single_rand_.get()));
     66   }
     67   virtual ~TensorDataSet() {}
     68 
     69   void set_input_tensors(const Tensor& dense, const Tensor& sparse_indices,
     70                          const Tensor& sparse_values,
     71                          const Tensor& sparse_shape);
     72 
     73   float get_input_value(int offset, int col) {
     74     return (*dense_data_)(offset, col);
     75   }
     76 
     77   int NumItems() const {
     78     if (dense_data_ != nullptr) {
     79       return dense_data_->dimensions()[0];
     80     } else if (sparse_indices_ != nullptr) {
     81       return sparse_batch_size_;
     82     } else {
     83       return 0;
     84     }
     85   }
     86 
     87   // This looks up a value by example and int32_id, which is much faster than
     88   // GetFeature.
     89   float GetExampleValue(int example,
     90                         const decision_trees::FeatureId& feature_id) const;
     91 
     92   // Same as overload with FeatureId, but if you already have the feature as
     93   // an int32 you can avoid the atoi32.
     94   virtual float GetExampleValue(int example, int32 feature_id) const;
     95 
     96   int num_features() { return available_features_.size(); }
     97 
     98   const Tensor& original_tensor() const { return original_dense_tensor_; }
     99 
    100   bool Decide(const decision_trees::BinaryNode& node, int example) const;
    101 
    102   // Randomly samples a feature from example, returns its id in feature_name,
    103   // the value in bias, and it's type from input_spec in type.
    104   void RandomSample(int example, decision_trees::FeatureId* feature_name,
    105                     float* bias, int* type) const;
    106 
    107  private:
    108   std::unique_ptr<DenseStorageType> dense_data_;
    109   std::unique_ptr<SparseIndicesStorageType> sparse_indices_;
    110   std::unique_ptr<SparseValuesStorageType> sparse_values_;
    111   int sparse_batch_size_;
    112 
    113   Tensor original_dense_tensor_;
    114   const tensorforest::TensorForestDataSpec input_spec_;
    115   std::vector<decision_trees::FeatureId> available_features_;
    116 
    117   int32 split_sampling_random_seed_;
    118   std::unique_ptr<random::PhiloxRandom> single_rand_;
    119   std::unique_ptr<random::SimplePhilox> rng_;
    120 };
    121 }  // namespace tensorforest
    122 }  // namespace tensorflow
    123 
    124 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
    125