1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/image_ops.cc 17 18 #include <memory> 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor_shape.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/gif/gif_io.h" 26 #include "tensorflow/core/lib/jpeg/jpeg_mem.h" 27 #include "tensorflow/core/lib/png/png_io.h" 28 #include "tensorflow/core/lib/strings/str_util.h" 29 #include "tensorflow/core/platform/logging.h" 30 31 namespace tensorflow { 32 namespace { 33 34 enum FileFormat { 35 kUnknownFormat = 0, 36 kPngFormat = 1, 37 kJpgFormat = 2, 38 kGifFormat = 3, 39 }; 40 41 // Classify the contents of a file based on starting bytes (the magic number). 42 FileFormat ClassifyFileFormat(StringPiece data) { 43 // The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three 44 if (data.starts_with("\xff\xd8\xff")) return kJpgFormat; 45 if (data.starts_with("\x89PNG\r\n\x1a\n")) return kPngFormat; 46 if (data.starts_with("\x47\x49\x46\x38")) return kGifFormat; 47 return kUnknownFormat; 48 } 49 50 string FileFormatString(FileFormat magic, StringPiece data) { 51 switch (magic) { 52 case kPngFormat: 53 return "PNG"; 54 case kJpgFormat: 55 return "JPEG"; 56 case kGifFormat: 57 return "GIF"; 58 default: { 59 if (data.empty()) return "empty file"; 60 return strings::StrCat("unknown format starting with '", 61 str_util::CEscape(data.substr(0, 16)), "'"); 62 } 63 } 64 } 65 66 // Decode an image (either jpeg, png, or gif). We use a single op so that 67 // users don't have to care about which format they have. 68 class DecodeImageOp : public OpKernel { 69 public: 70 explicit DecodeImageOp(OpKernelConstruction* context) : OpKernel(context) { 71 // Determine which op we are: jpeg, png, gif, or any 72 if (type_string() == "DecodeJpeg") { 73 format_ = kJpgFormat; 74 } else if (type_string() == "DecodeAndCropJpeg") { 75 format_ = kJpgFormat; 76 flags_.crop = true; 77 } else if (type_string() == "DecodePng") { 78 format_ = kPngFormat; 79 } else if (type_string() == "DecodeGif") { 80 format_ = kGifFormat; 81 } else { 82 OP_REQUIRES_OK(context, 83 errors::InvalidArgument("Bad op type ", type_string())); 84 } 85 86 if (format_ == kGifFormat) { 87 channels_ = 3; 88 } else { 89 OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_)); 90 OP_REQUIRES( 91 context, 92 channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4, 93 errors::InvalidArgument("channels must be 0, 1, 3, or 4, got ", 94 channels_)); 95 } 96 flags_.components = channels_; 97 98 // In the case of png, we support uint16 output 99 if (format_ == kPngFormat) { 100 DataType dt; 101 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dt)); 102 OP_REQUIRES( 103 context, dt == DataType::DT_UINT8 || dt == DataType::DT_UINT16, 104 errors::InvalidArgument("Type must be uint8 or uint16, got ", dt)); 105 if (dt == DataType::DT_UINT8) { 106 channel_bits_ = 8; 107 } else { 108 channel_bits_ = 16; 109 } 110 } 111 112 // The TensorFlow-chosen default for jpeg decoding is IFAST, sacrificing 113 // image quality for speed. 114 flags_.dct_method = JDCT_IFAST; 115 116 if (format_ == kJpgFormat) { 117 OP_REQUIRES_OK(context, context->GetAttr("ratio", &flags_.ratio)); 118 OP_REQUIRES(context, 119 flags_.ratio == 1 || flags_.ratio == 2 || flags_.ratio == 4 || 120 flags_.ratio == 8, 121 errors::InvalidArgument("ratio must be 1, 2, 4, or 8, got ", 122 flags_.ratio)); 123 OP_REQUIRES_OK(context, context->GetAttr("fancy_upscaling", 124 &flags_.fancy_upscaling)); 125 OP_REQUIRES_OK(context, 126 context->GetAttr("try_recover_truncated", 127 &flags_.try_recover_truncated_jpeg)); 128 OP_REQUIRES_OK(context, 129 context->GetAttr("acceptable_fraction", 130 &flags_.min_acceptable_fraction)); 131 132 string dct_method; 133 OP_REQUIRES_OK(context, context->GetAttr("dct_method", &dct_method)); 134 OP_REQUIRES( 135 context, 136 (dct_method.empty() || dct_method == "INTEGER_FAST" || 137 dct_method == "INTEGER_ACCURATE"), 138 errors::InvalidArgument("dct_method must be one of " 139 "{'', 'INTEGER_FAST', 'INTEGER_ACCURATE'}")); 140 if (dct_method == "INTEGER_FAST") { 141 flags_.dct_method = JDCT_IFAST; 142 } else if (dct_method == "INTEGER_ACCURATE") { 143 flags_.dct_method = JDCT_ISLOW; 144 } 145 } 146 } 147 148 void Compute(OpKernelContext* context) override { 149 const Tensor& contents = context->input(0); 150 OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), 151 errors::InvalidArgument("contents must be scalar, got shape ", 152 contents.shape().DebugString())); 153 154 // Determine format 155 const StringPiece input = contents.scalar<string>()(); 156 const auto magic = ClassifyFileFormat(input); 157 OP_REQUIRES( 158 context, 159 magic == kJpgFormat || magic == kPngFormat || magic == kGifFormat, 160 errors::InvalidArgument("Expected image (JPEG, PNG, or GIF), got ", 161 FileFormatString(magic, input))); 162 OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(), 163 errors::InvalidArgument( 164 FileFormatString(magic, input), 165 " contents are too large for int: ", input.size())); 166 OP_REQUIRES(context, magic == kPngFormat || channel_bits_ == 8, 167 errors::InvalidArgument(FileFormatString(magic, input), 168 " does not support uint16 output")); 169 170 switch (magic) { 171 case kJpgFormat: 172 DecodeJpeg(context, input); 173 break; 174 case kPngFormat: 175 DecodePng(context, input); 176 break; 177 case kGifFormat: 178 DecodeGif(context, input); 179 break; 180 default: 181 LOG(FATAL) << "Should never get here after check above"; 182 break; 183 } 184 } 185 186 void DecodeJpeg(OpKernelContext* context, StringPiece input) { 187 OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3, 188 errors::InvalidArgument( 189 "channels must be 0, 1, or 3 for JPEG, got ", channels_)); 190 191 // Use local copy of flags to avoid race condition as the class member is 192 // shared among different invocations. 193 jpeg::UncompressFlags flags = flags_; 194 if (flags.crop) { 195 // Update flags to include crop window. 196 const Tensor& crop_window = context->input(1); 197 OP_REQUIRES(context, crop_window.dims() == 1, 198 errors::InvalidArgument("crop_window must be 1-D, got shape ", 199 crop_window.shape().DebugString())); 200 OP_REQUIRES(context, crop_window.dim_size(0) == 4, 201 errors::InvalidArgument("crop_size must have four elements ", 202 crop_window.shape().DebugString())); 203 auto crop_window_vec = crop_window.vec<int32>(); 204 flags.crop_y = crop_window_vec(0); 205 flags.crop_x = crop_window_vec(1); 206 flags.crop_height = crop_window_vec(2); 207 flags.crop_width = crop_window_vec(3); 208 } 209 210 // Decode jpeg, allocating tensor once the size is known. 211 Tensor* output = nullptr; 212 OP_REQUIRES( 213 context, 214 jpeg::Uncompress( 215 input.data(), input.size(), flags, nullptr /* nwarn */, 216 [=, &output](int width, int height, int channels) -> uint8* { 217 Status status(context->allocate_output( 218 0, 219 format_ == kGifFormat 220 ? TensorShape({1, height, width, channels}) 221 : TensorShape({height, width, channels}), 222 &output)); 223 if (!status.ok()) { 224 VLOG(1) << status; 225 context->SetStatus(status); 226 return nullptr; 227 } 228 return output->flat<uint8>().data(); 229 }), 230 errors::InvalidArgument("Invalid JPEG data or crop window, data size ", 231 input.size())); 232 } 233 234 void DecodePng(OpKernelContext* context, StringPiece input) { 235 // Start decoding png to get shape details 236 png::DecodeContext decode; 237 OP_REQUIRES(context, 238 png::CommonInitDecode(input, channels_, channel_bits_, &decode), 239 errors::InvalidArgument("Invalid PNG header, data size ", 240 input.size())); 241 242 // Verify that width and height are not too large: 243 // - verify width and height don't overflow int. 244 // - width can later be multiplied by channels_ and sizeof(uint16), so 245 // verify single dimension is not too large. 246 // - verify when width and height are multiplied together, there are a few 247 // bits to spare as well. 248 const int width = static_cast<int>(decode.width); 249 const int height = static_cast<int>(decode.height); 250 const int64 total_size = 251 static_cast<int64>(width) * static_cast<int64>(height); 252 if (width != static_cast<int64>(decode.width) || width <= 0 || 253 width >= (1LL << 27) || height != static_cast<int64>(decode.height) || 254 height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) { 255 png::CommonFreeDecode(&decode); 256 OP_REQUIRES(context, false, 257 errors::InvalidArgument("PNG size too large for int: ", 258 decode.width, " by ", decode.height)); 259 } 260 261 // Allocate tensor 262 Tensor* output = nullptr; 263 const auto status = context->allocate_output( 264 0, 265 format_ == kGifFormat ? TensorShape({1, height, width, decode.channels}) 266 : TensorShape({height, width, decode.channels}), 267 &output); 268 if (!status.ok()) png::CommonFreeDecode(&decode); 269 OP_REQUIRES_OK(context, status); 270 271 if (channel_bits_ == 8) { 272 // Finish decoding png 273 OP_REQUIRES( 274 context, 275 png::CommonFinishDecode( 276 reinterpret_cast<png_bytep>(output->flat<uint8>().data()), 277 decode.channels * width * sizeof(uint8), &decode), 278 errors::InvalidArgument("Invalid PNG data, size ", input.size())); 279 } else { 280 // Finish decoding png 281 OP_REQUIRES( 282 context, 283 png::CommonFinishDecode( 284 reinterpret_cast<png_bytep>(output->flat<uint16>().data()), 285 decode.channels * width * sizeof(uint16), &decode), 286 errors::InvalidArgument("Invalid PNG data, size ", input.size())); 287 } 288 } 289 290 void DecodeGif(OpKernelContext* context, StringPiece input) { 291 OP_REQUIRES(context, channels_ == 0 || channels_ == 3, 292 errors::InvalidArgument("channels must be 0 or 3 for GIF, got ", 293 channels_)); 294 295 // Decode GIF, allocating tensor once the size is known. 296 Tensor* output = nullptr; 297 string error_string; 298 OP_REQUIRES( 299 context, 300 gif::Decode(input.data(), input.size(), 301 [=, &output](int num_frames, int width, int height, 302 int channels) -> uint8* { 303 Status status; 304 if (format_ == kGifFormat) { 305 status = context->allocate_output( 306 0, 307 TensorShape({num_frames, height, width, channels}), 308 &output); 309 } else if (num_frames == 1) { 310 status = context->allocate_output( 311 0, TensorShape({height, width, channels}), &output); 312 } else { 313 status = errors::InvalidArgument( 314 "Got ", num_frames, " frames, but animated gifs ", 315 "can only be decoded by tf.image.decode_gif or ", 316 "tf.image.decode_image"); 317 } 318 if (!status.ok()) { 319 VLOG(1) << status; 320 context->SetStatus(status); 321 return nullptr; 322 } 323 return output->flat<uint8>().data(); 324 }, 325 &error_string), 326 errors::InvalidArgument("Invalid GIF data (size ", input.size(), "), ", 327 error_string)); 328 } 329 330 private: 331 FileFormat format_; 332 int channels_; 333 int channel_bits_ = 8; 334 jpeg::UncompressFlags flags_; 335 }; 336 337 REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeImageOp); 338 REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodeImageOp); 339 REGISTER_KERNEL_BUILDER(Name("DecodeGif").Device(DEVICE_CPU), DecodeImageOp); 340 REGISTER_KERNEL_BUILDER(Name("DecodeAndCropJpeg").Device(DEVICE_CPU), 341 DecodeImageOp); 342 343 } // namespace 344 } // namespace tensorflow 345