Home | History | Annotate | Download | only in v4
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
     16 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
     17 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
     18 #include "tensorflow/core/lib/strings/numbers.h"
     19 
     20 namespace tensorflow {
     21 namespace tensorforest {
     22 namespace {
     23 
     24 bool DecideInequalityTest(const decision_trees::InequalityTest& test,
     25                           float value) {
     26   float bias = test.threshold().float_value();
     27   switch (test.type()) {
     28     case decision_trees::InequalityTest::LESS_OR_EQUAL:
     29       return value <= bias;
     30 
     31     case decision_trees::InequalityTest::LESS_THAN:
     32       return value < bias;
     33 
     34     case decision_trees::InequalityTest::GREATER_OR_EQUAL:
     35       return value >= bias;
     36 
     37     case decision_trees::InequalityTest::GREATER_THAN:
     38       return value > bias;
     39 
     40     default:
     41       return false;
     42   }
     43 }
     44 
     45 bool DecideMatchingValuesTest(const decision_trees::MatchingValuesTest& test,
     46                               float value) {
     47   for (const decision_trees::Value& test_value : test.value()) {
     48     if (test_value.float_value() == value) {
     49       return true;
     50     }
     51   }
     52   return false;
     53 }
     54 
     55 }  // namespace
     56 
     57 bool TensorDataSet::Decide(const decision_trees::BinaryNode& node,
     58                            int example) const {
     59   // TODO(gilberth): Support missing values.
     60   float val = 0;
     61   const auto& test = node.inequality_left_child_test();
     62 
     63   if (test.has_oblique()) {
     64     for (int i = 0; i < test.oblique().features_size(); ++i) {
     65       val += test.oblique().weights(i) *
     66              GetExampleValue(example, test.oblique().features(i));
     67     }
     68   } else {
     69     val = GetExampleValue(example, test.feature_id());
     70   }
     71 
     72   if (node.has_inequality_left_child_test()) {
     73     return DecideInequalityTest(node.inequality_left_child_test(), val);
     74   } else {
     75     decision_trees::MatchingValuesTest test;
     76     if (node.custom_left_child_test().UnpackTo(&test)) {
     77       return DecideMatchingValuesTest(test, val);
     78     } else {
     79       return false;
     80     }
     81   }
     82 }
     83 
     84 float TensorDataSet::GetExampleValue(
     85     int example, const decision_trees::FeatureId& feature_id) const {
     86   int32 feature;
     87   safe_strto32(feature_id.id().value(), &feature);
     88   if (feature >= input_spec_.dense_features_size()) {
     89     return FindSparseValue(*sparse_indices_, *sparse_values_, example, feature);
     90   } else {
     91     return (*dense_data_)(example, feature);
     92   }
     93 }
     94 
     95 float TensorDataSet::GetExampleValue(int example, int32 feature_id) const {
     96   if (feature_id >= input_spec_.dense_features_size()) {
     97     return FindSparseValue(*sparse_indices_, *sparse_values_, example,
     98                            feature_id);
     99   } else {
    100     return (*dense_data_)(example, feature_id);
    101   }
    102 }
    103 
    104 void TensorDataSet::set_input_tensors(const Tensor& dense,
    105                                       const Tensor& sparse_indices,
    106                                       const Tensor& sparse_values,
    107                                       const Tensor& sparse_shape) {
    108   if (dense.shape().dims() == 2) {
    109     dense_data_.reset(new DenseStorageType(dense.tensor<float, 2>()));
    110   }
    111   if (sparse_indices.shape().dims() == 2) {
    112     sparse_indices_.reset(
    113         new SparseIndicesStorageType(sparse_indices.tensor<int64, 2>()));
    114     sparse_values_.reset(
    115         new SparseValuesStorageType(sparse_values.tensor<float, 1>()));
    116     sparse_batch_size_ = sparse_shape.tensor<int64, 1>()(0);
    117   }
    118   original_dense_tensor_ = dense;
    119 }
    120 
    121 void TensorDataSet::RandomSample(int example,
    122                                  decision_trees::FeatureId* feature_id,
    123                                  float* bias, int* type) const {
    124   int32 num_total_features = input_spec_.dense_features_size();
    125   int64 sparse_input_start;
    126   if (sparse_indices_ != nullptr) {
    127     const int32 num_sparse = tensorforest::GetNumSparseFeatures(
    128         *sparse_indices_, example, &sparse_input_start);
    129     if (sparse_input_start >= 0) {
    130       num_total_features += num_sparse;
    131     }
    132   }
    133   int rand_feature = rng_->Uniform(num_total_features);
    134   if (rand_feature < available_features_.size()) {  // it's dense.
    135     *feature_id = available_features_[rand_feature];
    136     *type = input_spec_.GetDenseFeatureType(rand_feature);
    137   } else {
    138     const int32 sparse_index =
    139         sparse_input_start + rand_feature - input_spec_.dense_features_size();
    140     const int32 saved_index =
    141         (*sparse_indices_)(sparse_index, 1) + input_spec_.dense_features_size();
    142     *feature_id = decision_trees::FeatureId();
    143     feature_id->mutable_id()->set_value(strings::StrCat(saved_index));
    144 
    145     // TODO(gilberth): Remove this shortcut when different sparse types are
    146     // allowed.
    147     *type = input_spec_.sparse(0).original_type();
    148   }
    149 
    150   *bias = GetExampleValue(example, *feature_id);
    151 }
    152 
    153 }  // namespace tensorforest
    154 }  // namespace tensorflow
    155