Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc
     17 
     18 #include <memory>
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/framework/tensor_shape.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/kernels/bounds_check.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/lib/jpeg/jpeg_mem.h"
     27 #include "tensorflow/core/platform/logging.h"
     28 
     29 namespace tensorflow {
     30 
     31 // Encode an image to a JPEG stream
     32 class EncodeJpegOp : public OpKernel {
     33  public:
     34   explicit EncodeJpegOp(OpKernelConstruction* context) : OpKernel(context) {
     35     OP_REQUIRES_OK(context, context->GetAttr("format", &format_));
     36     if (format_.empty()) {
     37       flags_.format = static_cast<jpeg::Format>(0);
     38     } else if (format_ == "grayscale") {
     39       flags_.format = jpeg::FORMAT_GRAYSCALE;
     40     } else if (format_ == "rgb") {
     41       flags_.format = jpeg::FORMAT_RGB;
     42     } else {
     43       OP_REQUIRES(context, false,
     44                   errors::InvalidArgument(
     45                       "format must be '', grayscale or rgb, got ", format_));
     46     }
     47 
     48     OP_REQUIRES_OK(context, context->GetAttr("quality", &flags_.quality));
     49     OP_REQUIRES(context, 0 <= flags_.quality && flags_.quality <= 100,
     50                 errors::InvalidArgument("quality must be in [0,100], got ",
     51                                         flags_.quality));
     52     OP_REQUIRES_OK(context,
     53                    context->GetAttr("progressive", &flags_.progressive));
     54     OP_REQUIRES_OK(
     55         context, context->GetAttr("optimize_size", &flags_.optimize_jpeg_size));
     56     OP_REQUIRES_OK(context, context->GetAttr("chroma_downsampling",
     57                                              &flags_.chroma_downsampling));
     58 
     59     string density_unit;
     60     OP_REQUIRES_OK(context, context->GetAttr("density_unit", &density_unit));
     61     if (density_unit == "in") {
     62       flags_.density_unit = 1;
     63     } else if (density_unit == "cm") {
     64       flags_.density_unit = 2;
     65     } else {
     66       OP_REQUIRES(context, false,
     67                   errors::InvalidArgument("density_unit must be 'in' or 'cm'",
     68                                           density_unit));
     69     }
     70 
     71     OP_REQUIRES_OK(context, context->GetAttr("x_density", &flags_.x_density));
     72     OP_REQUIRES_OK(context, context->GetAttr("y_density", &flags_.y_density));
     73     OP_REQUIRES_OK(context, context->GetAttr("xmp_metadata", &xmp_metadata_));
     74     flags_.xmp_metadata = xmp_metadata_;  // StringPiece doesn't own data
     75   }
     76 
     77   void Compute(OpKernelContext* context) override {
     78     const Tensor& image = context->input(0);
     79     OP_REQUIRES(context, image.dims() == 3,
     80                 errors::InvalidArgument("image must be 3-dimensional",
     81                                         image.shape().DebugString()));
     82 
     83     OP_REQUIRES(
     84         context,
     85         FastBoundsCheck(image.NumElements(), std::numeric_limits<int32>::max()),
     86         errors::InvalidArgument(
     87             "Cannot encode images with >= max int32 elements"));
     88 
     89     const int32 dim_size0 = static_cast<int32>(image.dim_size(0));
     90     const int32 dim_size1 = static_cast<int32>(image.dim_size(1));
     91     const int32 dim_size2 = static_cast<int32>(image.dim_size(2));
     92 
     93     // Autodetect format if desired, otherwise make sure format and
     94     // image channels are consistent.
     95     int channels;
     96     jpeg::CompressFlags adjusted_flags = flags_;
     97     if (flags_.format == 0) {
     98       channels = dim_size2;
     99       if (channels == 1) {
    100         adjusted_flags.format = jpeg::FORMAT_GRAYSCALE;
    101       } else if (channels == 3) {
    102         adjusted_flags.format = jpeg::FORMAT_RGB;
    103       } else {
    104         OP_REQUIRES(
    105             context, false,
    106             errors::InvalidArgument("image must have 1 or 3 channels, got ",
    107                                     image.shape().DebugString()));
    108       }
    109     } else {
    110       if (flags_.format == jpeg::FORMAT_GRAYSCALE) {
    111         channels = 1;
    112       } else {  // RGB
    113         channels = 3;
    114       }
    115       OP_REQUIRES(context, channels == dim_size2,
    116                   errors::InvalidArgument("format ", format_, " expects ",
    117                                           channels, " channels, got ",
    118                                           image.shape().DebugString()));
    119     }
    120 
    121     // Encode image to jpeg string
    122     Tensor* output = nullptr;
    123     OP_REQUIRES_OK(context,
    124                    context->allocate_output(0, TensorShape({}), &output));
    125     OP_REQUIRES(context,
    126                 jpeg::Compress(image.flat<uint8>().data(), dim_size1, dim_size0,
    127                                adjusted_flags, &output->scalar<string>()()),
    128                 errors::Internal("JPEG encoding failed"));
    129   }
    130 
    131  private:
    132   string format_;
    133   string xmp_metadata_;  // Owns data referenced by flags_
    134   jpeg::CompressFlags flags_;
    135 };
    136 REGISTER_KERNEL_BUILDER(Name("EncodeJpeg").Device(DEVICE_CPU), EncodeJpegOp);
    137 
    138 }  // namespace tensorflow
    139