1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Functions to read and write images in PNG format. 17 18 #include <string.h> 19 #include <sys/types.h> 20 #include <zlib.h> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 // NOTE(skal): we don't '#include <setjmp.h>' before png.h as it otherwise 25 // provokes a compile error. We instead let png.h include what is needed. 26 27 #include "absl/base/casts.h" 28 #include "tensorflow/core/lib/png/png_io.h" 29 #include "tensorflow/core/platform/byte_order.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/png.h" 32 33 namespace tensorflow { 34 namespace png { 35 36 //////////////////////////////////////////////////////////////////////////////// 37 // Encode an 8- or 16-bit rgb/grayscale image to PNG string 38 //////////////////////////////////////////////////////////////////////////////// 39 40 namespace { 41 42 #define PTR_INC(type, ptr, del) \ 43 (ptr = reinterpret_cast<type*>(reinterpret_cast<char*>(ptr) + (del))) 44 #define CPTR_INC(type, ptr, del) \ 45 (ptr = reinterpret_cast<const type*>(reinterpret_cast<const char*>(ptr) + \ 46 (del))) 47 48 // Convert from 8 bit components to 16. This works in-place. 49 static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes, 50 int width, int height_in, uint16* p16, 51 int p16_row_bytes) { 52 // Force height*row_bytes computations to use 64 bits. Height*width is 53 // enforced to < 29 bits in decode_png_op.cc, but height*row_bytes is 54 // height*width*channels*(8bit?1:2) which is therefore only constrained to < 55 // 33 bits. 56 int64 height = static_cast<int64>(height_in); 57 58 // Adjust pointers to copy backwards 59 width *= num_comps; 60 CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes + (width - 1) * sizeof(*p8)); 61 PTR_INC(uint16, p16, 62 (height - 1) * p16_row_bytes + (width - 1) * sizeof(*p16)); 63 int bump8 = width * sizeof(*p8) - p8_row_bytes; 64 int bump16 = width * sizeof(*p16) - p16_row_bytes; 65 for (; height-- != 0; 66 CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) { 67 for (int w = width; w-- != 0; --p8, --p16) { 68 uint32 pix = *p8; 69 pix |= pix << 8; 70 *p16 = static_cast<uint16>(pix); 71 } 72 } 73 } 74 75 #undef PTR_INC 76 #undef CPTR_INC 77 78 void ErrorHandler(png_structp png_ptr, png_const_charp msg) { 79 DecodeContext* const ctx = 80 absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr)); 81 ctx->error_condition = true; 82 // To prevent log spam, errors are logged as VLOG(1) instead of ERROR. 83 VLOG(1) << "PNG error: " << msg; 84 longjmp(png_jmpbuf(png_ptr), 1); 85 } 86 87 void WarningHandler(png_structp png_ptr, png_const_charp msg) { 88 LOG(WARNING) << "PNG warning: " << msg; 89 } 90 91 void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) { 92 DecodeContext* const ctx = 93 absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr)); 94 if (static_cast<png_size_t>(ctx->data_left) < length) { 95 // Don't zero out the data buffer as it has been lazily allocated (copy on 96 // write) and zeroing it out here can produce an OOM. Since the buffer is 97 // only used for reading data from the image, this doesn't result in any 98 // data leak, so it is safe to just leave the buffer be as it is and just 99 // exit with error. 100 png_error(png_ptr, "More bytes requested to read than available"); 101 } else { 102 memcpy(data, ctx->data, length); 103 ctx->data += length; 104 ctx->data_left -= length; 105 } 106 } 107 108 void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) { 109 string* const s = absl::bit_cast<string*>(png_get_io_ptr(png_ptr)); 110 s->append(absl::bit_cast<const char*>(data), length); 111 } 112 113 void StringWriterFlush(png_structp png_ptr) {} 114 115 char* check_metadata_string(const string& s) { 116 const char* const c_string = s.c_str(); 117 const size_t length = s.size(); 118 if (strlen(c_string) != length) { 119 LOG(WARNING) << "Warning! Metadata contains \\0 character(s)."; 120 } 121 return const_cast<char*>(c_string); 122 } 123 124 } // namespace 125 126 // We move CommonInitDecode() and CommonFinishDecode() 127 // out of the CommonDecode() template to save code space. 128 void CommonFreeDecode(DecodeContext* context) { 129 if (context->png_ptr) { 130 png_destroy_read_struct(&context->png_ptr, 131 context->info_ptr ? &context->info_ptr : nullptr, 132 nullptr); 133 context->png_ptr = nullptr; 134 context->info_ptr = nullptr; 135 } 136 } 137 138 bool DecodeHeader(StringPiece png_string, int* width, int* height, 139 int* components, int* channel_bit_depth, 140 std::vector<std::pair<string, string> >* metadata) { 141 DecodeContext context; 142 // Ask for 16 bits even if there may be fewer. This assures that sniffing 143 // the metadata will succeed in all cases. 144 // 145 // TODO(skal): CommonInitDecode() mixes the operation of sniffing the 146 // metadata with setting up the data conversions. These should be separated. 147 constexpr int kDesiredNumChannels = 1; 148 constexpr int kDesiredChannelBits = 16; 149 if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits, 150 &context)) { 151 return false; 152 } 153 CHECK_NOTNULL(width); 154 *width = static_cast<int>(context.width); 155 CHECK_NOTNULL(height); 156 *height = static_cast<int>(context.height); 157 if (components != nullptr) { 158 switch (context.color_type) { 159 case PNG_COLOR_TYPE_PALETTE: 160 *components = 161 (png_get_valid(context.png_ptr, context.info_ptr, PNG_INFO_tRNS)) 162 ? 4 163 : 3; 164 break; 165 case PNG_COLOR_TYPE_GRAY: 166 *components = 1; 167 break; 168 case PNG_COLOR_TYPE_GRAY_ALPHA: 169 *components = 2; 170 break; 171 case PNG_COLOR_TYPE_RGB: 172 *components = 3; 173 break; 174 case PNG_COLOR_TYPE_RGB_ALPHA: 175 *components = 4; 176 break; 177 default: 178 *components = 0; 179 break; 180 } 181 } 182 if (channel_bit_depth != nullptr) { 183 *channel_bit_depth = context.bit_depth; 184 } 185 if (metadata != nullptr) { 186 metadata->clear(); 187 png_textp text_ptr = nullptr; 188 int num_text = 0; 189 png_get_text(context.png_ptr, context.info_ptr, &text_ptr, &num_text); 190 for (int i = 0; i < num_text; i++) { 191 const png_text& text = text_ptr[i]; 192 metadata->push_back(std::make_pair(text.key, text.text)); 193 } 194 } 195 CommonFreeDecode(&context); 196 return true; 197 } 198 199 bool CommonInitDecode(StringPiece png_string, int desired_channels, 200 int desired_channel_bits, DecodeContext* context) { 201 CHECK(desired_channel_bits == 8 || desired_channel_bits == 16) 202 << "desired_channel_bits = " << desired_channel_bits; 203 CHECK(0 <= desired_channels && desired_channels <= 4) 204 << "desired_channels = " << desired_channels; 205 context->error_condition = false; 206 context->channels = desired_channels; 207 context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context, 208 ErrorHandler, WarningHandler); 209 if (!context->png_ptr) { 210 VLOG(1) << ": DecodePNG <- png_create_read_struct failed"; 211 return false; 212 } 213 if (setjmp(png_jmpbuf(context->png_ptr))) { 214 VLOG(1) << ": DecodePNG error trapped."; 215 CommonFreeDecode(context); 216 return false; 217 } 218 context->info_ptr = png_create_info_struct(context->png_ptr); 219 if (!context->info_ptr || context->error_condition) { 220 VLOG(1) << ": DecodePNG <- png_create_info_struct failed"; 221 CommonFreeDecode(context); 222 return false; 223 } 224 context->data = absl::bit_cast<const uint8*>(png_string.data()); 225 context->data_left = png_string.size(); 226 png_set_read_fn(context->png_ptr, context, StringReader); 227 png_read_info(context->png_ptr, context->info_ptr); 228 png_get_IHDR(context->png_ptr, context->info_ptr, &context->width, 229 &context->height, &context->bit_depth, &context->color_type, 230 nullptr, nullptr, nullptr); 231 if (context->error_condition) { 232 VLOG(1) << ": DecodePNG <- error during header parsing."; 233 CommonFreeDecode(context); 234 return false; 235 } 236 if (context->width <= 0 || context->height <= 0) { 237 VLOG(1) << ": DecodePNG <- invalid dimensions"; 238 CommonFreeDecode(context); 239 return false; 240 } 241 const bool has_tRNS = 242 (png_get_valid(context->png_ptr, context->info_ptr, PNG_INFO_tRNS)) != 0; 243 if (context->channels == 0) { // Autodetect number of channels 244 if (context->color_type == PNG_COLOR_TYPE_PALETTE) { 245 if (has_tRNS) { 246 context->channels = 4; // RGB + A(tRNS) 247 } else { 248 context->channels = 3; // RGB 249 } 250 } else { 251 context->channels = png_get_channels(context->png_ptr, context->info_ptr); 252 } 253 } 254 const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0; 255 if ((context->channels & 1) == 0) { // We desire alpha 256 if (has_alpha) { // There is alpha 257 } else if (has_tRNS) { 258 png_set_tRNS_to_alpha(context->png_ptr); // Convert transparency to alpha 259 } else { 260 png_set_add_alpha(context->png_ptr, (1 << context->bit_depth) - 1, 261 PNG_FILLER_AFTER); 262 } 263 } else { // We don't want alpha 264 if (has_alpha || has_tRNS) { // There is alpha 265 png_set_strip_alpha(context->png_ptr); // Strip alpha 266 } 267 } 268 269 // If we only want 8 bits, but are given 16, strip off the LS 8 bits 270 if (context->bit_depth > 8 && desired_channel_bits <= 8) 271 png_set_strip_16(context->png_ptr); 272 273 context->need_to_synthesize_16 = 274 (context->bit_depth <= 8 && desired_channel_bits == 16); 275 276 png_set_packing(context->png_ptr); 277 context->num_passes = png_set_interlace_handling(context->png_ptr); 278 279 if (desired_channel_bits > 8 && port::kLittleEndian) { 280 png_set_swap(context->png_ptr); 281 } 282 283 // convert palette to rgb(a) if needs be. 284 if (context->color_type == PNG_COLOR_TYPE_PALETTE) 285 png_set_palette_to_rgb(context->png_ptr); 286 287 // handle grayscale case for source or destination 288 const bool want_gray = (context->channels < 3); 289 const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR); 290 if (is_gray) { // upconvert gray to 8-bit if needed. 291 if (context->bit_depth < 8) { 292 png_set_expand_gray_1_2_4_to_8(context->png_ptr); 293 } 294 } 295 if (want_gray) { // output is grayscale 296 if (!is_gray) 297 png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587); // 601, JPG 298 } else { // output is rgb(a) 299 if (is_gray) 300 png_set_gray_to_rgb(context->png_ptr); // Enable gray -> RGB conversion 301 } 302 303 // Must come last to incorporate all requested transformations. 304 png_read_update_info(context->png_ptr, context->info_ptr); 305 return true; 306 } 307 308 bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) { 309 CHECK_NOTNULL(data); 310 311 // we need to re-set the jump point so that we trap the errors 312 // within *this* function (and not CommonInitDecode()) 313 if (setjmp(png_jmpbuf(context->png_ptr))) { 314 VLOG(1) << ": DecodePNG error trapped."; 315 CommonFreeDecode(context); 316 return false; 317 } 318 // png_read_row() takes care of offsetting the pointer based on interlacing 319 for (int p = 0; p < context->num_passes; ++p) { 320 png_bytep row = data; 321 for (int h = context->height; h-- != 0; row += row_bytes) { 322 png_read_row(context->png_ptr, row, nullptr); 323 } 324 } 325 326 // Marks iDAT as valid. 327 png_set_rows(context->png_ptr, context->info_ptr, 328 png_get_rows(context->png_ptr, context->info_ptr)); 329 png_read_end(context->png_ptr, context->info_ptr); 330 331 // Clean up. 332 const bool ok = !context->error_condition; 333 CommonFreeDecode(context); 334 335 // Synthesize 16 bits from 8 if requested. 336 if (context->need_to_synthesize_16) 337 Convert8to16(absl::bit_cast<uint8*>(data), context->channels, row_bytes, 338 context->width, context->height, absl::bit_cast<uint16*>(data), 339 row_bytes); 340 return ok; 341 } 342 343 bool WriteImageToBuffer( 344 const void* image, int width, int height, int row_bytes, int num_channels, 345 int channel_bits, int compression, string* png_string, 346 const std::vector<std::pair<string, string> >* metadata) { 347 CHECK_NOTNULL(image); 348 CHECK_NOTNULL(png_string); 349 // Although this case is checked inside png.cc and issues an error message, 350 // that error causes memory corruption. 351 if (width == 0 || height == 0) return false; 352 353 png_string->resize(0); 354 png_infop info_ptr = nullptr; 355 png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr, 356 ErrorHandler, WarningHandler); 357 if (png_ptr == nullptr) return false; 358 if (setjmp(png_jmpbuf(png_ptr))) { 359 png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : nullptr); 360 return false; 361 } 362 info_ptr = png_create_info_struct(png_ptr); 363 if (info_ptr == nullptr) { 364 png_destroy_write_struct(&png_ptr, nullptr); 365 return false; 366 } 367 368 int color_type = -1; 369 switch (num_channels) { 370 case 1: 371 color_type = PNG_COLOR_TYPE_GRAY; 372 break; 373 case 2: 374 color_type = PNG_COLOR_TYPE_GRAY_ALPHA; 375 break; 376 case 3: 377 color_type = PNG_COLOR_TYPE_RGB; 378 break; 379 case 4: 380 color_type = PNG_COLOR_TYPE_RGB_ALPHA; 381 break; 382 default: 383 png_destroy_write_struct(&png_ptr, &info_ptr); 384 return false; 385 } 386 387 png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush); 388 if (compression < 0) compression = Z_DEFAULT_COMPRESSION; 389 png_set_compression_level(png_ptr, compression); 390 png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL); 391 // There used to be a call to png_set_filter here turning off filtering 392 // entirely, but it produced pessimal compression ratios. I'm not sure 393 // why it was there. 394 png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type, 395 PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT, 396 PNG_FILTER_TYPE_DEFAULT); 397 // If we have metadata write to it. 398 if (metadata && !metadata->empty()) { 399 std::vector<png_text> text; 400 for (const auto& pair : *metadata) { 401 png_text txt; 402 txt.compression = PNG_TEXT_COMPRESSION_NONE; 403 txt.key = check_metadata_string(pair.first); 404 txt.text = check_metadata_string(pair.second); 405 text.push_back(txt); 406 } 407 png_set_text(png_ptr, info_ptr, &text[0], text.size()); 408 } 409 410 png_write_info(png_ptr, info_ptr); 411 if (channel_bits > 8 && port::kLittleEndian) png_set_swap(png_ptr); 412 413 png_byte* row = reinterpret_cast<png_byte*>(const_cast<void*>(image)); 414 for (; height--; row += row_bytes) png_write_row(png_ptr, row); 415 png_write_end(png_ptr, nullptr); 416 417 png_destroy_write_struct(&png_ptr, &info_ptr); 418 return true; 419 } 420 421 } // namespace png 422 } // namespace tensorflow 423