Home | History | Annotate | Download | only in v4
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_
     16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_
     17 #include <string>
     18 #include <vector>
     19 
     20 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
     21 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
     22 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
     23 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/public/session.h"
     26 
     27 namespace tensorflow {
     28 namespace tensorforest {
     29 
     30 typedef std::vector<std::pair<string, ::tensorflow::Tensor>>
     31     TensorNameValueList;
     32 
     33 // Class that represents one split candidate, and can perform operations
     34 // on a session created from a graph.
     35 class CandidateGraphRunner {
     36  public:
     37   // split should contain the features that are being used.
     38   CandidateGraphRunner(const string& graph_dir,
     39                        const decision_trees::BinaryNode& split);
     40 
     41   // Input the given data and target Tensors to the add_example op.
     42   void AddExample(const Tensor& input_data, const Tensor& target,
     43                   const Tensor& examples);
     44 
     45   // Get the candidates' split score with the split_score op.
     46   float SplitScore();
     47 
     48   // Fills in the split in node with weights and threshold.
     49   void GetSplit(decision_trees::BinaryNode* node);
     50 
     51   // Fills in the stats for the left-branch taken.
     52   void GetLeftStats(LeafStat* stats);
     53 
     54   // Fills in the stats for the right-branch taken.
     55   void GetRightStats(LeafStat* stats);
     56 
     57   // Initializes variables, must be run before other ops.
     58   void Init();
     59 
     60  protected:
     61   void RunOp(const string& name, const TensorNameValueList& inputs,
     62              const std::vector<string>& output_tensor_names,
     63              std::vector<Tensor>* outputs);
     64 
     65   std::unique_ptr<Session> session_;
     66   decision_trees::BinaryNode split_;
     67   std::unique_ptr<Tensor> features_;
     68 };
     69 
     70 }  // namespace tensorforest
     71 }  // namespace tensorflow
     72 
     73 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_
     74