1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h" 16 17 #include "tensorflow/core/framework/graph.pb.h" 18 #include "tensorflow/core/lib/io/path.h" 19 #include "tensorflow/core/platform/env.h" 20 21 namespace tensorflow { 22 namespace tensorforest { 23 24 // Names of ops in the graph to run. 25 constexpr char kInitializeOp[] = "init"; 26 constexpr char kAddExampleOp[] = "add_example"; 27 constexpr char kSplitScoreName[] = "split_score"; 28 constexpr char kGetSplitName[] = "get_split"; 29 constexpr char kGetLeftStatsName[] = "get_left_stats"; 30 constexpr char kGetRightStatsName[] = "get_right_stats"; 31 32 // Names of files written by python graph builder. 33 constexpr char kGraphFilename[] = "graph"; 34 constexpr char kSaverDefFilename[] = "saver"; 35 constexpr char kMetaDefFilename[] = "meta"; 36 37 // Names of Tensor inputs. 38 constexpr char kFeaturesName[] = "features"; 39 constexpr char kInputDataName[] = "input_data"; 40 constexpr char kTargetsName[] = "targets"; 41 constexpr char kExamplesName[] = "examples"; 42 43 constexpr char kNoOp[] = "none"; 44 45 CandidateGraphRunner::CandidateGraphRunner( 46 const string& graph_dir, const decision_trees::BinaryNode& split) 47 : split_(split) { 48 // read graph from file. 49 GraphDef graph_def; 50 TF_CHECK_OK(ReadBinaryProto( 51 Env::Default(), io::JoinPath(graph_dir, kGraphFilename), &graph_def)) 52 << "Could not read graph def."; 53 54 // create session. 55 session_.reset(::tensorflow::NewSession(SessionOptions())); 56 TF_CHECK_OK(session_->Create(graph_def)) << "Failed to create session"; 57 58 // Features don't change, store them in a tensor. 59 const auto& oblique = split.inequality_left_child_test().oblique(); 60 const int32 feat_size = oblique.features_size(); 61 features_.reset(new Tensor(tensorflow::DT_INT32, TensorShape({feat_size}))); 62 auto feat = features_->flat<int32>(); 63 int i = 0; 64 for (const auto& id : oblique.features()) { 65 safe_strto32(id.id().value(), &feat(i++)); 66 } 67 } 68 69 void CandidateGraphRunner::RunOp(const string& name, 70 const TensorNameValueList& inputs, 71 const std::vector<string>& output_tensor_names, 72 std::vector<Tensor>* outputs) { 73 std::vector<string> op_name; 74 if (name != kNoOp) { 75 op_name.push_back(name); 76 } 77 TF_CHECK_OK(session_->Run(inputs, output_tensor_names, op_name, outputs)) 78 << "Failed to run: " << name; 79 } 80 81 void CandidateGraphRunner::Init() { 82 RunOp(kInitializeOp, TensorNameValueList(), std::vector<string>(), nullptr); 83 } 84 85 void CandidateGraphRunner::AddExample(const Tensor& input_data, 86 const Tensor& target, 87 const Tensor& examples) { 88 TensorNameValueList inputs; 89 inputs.emplace_back(kFeaturesName, *features_); 90 inputs.emplace_back(kExamplesName, examples); 91 inputs.emplace_back(kInputDataName, input_data); 92 inputs.emplace_back(kTargetsName, target); 93 94 RunOp(kAddExampleOp, inputs, std::vector<string>(), nullptr); 95 } 96 97 float CandidateGraphRunner::SplitScore() { 98 std::vector<Tensor> outputs; 99 RunOp(kNoOp, TensorNameValueList(), {kSplitScoreName}, &outputs); 100 return outputs[0].unaligned_flat<float>()(0); 101 } 102 103 void CandidateGraphRunner::GetSplit(decision_trees::BinaryNode* node) { 104 std::vector<Tensor> outputs; 105 RunOp(kNoOp, TensorNameValueList(), {kGetSplitName}, &outputs); 106 ParseProtoUnlimited(node, outputs[0].unaligned_flat<string>()(0)); 107 const auto& oblique = split_.inequality_left_child_test().oblique(); 108 auto* new_split = 109 node->mutable_inequality_left_child_test()->mutable_oblique(); 110 for (const auto& id : oblique.features()) { 111 *new_split->add_features() = id; 112 } 113 } 114 115 void CandidateGraphRunner::GetLeftStats(LeafStat* stats) { 116 std::vector<Tensor> outputs; 117 RunOp(kNoOp, TensorNameValueList(), {kGetLeftStatsName}, &outputs); 118 const auto& counts = outputs[0].unaligned_flat<float>(); 119 auto* dense = stats->mutable_classification()->mutable_dense_counts(); 120 for (int i = 0; i < counts.size(); ++i) { 121 dense->add_value()->set_float_value(counts(i)); 122 } 123 } 124 125 void CandidateGraphRunner::GetRightStats(LeafStat* stats) { 126 std::vector<Tensor> outputs; 127 RunOp(kNoOp, TensorNameValueList(), {kGetRightStatsName}, &outputs); 128 const auto& counts = outputs[0].unaligned_flat<float>(); 129 auto* dense = stats->mutable_classification()->mutable_dense_counts(); 130 for (int i = 0; i < counts.size(); ++i) { 131 dense->add_value()->set_float_value(counts(i)); 132 } 133 } 134 135 } // namespace tensorforest 136 } // namespace tensorflow 137