1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "zlib-utils.h" 18 19 #include <memory> 20 21 #include "util/base/logging.h" 22 #include "util/flatbuffers.h" 23 24 namespace libtextclassifier2 { 25 26 std::unique_ptr<ZlibDecompressor> ZlibDecompressor::Instance() { 27 std::unique_ptr<ZlibDecompressor> result(new ZlibDecompressor()); 28 if (!result->initialized_) { 29 result.reset(); 30 } 31 return result; 32 } 33 34 ZlibDecompressor::ZlibDecompressor() { 35 memset(&stream_, 0, sizeof(stream_)); 36 stream_.zalloc = Z_NULL; 37 stream_.zfree = Z_NULL; 38 initialized_ = (inflateInit(&stream_) == Z_OK); 39 } 40 41 ZlibDecompressor::~ZlibDecompressor() { 42 if (initialized_) { 43 inflateEnd(&stream_); 44 } 45 } 46 47 bool ZlibDecompressor::Decompress(const CompressedBuffer* compressed_buffer, 48 std::string* out) { 49 out->resize(compressed_buffer->uncompressed_size()); 50 stream_.next_in = 51 reinterpret_cast<const Bytef*>(compressed_buffer->buffer()->Data()); 52 stream_.avail_in = compressed_buffer->buffer()->Length(); 53 stream_.next_out = reinterpret_cast<Bytef*>(const_cast<char*>(out->c_str())); 54 stream_.avail_out = compressed_buffer->uncompressed_size(); 55 return (inflate(&stream_, Z_SYNC_FLUSH) == Z_OK); 56 } 57 58 std::unique_ptr<ZlibCompressor> ZlibCompressor::Instance() { 59 std::unique_ptr<ZlibCompressor> result(new ZlibCompressor()); 60 if (!result->initialized_) { 61 result.reset(); 62 } 63 return result; 64 } 65 66 ZlibCompressor::ZlibCompressor(int level, int tmp_buffer_size) { 67 memset(&stream_, 0, sizeof(stream_)); 68 stream_.zalloc = Z_NULL; 69 stream_.zfree = Z_NULL; 70 buffer_size_ = tmp_buffer_size; 71 buffer_.reset(new Bytef[buffer_size_]); 72 initialized_ = (deflateInit(&stream_, level) == Z_OK); 73 } 74 75 ZlibCompressor::~ZlibCompressor() { deflateEnd(&stream_); } 76 77 void ZlibCompressor::Compress(const std::string& uncompressed_content, 78 CompressedBufferT* out) { 79 out->uncompressed_size = uncompressed_content.size(); 80 out->buffer.clear(); 81 stream_.next_in = 82 reinterpret_cast<const Bytef*>(uncompressed_content.c_str()); 83 stream_.avail_in = uncompressed_content.size(); 84 stream_.next_out = buffer_.get(); 85 stream_.avail_out = buffer_size_; 86 unsigned char* buffer_deflate_start_position = 87 reinterpret_cast<unsigned char*>(buffer_.get()); 88 int status; 89 do { 90 // Deflate chunk-wise. 91 // Z_SYNC_FLUSH causes all pending output to be flushed, but doesn't 92 // reset the compression state. 93 // As we do not know how big the compressed buffer will be, we compress 94 // chunk wise and append the flushed content to the output string buffer. 95 // As we store the uncompressed size, we do not have to do this during 96 // decompression. 97 status = deflate(&stream_, Z_SYNC_FLUSH); 98 unsigned char* buffer_deflate_end_position = 99 reinterpret_cast<unsigned char*>(stream_.next_out); 100 if (buffer_deflate_end_position != buffer_deflate_start_position) { 101 out->buffer.insert(out->buffer.end(), buffer_deflate_start_position, 102 buffer_deflate_end_position); 103 stream_.next_out = buffer_deflate_start_position; 104 stream_.avail_out = buffer_size_; 105 } else { 106 break; 107 } 108 } while (status == Z_OK); 109 } 110 111 // Compress rule fields in the model. 112 bool CompressModel(ModelT* model) { 113 std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance(); 114 if (!zlib_compressor) { 115 TC_LOG(ERROR) << "Cannot compress model."; 116 return false; 117 } 118 119 // Compress regex rules. 120 if (model->regex_model != nullptr) { 121 for (int i = 0; i < model->regex_model->patterns.size(); i++) { 122 RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); 123 pattern->compressed_pattern.reset(new CompressedBufferT); 124 zlib_compressor->Compress(pattern->pattern, 125 pattern->compressed_pattern.get()); 126 pattern->pattern.clear(); 127 } 128 } 129 130 // Compress date-time rules. 131 if (model->datetime_model != nullptr) { 132 for (int i = 0; i < model->datetime_model->patterns.size(); i++) { 133 DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); 134 for (int j = 0; j < pattern->regexes.size(); j++) { 135 DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); 136 regex->compressed_pattern.reset(new CompressedBufferT); 137 zlib_compressor->Compress(regex->pattern, 138 regex->compressed_pattern.get()); 139 regex->pattern.clear(); 140 } 141 } 142 for (int i = 0; i < model->datetime_model->extractors.size(); i++) { 143 DatetimeModelExtractorT* extractor = 144 model->datetime_model->extractors[i].get(); 145 extractor->compressed_pattern.reset(new CompressedBufferT); 146 zlib_compressor->Compress(extractor->pattern, 147 extractor->compressed_pattern.get()); 148 extractor->pattern.clear(); 149 } 150 } 151 return true; 152 } 153 154 namespace { 155 156 bool DecompressBuffer(const CompressedBufferT* compressed_pattern, 157 ZlibDecompressor* zlib_decompressor, 158 std::string* uncompressed_pattern) { 159 std::string packed_pattern = 160 PackFlatbuffer<CompressedBuffer>(compressed_pattern); 161 if (!zlib_decompressor->Decompress( 162 LoadAndVerifyFlatbuffer<CompressedBuffer>(packed_pattern), 163 uncompressed_pattern)) { 164 return false; 165 } 166 return true; 167 } 168 169 } // namespace 170 171 bool DecompressModel(ModelT* model) { 172 std::unique_ptr<ZlibDecompressor> zlib_decompressor = 173 ZlibDecompressor::Instance(); 174 if (!zlib_decompressor) { 175 TC_LOG(ERROR) << "Cannot initialize decompressor."; 176 return false; 177 } 178 179 // Decompress regex rules. 180 if (model->regex_model != nullptr) { 181 for (int i = 0; i < model->regex_model->patterns.size(); i++) { 182 RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); 183 if (!DecompressBuffer(pattern->compressed_pattern.get(), 184 zlib_decompressor.get(), &pattern->pattern)) { 185 TC_LOG(ERROR) << "Cannot decompress pattern: " << i; 186 return false; 187 } 188 pattern->compressed_pattern.reset(nullptr); 189 } 190 } 191 192 // Decompress date-time rules. 193 if (model->datetime_model != nullptr) { 194 for (int i = 0; i < model->datetime_model->patterns.size(); i++) { 195 DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); 196 for (int j = 0; j < pattern->regexes.size(); j++) { 197 DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); 198 if (!DecompressBuffer(regex->compressed_pattern.get(), 199 zlib_decompressor.get(), ®ex->pattern)) { 200 TC_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j; 201 return false; 202 } 203 regex->compressed_pattern.reset(nullptr); 204 } 205 } 206 for (int i = 0; i < model->datetime_model->extractors.size(); i++) { 207 DatetimeModelExtractorT* extractor = 208 model->datetime_model->extractors[i].get(); 209 if (!DecompressBuffer(extractor->compressed_pattern.get(), 210 zlib_decompressor.get(), &extractor->pattern)) { 211 TC_LOG(ERROR) << "Cannot decompress pattern: " << i; 212 return false; 213 } 214 extractor->compressed_pattern.reset(nullptr); 215 } 216 } 217 return true; 218 } 219 220 std::string CompressSerializedModel(const std::string& model) { 221 std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str()); 222 TC_CHECK(unpacked_model != nullptr); 223 TC_CHECK(CompressModel(unpacked_model.get())); 224 flatbuffers::FlatBufferBuilder builder; 225 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); 226 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), 227 builder.GetSize()); 228 } 229 230 std::unique_ptr<UniLib::RegexPattern> UncompressMakeRegexPattern( 231 const UniLib& unilib, const flatbuffers::String* uncompressed_pattern, 232 const CompressedBuffer* compressed_pattern, ZlibDecompressor* decompressor, 233 std::string* result_pattern_text) { 234 UnicodeText unicode_regex_pattern; 235 std::string decompressed_pattern; 236 if (compressed_pattern != nullptr && 237 compressed_pattern->buffer() != nullptr) { 238 if (decompressor == nullptr || 239 !decompressor->Decompress(compressed_pattern, &decompressed_pattern)) { 240 TC_LOG(ERROR) << "Cannot decompress pattern."; 241 return nullptr; 242 } 243 unicode_regex_pattern = 244 UTF8ToUnicodeText(decompressed_pattern.data(), 245 decompressed_pattern.size(), /*do_copy=*/false); 246 } else { 247 if (uncompressed_pattern == nullptr) { 248 TC_LOG(ERROR) << "Cannot load uncompressed pattern."; 249 return nullptr; 250 } 251 unicode_regex_pattern = 252 UTF8ToUnicodeText(uncompressed_pattern->c_str(), 253 uncompressed_pattern->Length(), /*do_copy=*/false); 254 } 255 256 if (result_pattern_text != nullptr) { 257 *result_pattern_text = unicode_regex_pattern.ToUTF8String(); 258 } 259 260 std::unique_ptr<UniLib::RegexPattern> regex_pattern = 261 unilib.CreateRegexPattern(unicode_regex_pattern); 262 if (!regex_pattern) { 263 TC_LOG(ERROR) << "Could not create pattern: " 264 << unicode_regex_pattern.ToUTF8String(); 265 } 266 return regex_pattern; 267 } 268 269 } // namespace libtextclassifier2 270