1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Functions to read and write images in PNG format. 17 18 #include <string.h> 19 #include <sys/types.h> 20 #include <zlib.h> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 // NOTE(skal): we don't '#include <setjmp.h>' before png.h as it otherwise 25 // provokes a compile error. We instead let png.h include what is needed. 26 27 #include "tensorflow/core/lib/core/casts.h" 28 #include "tensorflow/core/lib/png/png_io.h" 29 #include "tensorflow/core/platform/cpu_info.h" // endian 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/png.h" 32 33 namespace tensorflow { 34 namespace png { 35 36 //////////////////////////////////////////////////////////////////////////////// 37 // Encode an 8- or 16-bit rgb/grayscale image to PNG string 38 //////////////////////////////////////////////////////////////////////////////// 39 40 namespace { 41 42 #define PTR_INC(type, ptr, del) \ 43 (ptr = reinterpret_cast<type*>(reinterpret_cast<char*>(ptr) + (del))) 44 #define CPTR_INC(type, ptr, del) \ 45 (ptr = reinterpret_cast<const type*>(reinterpret_cast<const char*>(ptr) + \ 46 (del))) 47 48 // Convert from 8 bit components to 16. This works in-place. 49 static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes, 50 int width, int height_in, uint16* p16, 51 int p16_row_bytes) { 52 // Force height*row_bytes computations to use 64 bits. Height*width is 53 // enforced to < 29 bits in decode_png_op.cc, but height*row_bytes is 54 // height*width*channels*(8bit?1:2) which is therefore only constrained to < 55 // 33 bits. 56 int64 height = static_cast<int64>(height_in); 57 58 // Adjust pointers to copy backwards 59 width *= num_comps; 60 CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes + (width - 1) * sizeof(*p8)); 61 PTR_INC(uint16, p16, 62 (height - 1) * p16_row_bytes + (width - 1) * sizeof(*p16)); 63 int bump8 = width * sizeof(*p8) - p8_row_bytes; 64 int bump16 = width * sizeof(*p16) - p16_row_bytes; 65 for (; height-- != 0; 66 CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) { 67 for (int w = width; w-- != 0; --p8, --p16) { 68 uint32 pix = *p8; 69 pix |= pix << 8; 70 *p16 = static_cast<uint16>(pix); 71 } 72 } 73 } 74 75 #undef PTR_INC 76 #undef CPTR_INC 77 78 void ErrorHandler(png_structp png_ptr, png_const_charp msg) { 79 DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr)); 80 ctx->error_condition = true; 81 // To prevent log spam, errors are logged as VLOG(1) instead of ERROR. 82 VLOG(1) << "PNG error: " << msg; 83 longjmp(png_jmpbuf(png_ptr), 1); 84 } 85 86 void WarningHandler(png_structp png_ptr, png_const_charp msg) { 87 LOG(WARNING) << "PNG warning: " << msg; 88 } 89 90 void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) { 91 DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr)); 92 if (static_cast<png_size_t>(ctx->data_left) < length) { 93 memset(data, 0, length); 94 png_error(png_ptr, "More bytes requested to read than available"); 95 } else { 96 memcpy(data, ctx->data, length); 97 ctx->data += length; 98 ctx->data_left -= length; 99 } 100 } 101 102 void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) { 103 string* const s = bit_cast<string*>(png_get_io_ptr(png_ptr)); 104 s->append(bit_cast<const char*>(data), length); 105 } 106 107 void StringWriterFlush(png_structp png_ptr) {} 108 109 char* check_metadata_string(const string& s) { 110 const char* const c_string = s.c_str(); 111 const size_t length = s.size(); 112 if (strlen(c_string) != length) { 113 LOG(WARNING) << "Warning! Metadata contains \\0 character(s)."; 114 } 115 return const_cast<char*>(c_string); 116 } 117 118 } // namespace 119 120 // We move CommonInitDecode() and CommonFinishDecode() 121 // out of the CommonDecode() template to save code space. 122 void CommonFreeDecode(DecodeContext* context) { 123 if (context->png_ptr) { 124 png_destroy_read_struct(&context->png_ptr, 125 context->info_ptr ? &context->info_ptr : nullptr, 126 nullptr); 127 context->png_ptr = nullptr; 128 context->info_ptr = nullptr; 129 } 130 } 131 132 bool DecodeHeader(StringPiece png_string, int* width, int* height, 133 int* components, int* channel_bit_depth, 134 std::vector<std::pair<string, string> >* metadata) { 135 DecodeContext context; 136 // Ask for 16 bits even if there may be fewer. This assures that sniffing 137 // the metadata will succeed in all cases. 138 // 139 // TODO(skal): CommonInitDecode() mixes the operation of sniffing the 140 // metadata with setting up the data conversions. These should be separated. 141 constexpr int kDesiredNumChannels = 1; 142 constexpr int kDesiredChannelBits = 16; 143 if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits, 144 &context)) { 145 return false; 146 } 147 CHECK_NOTNULL(width); 148 *width = static_cast<int>(context.width); 149 CHECK_NOTNULL(height); 150 *height = static_cast<int>(context.height); 151 if (components != nullptr) { 152 switch (context.color_type) { 153 case PNG_COLOR_TYPE_PALETTE: 154 *components = 155 (png_get_valid(context.png_ptr, context.info_ptr, PNG_INFO_tRNS)) 156 ? 4 157 : 3; 158 break; 159 case PNG_COLOR_TYPE_GRAY: 160 *components = 1; 161 break; 162 case PNG_COLOR_TYPE_GRAY_ALPHA: 163 *components = 2; 164 break; 165 case PNG_COLOR_TYPE_RGB: 166 *components = 3; 167 break; 168 case PNG_COLOR_TYPE_RGB_ALPHA: 169 *components = 4; 170 break; 171 default: 172 *components = 0; 173 break; 174 } 175 } 176 if (channel_bit_depth != nullptr) { 177 *channel_bit_depth = context.bit_depth; 178 } 179 if (metadata != nullptr) { 180 metadata->clear(); 181 png_textp text_ptr = nullptr; 182 int num_text = 0; 183 png_get_text(context.png_ptr, context.info_ptr, &text_ptr, &num_text); 184 for (int i = 0; i < num_text; i++) { 185 const png_text& text = text_ptr[i]; 186 metadata->push_back(std::make_pair(text.key, text.text)); 187 } 188 } 189 CommonFreeDecode(&context); 190 return true; 191 } 192 193 bool CommonInitDecode(StringPiece png_string, int desired_channels, 194 int desired_channel_bits, DecodeContext* context) { 195 CHECK(desired_channel_bits == 8 || desired_channel_bits == 16) 196 << "desired_channel_bits = " << desired_channel_bits; 197 CHECK(0 <= desired_channels && desired_channels <= 4) 198 << "desired_channels = " << desired_channels; 199 context->error_condition = false; 200 context->channels = desired_channels; 201 context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context, 202 ErrorHandler, WarningHandler); 203 if (!context->png_ptr) { 204 VLOG(1) << ": DecodePNG <- png_create_read_struct failed"; 205 return false; 206 } 207 if (setjmp(png_jmpbuf(context->png_ptr))) { 208 VLOG(1) << ": DecodePNG error trapped."; 209 CommonFreeDecode(context); 210 return false; 211 } 212 context->info_ptr = png_create_info_struct(context->png_ptr); 213 if (!context->info_ptr || context->error_condition) { 214 VLOG(1) << ": DecodePNG <- png_create_info_struct failed"; 215 CommonFreeDecode(context); 216 return false; 217 } 218 context->data = bit_cast<const uint8*>(png_string.data()); 219 context->data_left = png_string.size(); 220 png_set_read_fn(context->png_ptr, context, StringReader); 221 png_read_info(context->png_ptr, context->info_ptr); 222 png_get_IHDR(context->png_ptr, context->info_ptr, &context->width, 223 &context->height, &context->bit_depth, &context->color_type, 224 nullptr, nullptr, nullptr); 225 if (context->error_condition) { 226 VLOG(1) << ": DecodePNG <- error during header parsing."; 227 CommonFreeDecode(context); 228 return false; 229 } 230 if (context->width <= 0 || context->height <= 0) { 231 VLOG(1) << ": DecodePNG <- invalid dimensions"; 232 CommonFreeDecode(context); 233 return false; 234 } 235 if (context->channels == 0) { // Autodetect number of channels 236 context->channels = png_get_channels(context->png_ptr, context->info_ptr); 237 } 238 const bool has_tRNS = 239 (png_get_valid(context->png_ptr, context->info_ptr, PNG_INFO_tRNS)) != 0; 240 const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0; 241 if ((context->channels & 1) == 0) { // We desire alpha 242 if (has_alpha) { // There is alpha 243 } else if (has_tRNS) { 244 png_set_tRNS_to_alpha(context->png_ptr); // Convert transparency to alpha 245 } else { 246 png_set_add_alpha(context->png_ptr, (1 << context->bit_depth) - 1, 247 PNG_FILLER_AFTER); 248 } 249 } else { // We don't want alpha 250 if (has_alpha || has_tRNS) { // There is alpha 251 png_set_strip_alpha(context->png_ptr); // Strip alpha 252 } 253 } 254 255 // If we only want 8 bits, but are given 16, strip off the LS 8 bits 256 if (context->bit_depth > 8 && desired_channel_bits <= 8) 257 png_set_strip_16(context->png_ptr); 258 259 context->need_to_synthesize_16 = 260 (context->bit_depth <= 8 && desired_channel_bits == 16); 261 262 png_set_packing(context->png_ptr); 263 context->num_passes = png_set_interlace_handling(context->png_ptr); 264 265 if (desired_channel_bits > 8 && port::kLittleEndian) { 266 png_set_swap(context->png_ptr); 267 } 268 269 // convert palette to rgb(a) if needs be. 270 if (context->color_type == PNG_COLOR_TYPE_PALETTE) 271 png_set_palette_to_rgb(context->png_ptr); 272 273 // handle grayscale case for source or destination 274 const bool want_gray = (context->channels < 3); 275 const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR); 276 if (is_gray) { // upconvert gray to 8-bit if needed. 277 if (context->bit_depth < 8) { 278 png_set_expand_gray_1_2_4_to_8(context->png_ptr); 279 } 280 } 281 if (want_gray) { // output is grayscale 282 if (!is_gray) 283 png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587); // 601, JPG 284 } else { // output is rgb(a) 285 if (is_gray) 286 png_set_gray_to_rgb(context->png_ptr); // Enable gray -> RGB conversion 287 } 288 289 // Must come last to incorporate all requested transformations. 290 png_read_update_info(context->png_ptr, context->info_ptr); 291 return true; 292 } 293 294 bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) { 295 CHECK_NOTNULL(data); 296 297 // we need to re-set the jump point so that we trap the errors 298 // within *this* function (and not CommonInitDecode()) 299 if (setjmp(png_jmpbuf(context->png_ptr))) { 300 VLOG(1) << ": DecodePNG error trapped."; 301 CommonFreeDecode(context); 302 return false; 303 } 304 // png_read_row() takes care of offsetting the pointer based on interlacing 305 for (int p = 0; p < context->num_passes; ++p) { 306 png_bytep row = data; 307 for (int h = context->height; h-- != 0; row += row_bytes) { 308 png_read_row(context->png_ptr, row, nullptr); 309 } 310 } 311 312 // Marks iDAT as valid. 313 png_set_rows(context->png_ptr, context->info_ptr, 314 png_get_rows(context->png_ptr, context->info_ptr)); 315 png_read_end(context->png_ptr, context->info_ptr); 316 317 // Clean up. 318 const bool ok = !context->error_condition; 319 CommonFreeDecode(context); 320 321 // Synthesize 16 bits from 8 if requested. 322 if (context->need_to_synthesize_16) 323 Convert8to16(bit_cast<uint8*>(data), context->channels, row_bytes, 324 context->width, context->height, bit_cast<uint16*>(data), 325 row_bytes); 326 return ok; 327 } 328 329 bool WriteImageToBuffer( 330 const void* image, int width, int height, int row_bytes, int num_channels, 331 int channel_bits, int compression, string* png_string, 332 const std::vector<std::pair<string, string> >* metadata) { 333 CHECK_NOTNULL(image); 334 CHECK_NOTNULL(png_string); 335 // Although this case is checked inside png.cc and issues an error message, 336 // that error causes memory corruption. 337 if (width == 0 || height == 0) return false; 338 339 png_string->resize(0); 340 png_infop info_ptr = nullptr; 341 png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr, 342 ErrorHandler, WarningHandler); 343 if (png_ptr == nullptr) return false; 344 if (setjmp(png_jmpbuf(png_ptr))) { 345 png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : nullptr); 346 return false; 347 } 348 info_ptr = png_create_info_struct(png_ptr); 349 if (info_ptr == nullptr) { 350 png_destroy_write_struct(&png_ptr, nullptr); 351 return false; 352 } 353 354 int color_type = -1; 355 switch (num_channels) { 356 case 1: 357 color_type = PNG_COLOR_TYPE_GRAY; 358 break; 359 case 2: 360 color_type = PNG_COLOR_TYPE_GRAY_ALPHA; 361 break; 362 case 3: 363 color_type = PNG_COLOR_TYPE_RGB; 364 break; 365 case 4: 366 color_type = PNG_COLOR_TYPE_RGB_ALPHA; 367 break; 368 default: 369 png_destroy_write_struct(&png_ptr, &info_ptr); 370 return false; 371 } 372 373 png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush); 374 if (compression < 0) compression = Z_DEFAULT_COMPRESSION; 375 png_set_compression_level(png_ptr, compression); 376 png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL); 377 // There used to be a call to png_set_filter here turning off filtering 378 // entirely, but it produced pessimal compression ratios. I'm not sure 379 // why it was there. 380 png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type, 381 PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT, 382 PNG_FILTER_TYPE_DEFAULT); 383 // If we have metadata write to it. 384 if (metadata && !metadata->empty()) { 385 std::vector<png_text> text; 386 for (const auto& pair : *metadata) { 387 png_text txt; 388 txt.compression = PNG_TEXT_COMPRESSION_NONE; 389 txt.key = check_metadata_string(pair.first); 390 txt.text = check_metadata_string(pair.second); 391 text.push_back(txt); 392 } 393 png_set_text(png_ptr, info_ptr, &text[0], text.size()); 394 } 395 396 png_write_info(png_ptr, info_ptr); 397 if (channel_bits > 8 && port::kLittleEndian) png_set_swap(png_ptr); 398 399 png_byte* row = reinterpret_cast<png_byte*>(const_cast<void*>(image)); 400 for (; height--; row += row_bytes) png_write_row(png_ptr, row); 401 png_write_end(png_ptr, nullptr); 402 403 png_destroy_write_struct(&png_ptr, &info_ptr); 404 return true; 405 } 406 407 } // namespace png 408 } // namespace tensorflow 409