1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Helper routines for encoding/decoding tensor contents. 17 #ifndef TENSORFLOW_PLATFORM_TENSOR_CODING_H_ 18 #define TENSORFLOW_PLATFORM_TENSOR_CODING_H_ 19 20 #include <string> 21 #include "tensorflow/core/lib/core/refcount.h" 22 #include "tensorflow/core/lib/core/stringpiece.h" 23 #include "tensorflow/core/platform/platform.h" 24 #include "tensorflow/core/platform/protobuf.h" 25 #include "tensorflow/core/platform/types.h" 26 27 namespace tensorflow { 28 namespace port { 29 30 // Store src contents in *out. If backing memory for src is shared with *out, 31 // will ref obj during the call and will arrange to unref obj when no 32 // longer needed. 33 void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out); 34 35 // Copy contents of src to dst[0,src.size()-1]. 36 inline void CopyToArray(const string& src, char* dst) { 37 memcpy(dst, src.data(), src.size()); 38 } 39 40 // Copy subrange [pos:(pos + n)) from src to dst. If pos >= src.size() the 41 // result is empty. If pos + n > src.size() the subrange [pos, size()) is 42 // copied. 43 inline void CopySubrangeToArray(const string& src, size_t pos, size_t n, 44 char* dst) { 45 if (pos >= src.size()) return; 46 memcpy(dst, src.data() + pos, std::min(n, src.size() - pos)); 47 } 48 49 // Store encoding of strings[0..n-1] in *out. 50 void EncodeStringList(const string* strings, int64 n, string* out); 51 52 // Decode n strings from src and store in strings[0..n-1]. 53 // Returns true if successful, false on parse error. 54 bool DecodeStringList(const string& src, string* strings, int64 n); 55 56 // Assigns base[0..bytes-1] to *s 57 void CopyFromArray(string* s, const char* base, size_t bytes); 58 59 // Encodes sequences of strings and serialized protocol buffers into a string. 60 // Normal usage consists of zero or more calls to Append() and a single call to 61 // Finalize(). 62 class StringListEncoder { 63 public: 64 virtual ~StringListEncoder() = default; 65 66 // Encodes the given protocol buffer. This may not be called after Finalize(). 67 virtual void Append(const protobuf::MessageLite& m) = 0; 68 69 // Encodes the given string. This may not be called after Finalize(). 70 virtual void Append(const string& s) = 0; 71 72 // Signals end of the encoding process. No other calls are allowed after this. 73 virtual void Finalize() = 0; 74 }; 75 76 // Decodes a string into sequences of strings (which may represent serialized 77 // protocol buffers). Normal usage involves a single call to ReadSizes() in 78 // order to retrieve the length of all the strings in the sequence. For each 79 // size returned a call to Data() is expected and will return the actual 80 // string. 81 class StringListDecoder { 82 public: 83 virtual ~StringListDecoder() = default; 84 85 // Populates the given vector with the lengths of each string in the sequence 86 // being decoded. Upon returning the vector is guaranteed to contain as many 87 // elements as there are strings in the sequence. 88 virtual bool ReadSizes(std::vector<uint32>* sizes) = 0; 89 90 // Returns a pointer to the next string in the sequence, then prepares for the 91 // next call by advancing 'size' characters in the sequence. 92 virtual const char* Data(uint32 size) = 0; 93 }; 94 95 std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out); 96 std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in); 97 98 #if defined(TENSORFLOW_PROTOBUF_USES_CORD) 99 // Store src contents in *out. If backing memory for src is shared with *out, 100 // will ref obj during the call and will arrange to unref obj when no 101 // longer needed. 102 void AssignRefCounted(StringPiece src, core::RefCounted* obj, Cord* out); 103 104 // TODO(kmensah): Macro guard this with a check for Cord support. 105 inline void CopyToArray(const Cord& src, char* dst) { src.CopyToArray(dst); } 106 107 // Copy n bytes of src to dst. If pos >= src.size() the result is empty. 108 // If pos + n > src.size() the subrange [pos, size()) is copied. 109 inline void CopySubrangeToArray(const Cord& src, int64 pos, int64 n, 110 char* dst) { 111 src.Subcord(pos, n).CopyToArray(dst); 112 } 113 114 // Store encoding of strings[0..n-1] in *out. 115 void EncodeStringList(const string* strings, int64 n, Cord* out); 116 117 // Decode n strings from src and store in strings[0..n-1]. 118 // Returns true if successful, false on parse error. 119 bool DecodeStringList(const Cord& src, string* strings, int64 n); 120 121 // Assigns base[0..bytes-1] to *c 122 void CopyFromArray(Cord* c, const char* base, size_t bytes); 123 124 std::unique_ptr<StringListEncoder> NewStringListEncoder(Cord* out); 125 std::unique_ptr<StringListDecoder> NewStringListDecoder(const Cord& in); 126 #endif // defined(TENSORFLOW_PROTOBUF_USES_CORD) 127 128 } // namespace port 129 } // namespace tensorflow 130 131 #endif // TENSORFLOW_PLATFORM_TENSOR_CODING_H_ 132