1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/platform/tensor_coding.h" 17 18 #include <vector> 19 20 #include "tensorflow/core/lib/core/coding.h" 21 #include "tensorflow/core/lib/core/stringpiece.h" 22 #include "tensorflow/core/lib/strings/strcat.h" 23 #include "tensorflow/core/platform/protobuf.h" 24 25 #if defined(TENSORFLOW_PROTOBUF_USES_CORD) 26 #include "strings/cord_varint.h" 27 #endif // defined(TENSORFLOW_PROTOBUF_USES_CORD) 28 29 namespace tensorflow { 30 namespace port { 31 32 void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) { 33 out->assign(src.data(), src.size()); 34 } 35 36 void EncodeStringList(const string* strings, int64 n, string* out) { 37 out->clear(); 38 for (int i = 0; i < n; ++i) { 39 core::PutVarint32(out, strings[i].size()); 40 } 41 for (int i = 0; i < n; ++i) { 42 out->append(strings[i]); 43 } 44 } 45 46 bool DecodeStringList(const string& src, string* strings, int64 n) { 47 std::vector<uint32> sizes(n); 48 StringPiece reader(src); 49 int64 tot = 0; 50 for (auto& v : sizes) { 51 if (!core::GetVarint32(&reader, &v)) return false; 52 tot += v; 53 } 54 if (tot != static_cast<int64>(reader.size())) { 55 return false; 56 } 57 58 string* data = strings; 59 for (int64 i = 0; i < n; ++i, ++data) { 60 auto size = sizes[i]; 61 if (size > reader.size()) { 62 return false; 63 } 64 data->assign(reader.data(), size); 65 reader.remove_prefix(size); 66 } 67 68 return true; 69 } 70 71 void CopyFromArray(string* s, const char* base, size_t bytes) { 72 s->assign(base, bytes); 73 } 74 75 class StringListEncoderImpl : public StringListEncoder { 76 public: 77 explicit StringListEncoderImpl(string* out) : out_(out) {} 78 ~StringListEncoderImpl() override = default; 79 80 void Append(const protobuf::MessageLite& m) override { 81 core::PutVarint32(out_, m.ByteSizeLong()); 82 tensorflow::string serialized_message; 83 m.AppendToString(&serialized_message); 84 strings::StrAppend(&rest_, serialized_message); 85 } 86 87 void Append(const string& s) override { 88 core::PutVarint32(out_, s.length()); 89 strings::StrAppend(&rest_, s); 90 } 91 92 void Finalize() override { strings::StrAppend(out_, rest_); } 93 94 private: 95 string* out_; 96 string rest_; 97 }; 98 99 class StringListDecoderImpl : public StringListDecoder { 100 public: 101 explicit StringListDecoderImpl(const string& in) : reader_(in) {} 102 ~StringListDecoderImpl() override = default; 103 104 bool ReadSizes(std::vector<uint32>* sizes) override { 105 int64 total = 0; 106 for (auto& size : *sizes) { 107 if (!core::GetVarint32(&reader_, &size)) return false; 108 total += size; 109 } 110 if (total != static_cast<int64>(reader_.size())) { 111 return false; 112 } 113 return true; 114 } 115 116 const char* Data(uint32 size) override { 117 const char* data = reader_.data(); 118 reader_.remove_prefix(size); 119 return data; 120 } 121 122 private: 123 StringPiece reader_; 124 }; 125 126 std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out) { 127 return std::unique_ptr<StringListEncoder>(new StringListEncoderImpl(out)); 128 } 129 130 std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in) { 131 return std::unique_ptr<StringListDecoder>(new StringListDecoderImpl(in)); 132 } 133 134 #if defined(TENSORFLOW_PROTOBUF_USES_CORD) 135 void AssignRefCounted(StringPiece src, core::RefCounted* obj, Cord* out) { 136 obj->Ref(); 137 out->Clear(); 138 // Defines a lambda to unref "obj" when Cord deletes this piece of 139 // memory. +[] converts the lambda to a C style function pointer. 140 auto cleanup = +[](absl::string_view donotcare, void* obj) { 141 reinterpret_cast<core::RefCounted*>(obj)->Unref(); 142 }; 143 out->AppendExternalMemory(absl::string_view(src.data(), src.size()), obj, 144 cleanup); 145 } 146 147 void EncodeStringList(const string* strings, int64 n, Cord* out) { 148 out->Clear(); 149 for (int i = 0; i < n; ++i) { 150 ::strings::CordAppendVarint(strings[i].size(), out); 151 } 152 for (int i = 0; i < n; ++i) { 153 out->Append(strings[i]); 154 } 155 } 156 157 bool DecodeStringList(const Cord& src, string* strings, int64 n) { 158 std::vector<uint32> sizes(n); 159 CordReader reader(src); 160 int64 tot = 0; 161 for (auto& v : sizes) { 162 if (!::strings::CordReaderReadVarint(&reader, &v)) return false; 163 tot += v; 164 } 165 if (tot != reader.Available()) { 166 return false; 167 } 168 string* data = strings; 169 for (int i = 0; i < n; ++i, ++data) { 170 auto size = sizes[i]; 171 if (size > reader.Available()) { 172 return false; 173 } 174 gtl::STLStringResizeUninitialized(data, size); 175 reader.ReadN(size, gtl::string_as_array(data)); 176 } 177 return true; 178 } 179 180 void CopyFromArray(Cord* c, const char* base, size_t bytes) { 181 c->CopyFrom(base, bytes); 182 } 183 184 class CordStringListEncoderImpl : public StringListEncoder { 185 public: 186 explicit CordStringListEncoderImpl(Cord* out) : out_(out) {} 187 ~CordStringListEncoderImpl() override = default; 188 189 void Append(const protobuf::MessageLite& m) override { 190 ::strings::CordAppendVarint(m.ByteSizeLong(), out_); 191 m.AppendToString(&rest_); 192 } 193 194 void Append(const string& s) override { 195 ::strings::CordAppendVarint(s.length(), out_); 196 rest_.append(s.data(), s.size()); 197 } 198 199 void Finalize() override { out_->Append(rest_); } 200 201 private: 202 Cord* out_; 203 string rest_; 204 }; 205 206 class CordStringListDecoderImpl : public StringListDecoder { 207 public: 208 explicit CordStringListDecoderImpl(const Cord& in) : reader_(in) {} 209 ~CordStringListDecoderImpl() override = default; 210 211 bool ReadSizes(std::vector<uint32>* sizes) override { 212 int64 total = 0; 213 for (auto& size : *sizes) { 214 if (!::strings::CordReaderReadVarint(&reader_, &size)) return false; 215 total += size; 216 } 217 if (total != static_cast<int64>(reader_.Available())) { 218 return false; 219 } 220 return true; 221 } 222 223 const char* Data(uint32 size) override { 224 tmp_.resize(size); 225 reader_.ReadN(size, tmp_.data()); 226 return tmp_.data(); 227 } 228 229 private: 230 CordReader reader_; 231 std::vector<char> tmp_; 232 }; 233 234 std::unique_ptr<StringListEncoder> NewStringListEncoder(Cord* out) { 235 return std::unique_ptr<StringListEncoder>(new CordStringListEncoderImpl(out)); 236 } 237 238 std::unique_ptr<StringListDecoder> NewStringListDecoder(const Cord& in) { 239 return std::unique_ptr<StringListDecoder>(new CordStringListDecoderImpl(in)); 240 } 241 242 #endif // defined(TENSORFLOW_PROTOBUF_USES_CORD) 243 244 } // namespace port 245 } // namespace tensorflow 246