Home | History | Annotate | Download | only in platform
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Helper routines for encoding/decoding tensor contents.
     17 #ifndef TENSORFLOW_PLATFORM_TENSOR_CODING_H_
     18 #define TENSORFLOW_PLATFORM_TENSOR_CODING_H_
     19 
     20 #include <string>
     21 #include "tensorflow/core/lib/core/refcount.h"
     22 #include "tensorflow/core/lib/core/stringpiece.h"
     23 #include "tensorflow/core/platform/platform.h"
     24 #include "tensorflow/core/platform/protobuf.h"
     25 #include "tensorflow/core/platform/types.h"
     26 
     27 namespace tensorflow {
     28 namespace port {
     29 
     30 // Store src contents in *out.  If backing memory for src is shared with *out,
     31 // will ref obj during the call and will arrange to unref obj when no
     32 // longer needed.
     33 void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out);
     34 
     35 // Copy contents of src to dst[0,src.size()-1].
     36 inline void CopyToArray(const string& src, char* dst) {
     37   memcpy(dst, src.data(), src.size());
     38 }
     39 
     40 // Copy subrange [pos:(pos + n)) from src to dst. If pos >= src.size() the
     41 // result is empty. If pos + n > src.size() the subrange [pos, size()) is
     42 // copied.
     43 inline void CopySubrangeToArray(const string& src, size_t pos, size_t n,
     44                                 char* dst) {
     45   if (pos >= src.size()) return;
     46   memcpy(dst, src.data() + pos, std::min(n, src.size() - pos));
     47 }
     48 
     49 // Store encoding of strings[0..n-1] in *out.
     50 void EncodeStringList(const string* strings, int64 n, string* out);
     51 
     52 // Decode n strings from src and store in strings[0..n-1].
     53 // Returns true if successful, false on parse error.
     54 bool DecodeStringList(const string& src, string* strings, int64 n);
     55 
     56 // Assigns base[0..bytes-1] to *s
     57 void CopyFromArray(string* s, const char* base, size_t bytes);
     58 
     59 // Encodes sequences of strings and serialized protocol buffers into a string.
     60 // Normal usage consists of zero or more calls to Append() and a single call to
     61 // Finalize().
     62 class StringListEncoder {
     63  public:
     64   virtual ~StringListEncoder() = default;
     65 
     66   // Encodes the given protocol buffer. This may not be called after Finalize().
     67   virtual void Append(const protobuf::MessageLite& m) = 0;
     68 
     69   // Encodes the given string. This may not be called after Finalize().
     70   virtual void Append(const string& s) = 0;
     71 
     72   // Signals end of the encoding process. No other calls are allowed after this.
     73   virtual void Finalize() = 0;
     74 };
     75 
     76 // Decodes a string into sequences of strings (which may represent serialized
     77 // protocol buffers). Normal usage involves a single call to ReadSizes() in
     78 // order to retrieve the length of all the strings in the sequence. For each
     79 // size returned a call to Data() is expected and will return the actual
     80 // string.
     81 class StringListDecoder {
     82  public:
     83   virtual ~StringListDecoder() = default;
     84 
     85   // Populates the given vector with the lengths of each string in the sequence
     86   // being decoded. Upon returning the vector is guaranteed to contain as many
     87   // elements as there are strings in the sequence.
     88   virtual bool ReadSizes(std::vector<uint32>* sizes) = 0;
     89 
     90   // Returns a pointer to the next string in the sequence, then prepares for the
     91   // next call by advancing 'size' characters in the sequence.
     92   virtual const char* Data(uint32 size) = 0;
     93 };
     94 
     95 std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out);
     96 std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in);
     97 
     98 #if defined(TENSORFLOW_PROTOBUF_USES_CORD)
     99 // Store src contents in *out.  If backing memory for src is shared with *out,
    100 // will ref obj during the call and will arrange to unref obj when no
    101 // longer needed.
    102 void AssignRefCounted(StringPiece src, core::RefCounted* obj, Cord* out);
    103 
    104 // TODO(kmensah): Macro guard this with a check for Cord support.
    105 inline void CopyToArray(const Cord& src, char* dst) { src.CopyToArray(dst); }
    106 
    107 // Copy n bytes of src to dst. If pos >= src.size() the result is empty.
    108 // If pos + n > src.size() the subrange [pos, size()) is copied.
    109 inline void CopySubrangeToArray(const Cord& src, int64 pos, int64 n,
    110                                 char* dst) {
    111   src.Subcord(pos, n).CopyToArray(dst);
    112 }
    113 
    114 // Store encoding of strings[0..n-1] in *out.
    115 void EncodeStringList(const string* strings, int64 n, Cord* out);
    116 
    117 // Decode n strings from src and store in strings[0..n-1].
    118 // Returns true if successful, false on parse error.
    119 bool DecodeStringList(const Cord& src, string* strings, int64 n);
    120 
    121 // Assigns base[0..bytes-1] to *c
    122 void CopyFromArray(Cord* c, const char* base, size_t bytes);
    123 
    124 std::unique_ptr<StringListEncoder> NewStringListEncoder(Cord* out);
    125 std::unique_ptr<StringListDecoder> NewStringListDecoder(const Cord& in);
    126 #endif  // defined(TENSORFLOW_PROTOBUF_USES_CORD)
    127 
    128 }  // namespace port
    129 }  // namespace tensorflow
    130 
    131 #endif  // TENSORFLOW_PLATFORM_TENSOR_CODING_H_
    132