Home | History | Annotate | Download | only in png
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Functions to read and write images in PNG format.
     17 
     18 #include <string.h>
     19 #include <sys/types.h>
     20 #include <zlib.h>
     21 #include <string>
     22 #include <utility>
     23 #include <vector>
     24 // NOTE(skal): we don't '#include <setjmp.h>' before png.h as it otherwise
     25 // provokes a compile error. We instead let png.h include what is needed.
     26 
     27 #include "tensorflow/core/lib/core/casts.h"
     28 #include "tensorflow/core/lib/png/png_io.h"
     29 #include "tensorflow/core/platform/cpu_info.h"  // endian
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/png.h"
     32 
     33 namespace tensorflow {
     34 namespace png {
     35 
     36 ////////////////////////////////////////////////////////////////////////////////
     37 // Encode an 8- or 16-bit rgb/grayscale image to PNG string
     38 ////////////////////////////////////////////////////////////////////////////////
     39 
     40 namespace {
     41 
     42 #define PTR_INC(type, ptr, del) \
     43   (ptr = reinterpret_cast<type*>(reinterpret_cast<char*>(ptr) + (del)))
     44 #define CPTR_INC(type, ptr, del)                                            \
     45   (ptr = reinterpret_cast<const type*>(reinterpret_cast<const char*>(ptr) + \
     46                                        (del)))
     47 
     48 // Convert from 8 bit components to 16. This works in-place.
     49 static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes,
     50                          int width, int height_in, uint16* p16,
     51                          int p16_row_bytes) {
     52   // Force height*row_bytes computations to use 64 bits. Height*width is
     53   // enforced to < 29 bits in decode_png_op.cc, but height*row_bytes is
     54   // height*width*channels*(8bit?1:2) which is therefore only constrained to <
     55   // 33 bits.
     56   int64 height = static_cast<int64>(height_in);
     57 
     58   // Adjust pointers to copy backwards
     59   width *= num_comps;
     60   CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes + (width - 1) * sizeof(*p8));
     61   PTR_INC(uint16, p16,
     62           (height - 1) * p16_row_bytes + (width - 1) * sizeof(*p16));
     63   int bump8 = width * sizeof(*p8) - p8_row_bytes;
     64   int bump16 = width * sizeof(*p16) - p16_row_bytes;
     65   for (; height-- != 0;
     66        CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) {
     67     for (int w = width; w-- != 0; --p8, --p16) {
     68       uint32 pix = *p8;
     69       pix |= pix << 8;
     70       *p16 = static_cast<uint16>(pix);
     71     }
     72   }
     73 }
     74 
     75 #undef PTR_INC
     76 #undef CPTR_INC
     77 
     78 void ErrorHandler(png_structp png_ptr, png_const_charp msg) {
     79   DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
     80   ctx->error_condition = true;
     81   // To prevent log spam, errors are logged as VLOG(1) instead of ERROR.
     82   VLOG(1) << "PNG error: " << msg;
     83   longjmp(png_jmpbuf(png_ptr), 1);
     84 }
     85 
     86 void WarningHandler(png_structp png_ptr, png_const_charp msg) {
     87   LOG(WARNING) << "PNG warning: " << msg;
     88 }
     89 
     90 void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) {
     91   DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
     92   if (static_cast<png_size_t>(ctx->data_left) < length) {
     93     memset(data, 0, length);
     94     png_error(png_ptr, "More bytes requested to read than available");
     95   } else {
     96     memcpy(data, ctx->data, length);
     97     ctx->data += length;
     98     ctx->data_left -= length;
     99   }
    100 }
    101 
    102 void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) {
    103   string* const s = bit_cast<string*>(png_get_io_ptr(png_ptr));
    104   s->append(bit_cast<const char*>(data), length);
    105 }
    106 
    107 void StringWriterFlush(png_structp png_ptr) {}
    108 
    109 char* check_metadata_string(const string& s) {
    110   const char* const c_string = s.c_str();
    111   const size_t length = s.size();
    112   if (strlen(c_string) != length) {
    113     LOG(WARNING) << "Warning! Metadata contains \\0 character(s).";
    114   }
    115   return const_cast<char*>(c_string);
    116 }
    117 
    118 }  // namespace
    119 
    120 // We move CommonInitDecode() and CommonFinishDecode()
    121 // out of the CommonDecode() template to save code space.
    122 void CommonFreeDecode(DecodeContext* context) {
    123   if (context->png_ptr) {
    124     png_destroy_read_struct(&context->png_ptr,
    125                             context->info_ptr ? &context->info_ptr : nullptr,
    126                             nullptr);
    127     context->png_ptr = nullptr;
    128     context->info_ptr = nullptr;
    129   }
    130 }
    131 
    132 bool DecodeHeader(StringPiece png_string, int* width, int* height,
    133                   int* components, int* channel_bit_depth,
    134                   std::vector<std::pair<string, string> >* metadata) {
    135   DecodeContext context;
    136   // Ask for 16 bits even if there may be fewer.  This assures that sniffing
    137   // the metadata will succeed in all cases.
    138   //
    139   // TODO(skal): CommonInitDecode() mixes the operation of sniffing the
    140   // metadata with setting up the data conversions.  These should be separated.
    141   constexpr int kDesiredNumChannels = 1;
    142   constexpr int kDesiredChannelBits = 16;
    143   if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits,
    144                         &context)) {
    145     return false;
    146   }
    147   CHECK_NOTNULL(width);
    148   *width = static_cast<int>(context.width);
    149   CHECK_NOTNULL(height);
    150   *height = static_cast<int>(context.height);
    151   if (components != nullptr) {
    152     switch (context.color_type) {
    153       case PNG_COLOR_TYPE_PALETTE:
    154         *components =
    155             (png_get_valid(context.png_ptr, context.info_ptr, PNG_INFO_tRNS))
    156                 ? 4
    157                 : 3;
    158         break;
    159       case PNG_COLOR_TYPE_GRAY:
    160         *components = 1;
    161         break;
    162       case PNG_COLOR_TYPE_GRAY_ALPHA:
    163         *components = 2;
    164         break;
    165       case PNG_COLOR_TYPE_RGB:
    166         *components = 3;
    167         break;
    168       case PNG_COLOR_TYPE_RGB_ALPHA:
    169         *components = 4;
    170         break;
    171       default:
    172         *components = 0;
    173         break;
    174     }
    175   }
    176   if (channel_bit_depth != nullptr) {
    177     *channel_bit_depth = context.bit_depth;
    178   }
    179   if (metadata != nullptr) {
    180     metadata->clear();
    181     png_textp text_ptr = nullptr;
    182     int num_text = 0;
    183     png_get_text(context.png_ptr, context.info_ptr, &text_ptr, &num_text);
    184     for (int i = 0; i < num_text; i++) {
    185       const png_text& text = text_ptr[i];
    186       metadata->push_back(std::make_pair(text.key, text.text));
    187     }
    188   }
    189   CommonFreeDecode(&context);
    190   return true;
    191 }
    192 
    193 bool CommonInitDecode(StringPiece png_string, int desired_channels,
    194                       int desired_channel_bits, DecodeContext* context) {
    195   CHECK(desired_channel_bits == 8 || desired_channel_bits == 16)
    196       << "desired_channel_bits = " << desired_channel_bits;
    197   CHECK(0 <= desired_channels && desired_channels <= 4)
    198       << "desired_channels = " << desired_channels;
    199   context->error_condition = false;
    200   context->channels = desired_channels;
    201   context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context,
    202                                             ErrorHandler, WarningHandler);
    203   if (!context->png_ptr) {
    204     VLOG(1) << ": DecodePNG <- png_create_read_struct failed";
    205     return false;
    206   }
    207   if (setjmp(png_jmpbuf(context->png_ptr))) {
    208     VLOG(1) << ": DecodePNG error trapped.";
    209     CommonFreeDecode(context);
    210     return false;
    211   }
    212   context->info_ptr = png_create_info_struct(context->png_ptr);
    213   if (!context->info_ptr || context->error_condition) {
    214     VLOG(1) << ": DecodePNG <- png_create_info_struct failed";
    215     CommonFreeDecode(context);
    216     return false;
    217   }
    218   context->data = bit_cast<const uint8*>(png_string.data());
    219   context->data_left = png_string.size();
    220   png_set_read_fn(context->png_ptr, context, StringReader);
    221   png_read_info(context->png_ptr, context->info_ptr);
    222   png_get_IHDR(context->png_ptr, context->info_ptr, &context->width,
    223                &context->height, &context->bit_depth, &context->color_type,
    224                nullptr, nullptr, nullptr);
    225   if (context->error_condition) {
    226     VLOG(1) << ": DecodePNG <- error during header parsing.";
    227     CommonFreeDecode(context);
    228     return false;
    229   }
    230   if (context->width <= 0 || context->height <= 0) {
    231     VLOG(1) << ": DecodePNG <- invalid dimensions";
    232     CommonFreeDecode(context);
    233     return false;
    234   }
    235   if (context->channels == 0) {  // Autodetect number of channels
    236     context->channels = png_get_channels(context->png_ptr, context->info_ptr);
    237   }
    238   const bool has_tRNS =
    239       (png_get_valid(context->png_ptr, context->info_ptr, PNG_INFO_tRNS)) != 0;
    240   const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0;
    241   if ((context->channels & 1) == 0) {  // We desire alpha
    242     if (has_alpha) {                   // There is alpha
    243     } else if (has_tRNS) {
    244       png_set_tRNS_to_alpha(context->png_ptr);  // Convert transparency to alpha
    245     } else {
    246       png_set_add_alpha(context->png_ptr, (1 << context->bit_depth) - 1,
    247                         PNG_FILLER_AFTER);
    248     }
    249   } else {                                    // We don't want alpha
    250     if (has_alpha || has_tRNS) {              // There is alpha
    251       png_set_strip_alpha(context->png_ptr);  // Strip alpha
    252     }
    253   }
    254 
    255   // If we only want 8 bits, but are given 16, strip off the LS 8 bits
    256   if (context->bit_depth > 8 && desired_channel_bits <= 8)
    257     png_set_strip_16(context->png_ptr);
    258 
    259   context->need_to_synthesize_16 =
    260       (context->bit_depth <= 8 && desired_channel_bits == 16);
    261 
    262   png_set_packing(context->png_ptr);
    263   context->num_passes = png_set_interlace_handling(context->png_ptr);
    264 
    265   if (desired_channel_bits > 8 && port::kLittleEndian) {
    266     png_set_swap(context->png_ptr);
    267   }
    268 
    269   // convert palette to rgb(a) if needs be.
    270   if (context->color_type == PNG_COLOR_TYPE_PALETTE)
    271     png_set_palette_to_rgb(context->png_ptr);
    272 
    273   // handle grayscale case for source or destination
    274   const bool want_gray = (context->channels < 3);
    275   const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR);
    276   if (is_gray) {  // upconvert gray to 8-bit if needed.
    277     if (context->bit_depth < 8) {
    278       png_set_expand_gray_1_2_4_to_8(context->png_ptr);
    279     }
    280   }
    281   if (want_gray) {  // output is grayscale
    282     if (!is_gray)
    283       png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587);  // 601, JPG
    284   } else {  // output is rgb(a)
    285     if (is_gray)
    286       png_set_gray_to_rgb(context->png_ptr);  // Enable gray -> RGB conversion
    287   }
    288 
    289   // Must come last to incorporate all requested transformations.
    290   png_read_update_info(context->png_ptr, context->info_ptr);
    291   return true;
    292 }
    293 
    294 bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
    295   CHECK_NOTNULL(data);
    296 
    297   // we need to re-set the jump point so that we trap the errors
    298   // within *this* function (and not CommonInitDecode())
    299   if (setjmp(png_jmpbuf(context->png_ptr))) {
    300     VLOG(1) << ": DecodePNG error trapped.";
    301     CommonFreeDecode(context);
    302     return false;
    303   }
    304   // png_read_row() takes care of offsetting the pointer based on interlacing
    305   for (int p = 0; p < context->num_passes; ++p) {
    306     png_bytep row = data;
    307     for (int h = context->height; h-- != 0; row += row_bytes) {
    308       png_read_row(context->png_ptr, row, nullptr);
    309     }
    310   }
    311 
    312   // Marks iDAT as valid.
    313   png_set_rows(context->png_ptr, context->info_ptr,
    314                png_get_rows(context->png_ptr, context->info_ptr));
    315   png_read_end(context->png_ptr, context->info_ptr);
    316 
    317   // Clean up.
    318   const bool ok = !context->error_condition;
    319   CommonFreeDecode(context);
    320 
    321   // Synthesize 16 bits from 8 if requested.
    322   if (context->need_to_synthesize_16)
    323     Convert8to16(bit_cast<uint8*>(data), context->channels, row_bytes,
    324                  context->width, context->height, bit_cast<uint16*>(data),
    325                  row_bytes);
    326   return ok;
    327 }
    328 
    329 bool WriteImageToBuffer(
    330     const void* image, int width, int height, int row_bytes, int num_channels,
    331     int channel_bits, int compression, string* png_string,
    332     const std::vector<std::pair<string, string> >* metadata) {
    333   CHECK_NOTNULL(image);
    334   CHECK_NOTNULL(png_string);
    335   // Although this case is checked inside png.cc and issues an error message,
    336   // that error causes memory corruption.
    337   if (width == 0 || height == 0) return false;
    338 
    339   png_string->resize(0);
    340   png_infop info_ptr = nullptr;
    341   png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr,
    342                                                 ErrorHandler, WarningHandler);
    343   if (png_ptr == nullptr) return false;
    344   if (setjmp(png_jmpbuf(png_ptr))) {
    345     png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : nullptr);
    346     return false;
    347   }
    348   info_ptr = png_create_info_struct(png_ptr);
    349   if (info_ptr == nullptr) {
    350     png_destroy_write_struct(&png_ptr, nullptr);
    351     return false;
    352   }
    353 
    354   int color_type = -1;
    355   switch (num_channels) {
    356     case 1:
    357       color_type = PNG_COLOR_TYPE_GRAY;
    358       break;
    359     case 2:
    360       color_type = PNG_COLOR_TYPE_GRAY_ALPHA;
    361       break;
    362     case 3:
    363       color_type = PNG_COLOR_TYPE_RGB;
    364       break;
    365     case 4:
    366       color_type = PNG_COLOR_TYPE_RGB_ALPHA;
    367       break;
    368     default:
    369       png_destroy_write_struct(&png_ptr, &info_ptr);
    370       return false;
    371   }
    372 
    373   png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush);
    374   if (compression < 0) compression = Z_DEFAULT_COMPRESSION;
    375   png_set_compression_level(png_ptr, compression);
    376   png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL);
    377   // There used to be a call to png_set_filter here turning off filtering
    378   // entirely, but it produced pessimal compression ratios.  I'm not sure
    379   // why it was there.
    380   png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type,
    381                PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT,
    382                PNG_FILTER_TYPE_DEFAULT);
    383   // If we have metadata write to it.
    384   if (metadata && !metadata->empty()) {
    385     std::vector<png_text> text;
    386     for (const auto& pair : *metadata) {
    387       png_text txt;
    388       txt.compression = PNG_TEXT_COMPRESSION_NONE;
    389       txt.key = check_metadata_string(pair.first);
    390       txt.text = check_metadata_string(pair.second);
    391       text.push_back(txt);
    392     }
    393     png_set_text(png_ptr, info_ptr, &text[0], text.size());
    394   }
    395 
    396   png_write_info(png_ptr, info_ptr);
    397   if (channel_bits > 8 && port::kLittleEndian) png_set_swap(png_ptr);
    398 
    399   png_byte* row = reinterpret_cast<png_byte*>(const_cast<void*>(image));
    400   for (; height--; row += row_bytes) png_write_row(png_ptr, row);
    401   png_write_end(png_ptr, nullptr);
    402 
    403   png_destroy_write_struct(&png_ptr, &info_ptr);
    404   return true;
    405 }
    406 
    407 }  // namespace png
    408 }  // namespace tensorflow
    409