Home | History | Annotate | Download | only in kernels
      1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 // This is a surrogate for using a proto, since it doesn't seem to be possible
     16 // to use protos in a dynamically-loaded/shared-linkage library, which is
     17 // what is used for custom ops in tensorflow/contrib.
     18 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
     19 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
     20 #include <unordered_map>
     21 
     22 #include "tensorflow/core/lib/strings/numbers.h"
     23 #include "tensorflow/core/lib/strings/str_util.h"
     24 
     25 namespace tensorflow {
     26 namespace tensorforest {
     27 
     28 using tensorflow::strings::safe_strto32;
     29 
     30 // DataColumn holds information about one feature of the original data.
     31 // A feature could be dense or sparse, and be of any size.
     32 class DataColumn {
     33  public:
     34   DataColumn() {}
     35 
     36   // Parses a serialized DataColumn produced from the SerializeToString()
     37   // function of a python data_ops.DataColumn object.
     38   // It should look like a proto ASCII format, i.e.
     39   // name: <name> original_type: <type> size: <size>
     40   void ParseFromString(const string& serialized) {
     41     std::vector<string> tokens = tensorflow::str_util::Split(serialized, ' ');
     42     CHECK_EQ(tokens.size(), 6);
     43     name_ = tokens[1];
     44     safe_strto32(tokens[3], &original_type_);
     45     safe_strto32(tokens[5], &size_);
     46   }
     47 
     48   const string& name() const { return name_; }
     49 
     50   int original_type() const { return original_type_; }
     51 
     52   int size() const { return size_; }
     53 
     54   void set_name(const string& n) { name_ = n; }
     55 
     56   void set_original_type(int o) { original_type_ = o; }
     57 
     58   void set_size(int s) { size_ = s; }
     59 
     60  private:
     61   string name_;
     62   int original_type_;
     63   int size_;
     64 };
     65 
     66 // TensorForestDataSpec holds information about the original features of the
     67 // data set, which were flattened to a single dense float tensor and/or a
     68 // single sparse float tensor.
     69 class TensorForestDataSpec {
     70  public:
     71   TensorForestDataSpec() {}
     72 
     73   // Parses a serialized DataColumn produced from the SerializeToString()
     74   // function of a python data_ops.TensorForestDataSpec object.
     75   // It should look something like:
     76   // dense_features_size: <size> dense: [{<col1>}{<col2>}] sparse: [{<col3>}]
     77   void ParseFromString(const string& serialized) {
     78     std::vector<string> tokens = tensorflow::str_util::Split(serialized, "[]");
     79     std::vector<string> first_part =
     80         tensorflow::str_util::Split(tokens[0], ' ');
     81     safe_strto32(first_part[1], &dense_features_size_);
     82     ParseColumns(tokens[1], &dense_);
     83     ParseColumns(tokens[3], &sparse_);
     84 
     85     int total = 0;
     86     for (const DataColumn& col : dense_) {
     87       for (int i = 0; i < col.size(); ++i) {
     88         feature_to_type_.push_back(col.original_type());
     89         ++total;
     90       }
     91     }
     92   }
     93 
     94   const DataColumn& dense(int i) const { return dense_.at(i); }
     95 
     96   const DataColumn& sparse(int i) const { return sparse_.at(i); }
     97 
     98   DataColumn* mutable_sparse(int i) { return &sparse_[i]; }
     99 
    100   int dense_size() const { return dense_.size(); }
    101 
    102   int sparse_size() const { return sparse_.size(); }
    103 
    104   int dense_features_size() const { return dense_features_size_; }
    105 
    106   void set_dense_features_size(int s) { dense_features_size_ = s; }
    107 
    108   DataColumn* add_dense() {
    109     dense_.push_back(DataColumn());
    110     return &dense_[dense_.size() - 1];
    111   }
    112 
    113   int GetDenseFeatureType(int feature) const {
    114     return feature_to_type_[feature];
    115   }
    116 
    117  private:
    118   void ParseColumns(const string& cols, std::vector<DataColumn>* vec) {
    119     std::vector<string> tokens = tensorflow::str_util::Split(cols, "{}");
    120     for (const string& tok : tokens) {
    121       if (!tok.empty()) {
    122         DataColumn col;
    123         col.ParseFromString(tok);
    124         vec->push_back(col);
    125       }
    126     }
    127   }
    128 
    129   std::vector<DataColumn> dense_;
    130   std::vector<DataColumn> sparse_;
    131   int dense_features_size_;
    132 
    133   // This map tracks features in the total dense feature space to their
    134   // original type for fast lookup.
    135   std::vector<int> feature_to_type_;
    136 };
    137 
    138 }  // namespace tensorforest
    139 }  // namespace tensorflow
    140 
    141 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
    142