Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc
     17 
     18 #include <memory>
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/framework/tensor_shape.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/lib/gif/gif_io.h"
     26 #include "tensorflow/core/lib/jpeg/jpeg_mem.h"
     27 #include "tensorflow/core/lib/png/png_io.h"
     28 #include "tensorflow/core/lib/strings/str_util.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 
     31 namespace tensorflow {
     32 namespace {
     33 
     34 enum FileFormat {
     35   kUnknownFormat = 0,
     36   kPngFormat = 1,
     37   kJpgFormat = 2,
     38   kGifFormat = 3,
     39 };
     40 
     41 // Classify the contents of a file based on starting bytes (the magic number).
     42 FileFormat ClassifyFileFormat(StringPiece data) {
     43   // The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three
     44   if (data.starts_with("\xff\xd8\xff")) return kJpgFormat;
     45   if (data.starts_with("\x89PNG\r\n\x1a\n")) return kPngFormat;
     46   if (data.starts_with("\x47\x49\x46\x38")) return kGifFormat;
     47   return kUnknownFormat;
     48 }
     49 
     50 string FileFormatString(FileFormat magic, StringPiece data) {
     51   switch (magic) {
     52     case kPngFormat:
     53       return "PNG";
     54     case kJpgFormat:
     55       return "JPEG";
     56     case kGifFormat:
     57       return "GIF";
     58     default: {
     59       if (data.empty()) return "empty file";
     60       return strings::StrCat("unknown format starting with '",
     61                              str_util::CEscape(data.substr(0, 16)), "'");
     62     }
     63   }
     64 }
     65 
     66 // Decode an image (either jpeg, png, or gif).  We use a single op so that
     67 // users don't have to care about which format they have.
     68 class DecodeImageOp : public OpKernel {
     69  public:
     70   explicit DecodeImageOp(OpKernelConstruction* context) : OpKernel(context) {
     71     // Determine which op we are: jpeg, png, gif, or any
     72     if (type_string() == "DecodeJpeg") {
     73       format_ = kJpgFormat;
     74     } else if (type_string() == "DecodeAndCropJpeg") {
     75       format_ = kJpgFormat;
     76       flags_.crop = true;
     77     } else if (type_string() == "DecodePng") {
     78       format_ = kPngFormat;
     79     } else if (type_string() == "DecodeGif") {
     80       format_ = kGifFormat;
     81     } else {
     82       OP_REQUIRES_OK(context,
     83                      errors::InvalidArgument("Bad op type ", type_string()));
     84     }
     85 
     86     if (format_ == kGifFormat) {
     87       channels_ = 3;
     88     } else {
     89       OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_));
     90       OP_REQUIRES(
     91           context,
     92           channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4,
     93           errors::InvalidArgument("channels must be 0, 1, 3, or 4, got ",
     94                                   channels_));
     95     }
     96     flags_.components = channels_;
     97 
     98     // In the case of png, we support uint16 output
     99     if (format_ == kPngFormat) {
    100       DataType dt;
    101       OP_REQUIRES_OK(context, context->GetAttr("dtype", &dt));
    102       OP_REQUIRES(
    103           context, dt == DataType::DT_UINT8 || dt == DataType::DT_UINT16,
    104           errors::InvalidArgument("Type must be uint8 or uint16, got ", dt));
    105       if (dt == DataType::DT_UINT8) {
    106         channel_bits_ = 8;
    107       } else {
    108         channel_bits_ = 16;
    109       }
    110     }
    111 
    112     // The TensorFlow-chosen default for jpeg decoding is IFAST, sacrificing
    113     // image quality for speed.
    114     flags_.dct_method = JDCT_IFAST;
    115 
    116     if (format_ == kJpgFormat) {
    117       OP_REQUIRES_OK(context, context->GetAttr("ratio", &flags_.ratio));
    118       OP_REQUIRES(context,
    119                   flags_.ratio == 1 || flags_.ratio == 2 || flags_.ratio == 4 ||
    120                       flags_.ratio == 8,
    121                   errors::InvalidArgument("ratio must be 1, 2, 4, or 8, got ",
    122                                           flags_.ratio));
    123       OP_REQUIRES_OK(context, context->GetAttr("fancy_upscaling",
    124                                                &flags_.fancy_upscaling));
    125       OP_REQUIRES_OK(context,
    126                      context->GetAttr("try_recover_truncated",
    127                                       &flags_.try_recover_truncated_jpeg));
    128       OP_REQUIRES_OK(context,
    129                      context->GetAttr("acceptable_fraction",
    130                                       &flags_.min_acceptable_fraction));
    131 
    132       string dct_method;
    133       OP_REQUIRES_OK(context, context->GetAttr("dct_method", &dct_method));
    134       OP_REQUIRES(
    135           context,
    136           (dct_method.empty() || dct_method == "INTEGER_FAST" ||
    137            dct_method == "INTEGER_ACCURATE"),
    138           errors::InvalidArgument("dct_method must be one of "
    139                                   "{'', 'INTEGER_FAST', 'INTEGER_ACCURATE'}"));
    140       if (dct_method == "INTEGER_FAST") {
    141         flags_.dct_method = JDCT_IFAST;
    142       } else if (dct_method == "INTEGER_ACCURATE") {
    143         flags_.dct_method = JDCT_ISLOW;
    144       }
    145     }
    146   }
    147 
    148   void Compute(OpKernelContext* context) override {
    149     const Tensor& contents = context->input(0);
    150     OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()),
    151                 errors::InvalidArgument("contents must be scalar, got shape ",
    152                                         contents.shape().DebugString()));
    153 
    154     // Determine format
    155     const StringPiece input = contents.scalar<string>()();
    156     const auto magic = ClassifyFileFormat(input);
    157     OP_REQUIRES(
    158         context,
    159         magic == kJpgFormat || magic == kPngFormat || magic == kGifFormat,
    160         errors::InvalidArgument("Expected image (JPEG, PNG, or GIF), got ",
    161                                 FileFormatString(magic, input)));
    162     OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(),
    163                 errors::InvalidArgument(
    164                     FileFormatString(magic, input),
    165                     " contents are too large for int: ", input.size()));
    166     OP_REQUIRES(context, magic == kPngFormat || channel_bits_ == 8,
    167                 errors::InvalidArgument(FileFormatString(magic, input),
    168                                         " does not support uint16 output"));
    169 
    170     switch (magic) {
    171       case kJpgFormat:
    172         DecodeJpeg(context, input);
    173         break;
    174       case kPngFormat:
    175         DecodePng(context, input);
    176         break;
    177       case kGifFormat:
    178         DecodeGif(context, input);
    179         break;
    180       default:
    181         LOG(FATAL) << "Should never get here after check above";
    182         break;
    183     }
    184   }
    185 
    186   void DecodeJpeg(OpKernelContext* context, StringPiece input) {
    187     OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3,
    188                 errors::InvalidArgument(
    189                     "channels must be 0, 1, or 3 for JPEG, got ", channels_));
    190 
    191     // Use local copy of flags to avoid race condition as the class member is
    192     // shared among different invocations.
    193     jpeg::UncompressFlags flags = flags_;
    194     if (flags.crop) {
    195       // Update flags to include crop window.
    196       const Tensor& crop_window = context->input(1);
    197       OP_REQUIRES(context, crop_window.dims() == 1,
    198                   errors::InvalidArgument("crop_window must be 1-D, got shape ",
    199                                           crop_window.shape().DebugString()));
    200       OP_REQUIRES(context, crop_window.dim_size(0) == 4,
    201                   errors::InvalidArgument("crop_size must have four elements ",
    202                                           crop_window.shape().DebugString()));
    203       auto crop_window_vec = crop_window.vec<int32>();
    204       flags.crop_y = crop_window_vec(0);
    205       flags.crop_x = crop_window_vec(1);
    206       flags.crop_height = crop_window_vec(2);
    207       flags.crop_width = crop_window_vec(3);
    208     }
    209 
    210     // Decode jpeg, allocating tensor once the size is known.
    211     Tensor* output = nullptr;
    212     OP_REQUIRES(
    213         context,
    214         jpeg::Uncompress(
    215             input.data(), input.size(), flags, nullptr /* nwarn */,
    216             [=, &output](int width, int height, int channels) -> uint8* {
    217               Status status(context->allocate_output(
    218                   0,
    219                   format_ == kGifFormat
    220                       ? TensorShape({1, height, width, channels})
    221                       : TensorShape({height, width, channels}),
    222                   &output));
    223               if (!status.ok()) {
    224                 VLOG(1) << status;
    225                 context->SetStatus(status);
    226                 return nullptr;
    227               }
    228               return output->flat<uint8>().data();
    229             }),
    230         errors::InvalidArgument("Invalid JPEG data or crop window, data size ",
    231                                 input.size()));
    232   }
    233 
    234   void DecodePng(OpKernelContext* context, StringPiece input) {
    235     // Start decoding png to get shape details
    236     png::DecodeContext decode;
    237     OP_REQUIRES(context,
    238                 png::CommonInitDecode(input, channels_, channel_bits_, &decode),
    239                 errors::InvalidArgument("Invalid PNG header, data size ",
    240                                         input.size()));
    241 
    242     // Verify that width and height are not too large:
    243     // - verify width and height don't overflow int.
    244     // - width can later be multiplied by channels_ and sizeof(uint16), so
    245     //   verify single dimension is not too large.
    246     // - verify when width and height are multiplied together, there are a few
    247     //   bits to spare as well.
    248     const int width = static_cast<int>(decode.width);
    249     const int height = static_cast<int>(decode.height);
    250     const int64 total_size =
    251         static_cast<int64>(width) * static_cast<int64>(height);
    252     if (width != static_cast<int64>(decode.width) || width <= 0 ||
    253         width >= (1LL << 27) || height != static_cast<int64>(decode.height) ||
    254         height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) {
    255       png::CommonFreeDecode(&decode);
    256       OP_REQUIRES(context, false,
    257                   errors::InvalidArgument("PNG size too large for int: ",
    258                                           decode.width, " by ", decode.height));
    259     }
    260 
    261     // Allocate tensor
    262     Tensor* output = nullptr;
    263     const auto status = context->allocate_output(
    264         0,
    265         format_ == kGifFormat ? TensorShape({1, height, width, decode.channels})
    266                               : TensorShape({height, width, decode.channels}),
    267         &output);
    268     if (!status.ok()) png::CommonFreeDecode(&decode);
    269     OP_REQUIRES_OK(context, status);
    270 
    271     if (channel_bits_ == 8) {
    272       // Finish decoding png
    273       OP_REQUIRES(
    274           context,
    275           png::CommonFinishDecode(
    276               reinterpret_cast<png_bytep>(output->flat<uint8>().data()),
    277               decode.channels * width * sizeof(uint8), &decode),
    278           errors::InvalidArgument("Invalid PNG data, size ", input.size()));
    279     } else {
    280       // Finish decoding png
    281       OP_REQUIRES(
    282           context,
    283           png::CommonFinishDecode(
    284               reinterpret_cast<png_bytep>(output->flat<uint16>().data()),
    285               decode.channels * width * sizeof(uint16), &decode),
    286           errors::InvalidArgument("Invalid PNG data, size ", input.size()));
    287     }
    288   }
    289 
    290   void DecodeGif(OpKernelContext* context, StringPiece input) {
    291     OP_REQUIRES(context, channels_ == 0 || channels_ == 3,
    292                 errors::InvalidArgument("channels must be 0 or 3 for GIF, got ",
    293                                         channels_));
    294 
    295     // Decode GIF, allocating tensor once the size is known.
    296     Tensor* output = nullptr;
    297     string error_string;
    298     OP_REQUIRES(
    299         context,
    300         gif::Decode(input.data(), input.size(),
    301                     [=, &output](int num_frames, int width, int height,
    302                                  int channels) -> uint8* {
    303                       Status status;
    304                       if (format_ == kGifFormat) {
    305                         status = context->allocate_output(
    306                             0,
    307                             TensorShape({num_frames, height, width, channels}),
    308                             &output);
    309                       } else if (num_frames == 1) {
    310                         status = context->allocate_output(
    311                             0, TensorShape({height, width, channels}), &output);
    312                       } else {
    313                         status = errors::InvalidArgument(
    314                             "Got ", num_frames, " frames, but animated gifs ",
    315                             "can only be decoded by tf.image.decode_gif or ",
    316                             "tf.image.decode_image");
    317                       }
    318                       if (!status.ok()) {
    319                         VLOG(1) << status;
    320                         context->SetStatus(status);
    321                         return nullptr;
    322                       }
    323                       return output->flat<uint8>().data();
    324                     },
    325                     &error_string),
    326         errors::InvalidArgument("Invalid GIF data (size ", input.size(), "), ",
    327                                 error_string));
    328   }
    329 
    330  private:
    331   FileFormat format_;
    332   int channels_;
    333   int channel_bits_ = 8;
    334   jpeg::UncompressFlags flags_;
    335 };
    336 
    337 REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeImageOp);
    338 REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodeImageOp);
    339 REGISTER_KERNEL_BUILDER(Name("DecodeGif").Device(DEVICE_CPU), DecodeImageOp);
    340 REGISTER_KERNEL_BUILDER(Name("DecodeAndCropJpeg").Device(DEVICE_CPU),
    341                         DecodeImageOp);
    342 
    343 }  // namespace
    344 }  // namespace tensorflow
    345