1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ 16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ 17 #include <ctime> 18 #include <unordered_map> 19 #include "google/protobuf/any.pb.h" 20 #include "google/protobuf/wrappers.pb.h" 21 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" 22 #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor_types.h" 25 #include "tensorflow/core/lib/random/philox_random.h" 26 #include "tensorflow/core/lib/random/simple_philox.h" 27 28 namespace tensorflow { 29 namespace tensorforest { 30 31 typedef TTypes<const float, 2>::ConstTensor DenseStorageType; 32 typedef TTypes<const int64, 2>::ConstTensor SparseIndicesStorageType; 33 typedef TTypes<const float, 1>::ConstTensor SparseValuesStorageType; 34 35 class TensorDataSet { 36 public: 37 TensorDataSet(const tensorforest::TensorForestDataSpec& input_spec, 38 int32 seed) 39 : dense_data_(nullptr), 40 sparse_indices_(nullptr), 41 sparse_values_(nullptr), 42 input_spec_(input_spec), 43 split_sampling_random_seed_(seed) { 44 int column_count = 0; 45 for (int i = 0; i < input_spec_.dense_size(); ++i) { 46 for (int j = 0; j < input_spec_.dense(i).size(); ++j) { 47 decision_trees::FeatureId id; 48 id.mutable_id()->set_value(strings::StrCat(column_count)); 49 available_features_.push_back(id); 50 ++column_count; 51 } 52 } 53 54 // Set up the random number generator. 55 if (split_sampling_random_seed_ == 0) { 56 uint64 time_seed = static_cast<uint64>(std::clock()); 57 single_rand_ = std::unique_ptr<random::PhiloxRandom>( 58 new random::PhiloxRandom(time_seed)); 59 } else { 60 single_rand_ = std::unique_ptr<random::PhiloxRandom>( 61 new random::PhiloxRandom(split_sampling_random_seed_)); 62 } 63 64 rng_ = std::unique_ptr<random::SimplePhilox>( 65 new random::SimplePhilox(single_rand_.get())); 66 } 67 virtual ~TensorDataSet() {} 68 69 void set_input_tensors(const Tensor& dense, const Tensor& sparse_indices, 70 const Tensor& sparse_values, 71 const Tensor& sparse_shape); 72 73 float get_input_value(int offset, int col) { 74 return (*dense_data_)(offset, col); 75 } 76 77 int NumItems() const { 78 if (dense_data_ != nullptr) { 79 return dense_data_->dimensions()[0]; 80 } else if (sparse_indices_ != nullptr) { 81 return sparse_batch_size_; 82 } else { 83 return 0; 84 } 85 } 86 87 // This looks up a value by example and int32_id, which is much faster than 88 // GetFeature. 89 float GetExampleValue(int example, 90 const decision_trees::FeatureId& feature_id) const; 91 92 // Same as overload with FeatureId, but if you already have the feature as 93 // an int32 you can avoid the atoi32. 94 virtual float GetExampleValue(int example, int32 feature_id) const; 95 96 int num_features() { return available_features_.size(); } 97 98 const Tensor& original_tensor() const { return original_dense_tensor_; } 99 100 bool Decide(const decision_trees::BinaryNode& node, int example) const; 101 102 // Randomly samples a feature from example, returns its id in feature_name, 103 // the value in bias, and it's type from input_spec in type. 104 void RandomSample(int example, decision_trees::FeatureId* feature_name, 105 float* bias, int* type) const; 106 107 private: 108 std::unique_ptr<DenseStorageType> dense_data_; 109 std::unique_ptr<SparseIndicesStorageType> sparse_indices_; 110 std::unique_ptr<SparseValuesStorageType> sparse_values_; 111 int sparse_batch_size_; 112 113 Tensor original_dense_tensor_; 114 const tensorforest::TensorForestDataSpec input_spec_; 115 std::vector<decision_trees::FeatureId> available_features_; 116 117 int32 split_sampling_random_seed_; 118 std::unique_ptr<random::PhiloxRandom> single_rand_; 119 std::unique_ptr<random::SimplePhilox> rng_; 120 }; 121 } // namespace tensorforest 122 } // namespace tensorflow 123 124 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ 125