1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 // This is a surrogate for using a proto, since it doesn't seem to be possible 16 // to use protos in a dynamically-loaded/shared-linkage library, which is 17 // what is used for custom ops in tensorflow/contrib. 18 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_ 19 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_ 20 #include <unordered_map> 21 22 #include "tensorflow/core/lib/strings/numbers.h" 23 #include "tensorflow/core/lib/strings/str_util.h" 24 25 namespace tensorflow { 26 namespace tensorforest { 27 28 using tensorflow::strings::safe_strto32; 29 30 // DataColumn holds information about one feature of the original data. 31 // A feature could be dense or sparse, and be of any size. 32 class DataColumn { 33 public: 34 DataColumn() {} 35 36 // Parses a serialized DataColumn produced from the SerializeToString() 37 // function of a python data_ops.DataColumn object. 38 // It should look like a proto ASCII format, i.e. 39 // name: <name> original_type: <type> size: <size> 40 void ParseFromString(const string& serialized) { 41 std::vector<string> tokens = tensorflow::str_util::Split(serialized, ' '); 42 CHECK_EQ(tokens.size(), 6); 43 name_ = tokens[1]; 44 safe_strto32(tokens[3], &original_type_); 45 safe_strto32(tokens[5], &size_); 46 } 47 48 const string& name() const { return name_; } 49 50 int original_type() const { return original_type_; } 51 52 int size() const { return size_; } 53 54 void set_name(const string& n) { name_ = n; } 55 56 void set_original_type(int o) { original_type_ = o; } 57 58 void set_size(int s) { size_ = s; } 59 60 private: 61 string name_; 62 int original_type_; 63 int size_; 64 }; 65 66 // TensorForestDataSpec holds information about the original features of the 67 // data set, which were flattened to a single dense float tensor and/or a 68 // single sparse float tensor. 69 class TensorForestDataSpec { 70 public: 71 TensorForestDataSpec() {} 72 73 // Parses a serialized DataColumn produced from the SerializeToString() 74 // function of a python data_ops.TensorForestDataSpec object. 75 // It should look something like: 76 // dense_features_size: <size> dense: [{<col1>}{<col2>}] sparse: [{<col3>}] 77 void ParseFromString(const string& serialized) { 78 std::vector<string> tokens = tensorflow::str_util::Split(serialized, "[]"); 79 std::vector<string> first_part = 80 tensorflow::str_util::Split(tokens[0], ' '); 81 safe_strto32(first_part[1], &dense_features_size_); 82 ParseColumns(tokens[1], &dense_); 83 ParseColumns(tokens[3], &sparse_); 84 85 int total = 0; 86 for (const DataColumn& col : dense_) { 87 for (int i = 0; i < col.size(); ++i) { 88 feature_to_type_.push_back(col.original_type()); 89 ++total; 90 } 91 } 92 } 93 94 const DataColumn& dense(int i) const { return dense_.at(i); } 95 96 const DataColumn& sparse(int i) const { return sparse_.at(i); } 97 98 DataColumn* mutable_sparse(int i) { return &sparse_[i]; } 99 100 int dense_size() const { return dense_.size(); } 101 102 int sparse_size() const { return sparse_.size(); } 103 104 int dense_features_size() const { return dense_features_size_; } 105 106 void set_dense_features_size(int s) { dense_features_size_ = s; } 107 108 DataColumn* add_dense() { 109 dense_.push_back(DataColumn()); 110 return &dense_[dense_.size() - 1]; 111 } 112 113 int GetDenseFeatureType(int feature) const { 114 return feature_to_type_[feature]; 115 } 116 117 private: 118 void ParseColumns(const string& cols, std::vector<DataColumn>* vec) { 119 std::vector<string> tokens = tensorflow::str_util::Split(cols, "{}"); 120 for (const string& tok : tokens) { 121 if (!tok.empty()) { 122 DataColumn col; 123 col.ParseFromString(tok); 124 vec->push_back(col); 125 } 126 } 127 } 128 129 std::vector<DataColumn> dense_; 130 std::vector<DataColumn> sparse_; 131 int dense_features_size_; 132 133 // This map tracks features in the total dense feature space to their 134 // original type for fast lookup. 135 std::vector<int> feature_to_type_; 136 }; 137 138 } // namespace tensorforest 139 } // namespace tensorflow 140 141 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_ 142