1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" 16 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" 17 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 18 #include "tensorflow/core/lib/strings/numbers.h" 19 20 namespace tensorflow { 21 namespace tensorforest { 22 namespace { 23 24 bool DecideInequalityTest(const decision_trees::InequalityTest& test, 25 float value) { 26 float bias = test.threshold().float_value(); 27 switch (test.type()) { 28 case decision_trees::InequalityTest::LESS_OR_EQUAL: 29 return value <= bias; 30 31 case decision_trees::InequalityTest::LESS_THAN: 32 return value < bias; 33 34 case decision_trees::InequalityTest::GREATER_OR_EQUAL: 35 return value >= bias; 36 37 case decision_trees::InequalityTest::GREATER_THAN: 38 return value > bias; 39 40 default: 41 return false; 42 } 43 } 44 45 bool DecideMatchingValuesTest(const decision_trees::MatchingValuesTest& test, 46 float value) { 47 for (const decision_trees::Value& test_value : test.value()) { 48 if (test_value.float_value() == value) { 49 return true; 50 } 51 } 52 return false; 53 } 54 55 } // namespace 56 57 bool TensorDataSet::Decide(const decision_trees::BinaryNode& node, 58 int example) const { 59 // TODO(gilberth): Support missing values. 60 float val = 0; 61 const auto& test = node.inequality_left_child_test(); 62 63 if (test.has_oblique()) { 64 for (int i = 0; i < test.oblique().features_size(); ++i) { 65 val += test.oblique().weights(i) * 66 GetExampleValue(example, test.oblique().features(i)); 67 } 68 } else { 69 val = GetExampleValue(example, test.feature_id()); 70 } 71 72 if (node.has_inequality_left_child_test()) { 73 return DecideInequalityTest(node.inequality_left_child_test(), val); 74 } else { 75 decision_trees::MatchingValuesTest test; 76 if (node.custom_left_child_test().UnpackTo(&test)) { 77 return DecideMatchingValuesTest(test, val); 78 } else { 79 return false; 80 } 81 } 82 } 83 84 float TensorDataSet::GetExampleValue( 85 int example, const decision_trees::FeatureId& feature_id) const { 86 int32 feature; 87 safe_strto32(feature_id.id().value(), &feature); 88 if (feature >= input_spec_.dense_features_size()) { 89 return FindSparseValue(*sparse_indices_, *sparse_values_, example, feature); 90 } else { 91 return (*dense_data_)(example, feature); 92 } 93 } 94 95 float TensorDataSet::GetExampleValue(int example, int32 feature_id) const { 96 if (feature_id >= input_spec_.dense_features_size()) { 97 return FindSparseValue(*sparse_indices_, *sparse_values_, example, 98 feature_id); 99 } else { 100 return (*dense_data_)(example, feature_id); 101 } 102 } 103 104 void TensorDataSet::set_input_tensors(const Tensor& dense, 105 const Tensor& sparse_indices, 106 const Tensor& sparse_values, 107 const Tensor& sparse_shape) { 108 if (dense.shape().dims() == 2) { 109 dense_data_.reset(new DenseStorageType(dense.tensor<float, 2>())); 110 } 111 if (sparse_indices.shape().dims() == 2) { 112 sparse_indices_.reset( 113 new SparseIndicesStorageType(sparse_indices.tensor<int64, 2>())); 114 sparse_values_.reset( 115 new SparseValuesStorageType(sparse_values.tensor<float, 1>())); 116 sparse_batch_size_ = sparse_shape.tensor<int64, 1>()(0); 117 } 118 original_dense_tensor_ = dense; 119 } 120 121 void TensorDataSet::RandomSample(int example, 122 decision_trees::FeatureId* feature_id, 123 float* bias, int* type) const { 124 int32 num_total_features = input_spec_.dense_features_size(); 125 int64 sparse_input_start; 126 if (sparse_indices_ != nullptr) { 127 const int32 num_sparse = tensorforest::GetNumSparseFeatures( 128 *sparse_indices_, example, &sparse_input_start); 129 if (sparse_input_start >= 0) { 130 num_total_features += num_sparse; 131 } 132 } 133 int rand_feature = rng_->Uniform(num_total_features); 134 if (rand_feature < available_features_.size()) { // it's dense. 135 *feature_id = available_features_[rand_feature]; 136 *type = input_spec_.GetDenseFeatureType(rand_feature); 137 } else { 138 const int32 sparse_index = 139 sparse_input_start + rand_feature - input_spec_.dense_features_size(); 140 const int32 saved_index = 141 (*sparse_indices_)(sparse_index, 1) + input_spec_.dense_features_size(); 142 *feature_id = decision_trees::FeatureId(); 143 feature_id->mutable_id()->set_value(strings::StrCat(saved_index)); 144 145 // TODO(gilberth): Remove this shortcut when different sparse types are 146 // allowed. 147 *type = input_spec_.sparse(0).original_type(); 148 } 149 150 *bias = GetExampleValue(example, *feature_id); 151 } 152 153 } // namespace tensorforest 154 } // namespace tensorflow 155