Home | History | Annotate | Download | only in platform
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/platform/tensor_coding.h"
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/core/lib/core/coding.h"
     21 #include "tensorflow/core/lib/core/stringpiece.h"
     22 #include "tensorflow/core/lib/strings/strcat.h"
     23 #include "tensorflow/core/platform/protobuf.h"
     24 
     25 #if defined(TENSORFLOW_PROTOBUF_USES_CORD)
     26 #include "strings/cord_varint.h"
     27 #endif  // defined(TENSORFLOW_PROTOBUF_USES_CORD)
     28 
     29 namespace tensorflow {
     30 namespace port {
     31 
     32 void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) {
     33   out->assign(src.data(), src.size());
     34 }
     35 
     36 void EncodeStringList(const string* strings, int64 n, string* out) {
     37   out->clear();
     38   for (int i = 0; i < n; ++i) {
     39     core::PutVarint32(out, strings[i].size());
     40   }
     41   for (int i = 0; i < n; ++i) {
     42     out->append(strings[i]);
     43   }
     44 }
     45 
     46 bool DecodeStringList(const string& src, string* strings, int64 n) {
     47   std::vector<uint32> sizes(n);
     48   StringPiece reader(src);
     49   int64 tot = 0;
     50   for (auto& v : sizes) {
     51     if (!core::GetVarint32(&reader, &v)) return false;
     52     tot += v;
     53   }
     54   if (tot != static_cast<int64>(reader.size())) {
     55     return false;
     56   }
     57 
     58   string* data = strings;
     59   for (int64 i = 0; i < n; ++i, ++data) {
     60     auto size = sizes[i];
     61     if (size > reader.size()) {
     62       return false;
     63     }
     64     data->assign(reader.data(), size);
     65     reader.remove_prefix(size);
     66   }
     67 
     68   return true;
     69 }
     70 
     71 void CopyFromArray(string* s, const char* base, size_t bytes) {
     72   s->assign(base, bytes);
     73 }
     74 
     75 class StringListEncoderImpl : public StringListEncoder {
     76  public:
     77   explicit StringListEncoderImpl(string* out) : out_(out) {}
     78   ~StringListEncoderImpl() override = default;
     79 
     80   void Append(const protobuf::MessageLite& m) override {
     81     core::PutVarint32(out_, m.ByteSizeLong());
     82     tensorflow::string serialized_message;
     83     m.AppendToString(&serialized_message);
     84     strings::StrAppend(&rest_, serialized_message);
     85   }
     86 
     87   void Append(const string& s) override {
     88     core::PutVarint32(out_, s.length());
     89     strings::StrAppend(&rest_, s);
     90   }
     91 
     92   void Finalize() override { strings::StrAppend(out_, rest_); }
     93 
     94  private:
     95   string* out_;
     96   string rest_;
     97 };
     98 
     99 class StringListDecoderImpl : public StringListDecoder {
    100  public:
    101   explicit StringListDecoderImpl(const string& in) : reader_(in) {}
    102   ~StringListDecoderImpl() override = default;
    103 
    104   bool ReadSizes(std::vector<uint32>* sizes) override {
    105     int64 total = 0;
    106     for (auto& size : *sizes) {
    107       if (!core::GetVarint32(&reader_, &size)) return false;
    108       total += size;
    109     }
    110     if (total != static_cast<int64>(reader_.size())) {
    111       return false;
    112     }
    113     return true;
    114   }
    115 
    116   const char* Data(uint32 size) override {
    117     const char* data = reader_.data();
    118     reader_.remove_prefix(size);
    119     return data;
    120   }
    121 
    122  private:
    123   StringPiece reader_;
    124 };
    125 
    126 std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out) {
    127   return std::unique_ptr<StringListEncoder>(new StringListEncoderImpl(out));
    128 }
    129 
    130 std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in) {
    131   return std::unique_ptr<StringListDecoder>(new StringListDecoderImpl(in));
    132 }
    133 
    134 #if defined(TENSORFLOW_PROTOBUF_USES_CORD)
    135 void AssignRefCounted(StringPiece src, core::RefCounted* obj, Cord* out) {
    136   obj->Ref();
    137   out->Clear();
    138   // Defines a lambda to unref "obj" when Cord deletes this piece of
    139   // memory. +[] converts the lambda to a C style function pointer.
    140   auto cleanup = +[](absl::string_view donotcare, void* obj) {
    141     reinterpret_cast<core::RefCounted*>(obj)->Unref();
    142   };
    143   out->AppendExternalMemory(absl::string_view(src.data(), src.size()), obj,
    144                             cleanup);
    145 }
    146 
    147 void EncodeStringList(const string* strings, int64 n, Cord* out) {
    148   out->Clear();
    149   for (int i = 0; i < n; ++i) {
    150     ::strings::CordAppendVarint(strings[i].size(), out);
    151   }
    152   for (int i = 0; i < n; ++i) {
    153     out->Append(strings[i]);
    154   }
    155 }
    156 
    157 bool DecodeStringList(const Cord& src, string* strings, int64 n) {
    158   std::vector<uint32> sizes(n);
    159   CordReader reader(src);
    160   int64 tot = 0;
    161   for (auto& v : sizes) {
    162     if (!::strings::CordReaderReadVarint(&reader, &v)) return false;
    163     tot += v;
    164   }
    165   if (tot != reader.Available()) {
    166     return false;
    167   }
    168   string* data = strings;
    169   for (int i = 0; i < n; ++i, ++data) {
    170     auto size = sizes[i];
    171     if (size > reader.Available()) {
    172       return false;
    173     }
    174     gtl::STLStringResizeUninitialized(data, size);
    175     reader.ReadN(size, gtl::string_as_array(data));
    176   }
    177   return true;
    178 }
    179 
    180 void CopyFromArray(Cord* c, const char* base, size_t bytes) {
    181   c->CopyFrom(base, bytes);
    182 }
    183 
    184 class CordStringListEncoderImpl : public StringListEncoder {
    185  public:
    186   explicit CordStringListEncoderImpl(Cord* out) : out_(out) {}
    187   ~CordStringListEncoderImpl() override = default;
    188 
    189   void Append(const protobuf::MessageLite& m) override {
    190     ::strings::CordAppendVarint(m.ByteSizeLong(), out_);
    191     m.AppendToString(&rest_);
    192   }
    193 
    194   void Append(const string& s) override {
    195     ::strings::CordAppendVarint(s.length(), out_);
    196     rest_.append(s.data(), s.size());
    197   }
    198 
    199   void Finalize() override { out_->Append(rest_); }
    200 
    201  private:
    202   Cord* out_;
    203   string rest_;
    204 };
    205 
    206 class CordStringListDecoderImpl : public StringListDecoder {
    207  public:
    208   explicit CordStringListDecoderImpl(const Cord& in) : reader_(in) {}
    209   ~CordStringListDecoderImpl() override = default;
    210 
    211   bool ReadSizes(std::vector<uint32>* sizes) override {
    212     int64 total = 0;
    213     for (auto& size : *sizes) {
    214       if (!::strings::CordReaderReadVarint(&reader_, &size)) return false;
    215       total += size;
    216     }
    217     if (total != static_cast<int64>(reader_.Available())) {
    218       return false;
    219     }
    220     return true;
    221   }
    222 
    223   const char* Data(uint32 size) override {
    224     tmp_.resize(size);
    225     reader_.ReadN(size, tmp_.data());
    226     return tmp_.data();
    227   }
    228 
    229  private:
    230   CordReader reader_;
    231   std::vector<char> tmp_;
    232 };
    233 
    234 std::unique_ptr<StringListEncoder> NewStringListEncoder(Cord* out) {
    235   return std::unique_ptr<StringListEncoder>(new CordStringListEncoderImpl(out));
    236 }
    237 
    238 std::unique_ptr<StringListDecoder> NewStringListDecoder(const Cord& in) {
    239   return std::unique_ptr<StringListDecoder>(new CordStringListDecoderImpl(in));
    240 }
    241 
    242 #endif  // defined(TENSORFLOW_PROTOBUF_USES_CORD)
    243 
    244 }  // namespace port
    245 }  // namespace tensorflow
    246