Home | History | Annotate | Download | only in v4
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h"
     16 
     17 #include "tensorflow/core/framework/graph.pb.h"
     18 #include "tensorflow/core/lib/io/path.h"
     19 #include "tensorflow/core/platform/env.h"
     20 
     21 namespace tensorflow {
     22 namespace tensorforest {
     23 
     24 // Names of ops in the graph to run.
     25 constexpr char kInitializeOp[] = "init";
     26 constexpr char kAddExampleOp[] = "add_example";
     27 constexpr char kSplitScoreName[] = "split_score";
     28 constexpr char kGetSplitName[] = "get_split";
     29 constexpr char kGetLeftStatsName[] = "get_left_stats";
     30 constexpr char kGetRightStatsName[] = "get_right_stats";
     31 
     32 // Names of files written by python graph builder.
     33 constexpr char kGraphFilename[] = "graph";
     34 constexpr char kSaverDefFilename[] = "saver";
     35 constexpr char kMetaDefFilename[] = "meta";
     36 
     37 // Names of Tensor inputs.
     38 constexpr char kFeaturesName[] = "features";
     39 constexpr char kInputDataName[] = "input_data";
     40 constexpr char kTargetsName[] = "targets";
     41 constexpr char kExamplesName[] = "examples";
     42 
     43 constexpr char kNoOp[] = "none";
     44 
     45 CandidateGraphRunner::CandidateGraphRunner(
     46     const string& graph_dir, const decision_trees::BinaryNode& split)
     47     : split_(split) {
     48   // read graph from file.
     49   GraphDef graph_def;
     50   TF_CHECK_OK(ReadBinaryProto(
     51       Env::Default(), io::JoinPath(graph_dir, kGraphFilename), &graph_def))
     52       << "Could not read graph def.";
     53 
     54   // create session.
     55   session_.reset(::tensorflow::NewSession(SessionOptions()));
     56   TF_CHECK_OK(session_->Create(graph_def)) << "Failed to create session";
     57 
     58   // Features don't change, store them in a tensor.
     59   const auto& oblique = split.inequality_left_child_test().oblique();
     60   const int32 feat_size = oblique.features_size();
     61   features_.reset(new Tensor(tensorflow::DT_INT32, TensorShape({feat_size})));
     62   auto feat = features_->flat<int32>();
     63   int i = 0;
     64   for (const auto& id : oblique.features()) {
     65     safe_strto32(id.id().value(), &feat(i++));
     66   }
     67 }
     68 
     69 void CandidateGraphRunner::RunOp(const string& name,
     70                                  const TensorNameValueList& inputs,
     71                                  const std::vector<string>& output_tensor_names,
     72                                  std::vector<Tensor>* outputs) {
     73   std::vector<string> op_name;
     74   if (name != kNoOp) {
     75     op_name.push_back(name);
     76   }
     77   TF_CHECK_OK(session_->Run(inputs, output_tensor_names, op_name, outputs))
     78       << "Failed to run: " << name;
     79 }
     80 
     81 void CandidateGraphRunner::Init() {
     82   RunOp(kInitializeOp, TensorNameValueList(), std::vector<string>(), nullptr);
     83 }
     84 
     85 void CandidateGraphRunner::AddExample(const Tensor& input_data,
     86                                       const Tensor& target,
     87                                       const Tensor& examples) {
     88   TensorNameValueList inputs;
     89   inputs.emplace_back(kFeaturesName, *features_);
     90   inputs.emplace_back(kExamplesName, examples);
     91   inputs.emplace_back(kInputDataName, input_data);
     92   inputs.emplace_back(kTargetsName, target);
     93 
     94   RunOp(kAddExampleOp, inputs, std::vector<string>(), nullptr);
     95 }
     96 
     97 float CandidateGraphRunner::SplitScore() {
     98   std::vector<Tensor> outputs;
     99   RunOp(kNoOp, TensorNameValueList(), {kSplitScoreName}, &outputs);
    100   return outputs[0].unaligned_flat<float>()(0);
    101 }
    102 
    103 void CandidateGraphRunner::GetSplit(decision_trees::BinaryNode* node) {
    104   std::vector<Tensor> outputs;
    105   RunOp(kNoOp, TensorNameValueList(), {kGetSplitName}, &outputs);
    106   ParseProtoUnlimited(node, outputs[0].unaligned_flat<string>()(0));
    107   const auto& oblique = split_.inequality_left_child_test().oblique();
    108   auto* new_split =
    109       node->mutable_inequality_left_child_test()->mutable_oblique();
    110   for (const auto& id : oblique.features()) {
    111     *new_split->add_features() = id;
    112   }
    113 }
    114 
    115 void CandidateGraphRunner::GetLeftStats(LeafStat* stats) {
    116   std::vector<Tensor> outputs;
    117   RunOp(kNoOp, TensorNameValueList(), {kGetLeftStatsName}, &outputs);
    118   const auto& counts = outputs[0].unaligned_flat<float>();
    119   auto* dense = stats->mutable_classification()->mutable_dense_counts();
    120   for (int i = 0; i < counts.size(); ++i) {
    121     dense->add_value()->set_float_value(counts(i));
    122   }
    123 }
    124 
    125 void CandidateGraphRunner::GetRightStats(LeafStat* stats) {
    126   std::vector<Tensor> outputs;
    127   RunOp(kNoOp, TensorNameValueList(), {kGetRightStatsName}, &outputs);
    128   const auto& counts = outputs[0].unaligned_flat<float>();
    129   auto* dense = stats->mutable_classification()->mutable_dense_counts();
    130   for (int i = 0; i < counts.size(); ++i) {
    131     dense->add_value()->set_float_value(counts(i));
    132   }
    133 }
    134 
    135 }  // namespace tensorforest
    136 }  // namespace tensorflow
    137