Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
     17 // inputs or outputs in various ways.
     18 
     19 // See docs in ../ops/summary_ops.cc.
     20 
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/summary.pb.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/lib/png/png_io.h"
     25 #include "tensorflow/core/platform/logging.h"
     26 
     27 namespace tensorflow {
     28 
     29 class SummaryImageOp : public OpKernel {
     30  public:
     31   typedef Eigen::Tensor<uint8, 2, Eigen::RowMajor> Uint8Image;
     32 
     33   explicit SummaryImageOp(OpKernelConstruction* context) : OpKernel(context) {
     34     int64 max_images_tmp;
     35     OP_REQUIRES_OK(context, context->GetAttr("max_images", &max_images_tmp));
     36     OP_REQUIRES(context, max_images_tmp < (1LL << 31),
     37                 errors::InvalidArgument("max_images must be < 2^31"));
     38     max_images_ = static_cast<int32>(max_images_tmp);
     39     const TensorProto* proto;
     40     OP_REQUIRES_OK(context, context->GetAttr("bad_color", &proto));
     41     OP_REQUIRES_OK(context, context->device()->MakeTensorFromProto(
     42                                 *proto, AllocatorAttributes(), &bad_color_));
     43     OP_REQUIRES(context, bad_color_.dtype() == DT_UINT8,
     44                 errors::InvalidArgument("bad_color must be uint8, got ",
     45                                         DataTypeString(bad_color_.dtype())));
     46     OP_REQUIRES(
     47         context, TensorShapeUtils::IsVector(bad_color_.shape()),
     48         errors::InvalidArgument("bad_color must be a vector, got shape ",
     49                                 bad_color_.shape().DebugString()));
     50   }
     51 
     52   void Compute(OpKernelContext* c) override {
     53     const Tensor& tags = c->input(0);
     54     const Tensor& tensor = c->input(1);
     55     OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
     56                 errors::InvalidArgument("Tags must be a scalar"));
     57     OP_REQUIRES(c,
     58                 tensor.dims() == 4 &&
     59                     (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 ||
     60                      tensor.dim_size(3) == 4),
     61                 errors::InvalidArgument(
     62                     "Tensor must be 4-D with last dim 1, 3, or 4, not ",
     63                     tensor.shape().DebugString()));
     64     const string& base_tag = tags.scalar<string>()();
     65 
     66     OP_REQUIRES(c,
     67                 tensor.dim_size(0) < (1LL << 31) &&
     68                     tensor.dim_size(1) < (1LL << 31) &&
     69                     tensor.dim_size(2) < (1LL << 31) &&
     70                     (tensor.dim_size(1) * tensor.dim_size(2)) < (1LL << 29),
     71                 errors::InvalidArgument("Tensor too large for summary ",
     72                                         tensor.shape().DebugString()));
     73 
     74     // The casts and h * w cannot overflow because of the limits above.
     75     const int batch_size = static_cast<int>(tensor.dim_size(0));
     76     const int h = static_cast<int>(tensor.dim_size(1));
     77     const int w = static_cast<int>(tensor.dim_size(2));
     78     const int hw = h * w;  // Compact these two dims for simplicity
     79     const int depth = static_cast<int>(tensor.dim_size(3));
     80 
     81     Summary s;
     82     if (tensor.dtype() == DT_UINT8) {
     83       // For uint8 input, no normalization is necessary
     84       auto ith_image = [&tensor, batch_size, hw, depth](int i) {
     85         auto values = tensor.shaped<uint8, 3>({batch_size, hw, depth});
     86         return typename TTypes<uint8>::ConstMatrix(
     87             &values(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
     88       };
     89       OP_REQUIRES_OK(
     90           c, AddImages(base_tag, batch_size, w, h, depth, ith_image, &s));
     91     } else if (tensor.dtype() == DT_HALF) {
     92       NormalizeAndAddImages<Eigen::half>(c, tensor, h, w, hw, depth, batch_size,
     93                                          base_tag, &s);
     94     } else if (tensor.dtype() == DT_FLOAT) {
     95       NormalizeAndAddImages<float>(c, tensor, h, w, hw, depth, batch_size,
     96                                    base_tag, &s);
     97     } else {  // tensor.dtype() = DT_DOUBLE
     98       NormalizeAndAddImages<double>(c, tensor, h, w, hw, depth, batch_size,
     99                                     base_tag, &s);
    100     }
    101 
    102     Tensor* summary_tensor = nullptr;
    103     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
    104     CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
    105   }
    106 
    107   template <class T>
    108   void NormalizeAndAddImages(OpKernelContext* c, const Tensor& tensor, int h,
    109                              int w, int hw, int depth, int batch_size,
    110                              const string& base_tag, Summary* s) {
    111     // For float and half images, nans and infs are replaced with bad_color.
    112     OP_REQUIRES(c, bad_color_.dim_size(0) >= depth,
    113                 errors::InvalidArgument(
    114                     "expected depth <= bad_color.size, got depth = ", depth,
    115                     ", bad_color.size = ", bad_color_.dim_size(0)));
    116     auto bad_color_full = bad_color_.vec<uint8>();
    117     typename TTypes<uint8>::ConstVec bad_color(bad_color_full.data(), depth);
    118 
    119     // Float images must be scaled and translated.
    120     Uint8Image image(hw, depth);
    121     auto ith_image = [&tensor, &image, bad_color, batch_size, hw,
    122                       depth](int i) {
    123       auto tensor_eigen = tensor.template shaped<T, 3>({batch_size, hw, depth});
    124       typename TTypes<T>::ConstMatrix values(
    125           &tensor_eigen(i, 0, 0),
    126           Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
    127       NormalizeFloatImage<T>(hw, depth, values, bad_color, &image);
    128       return image;
    129     };
    130     OP_REQUIRES_OK(c,
    131                    AddImages(base_tag, batch_size, w, h, depth, ith_image, s));
    132   }
    133 
    134   // Add the sequence of images specified by ith_image to the summary.
    135   //
    136   // Factoring this loop out into a helper function lets ith_image behave
    137   // differently in the float and uint8 cases: the float case needs a temporary
    138   // buffer which can be shared across calls to ith_image, but the uint8 case
    139   // does not.
    140   Status AddImages(const string& tag, int batch_size, int w, int h, int depth,
    141                    const std::function<Uint8Image(int)>& ith_image,
    142                    Summary* s) {
    143     const int N = std::min<int>(max_images_, batch_size);
    144     for (int i = 0; i < N; ++i) {
    145       Summary::Value* v = s->add_value();
    146       // The tag depends on the number of requested images (not the number
    147       // produced.)
    148       //
    149       // Note that later on avisu uses "/" to figure out a consistent naming
    150       // convention for display, so we append "/image" to guarantee that the
    151       // image(s) won't be displayed in the global scope with no name.
    152       if (max_images_ > 1) {
    153         v->set_tag(strings::StrCat(tag, "/image/", i));
    154       } else {
    155         v->set_tag(strings::StrCat(tag, "/image"));
    156       }
    157 
    158       auto image = ith_image(i);
    159       Summary::Image* si = v->mutable_image();
    160       si->set_height(h);
    161       si->set_width(w);
    162       si->set_colorspace(depth);
    163       const int channel_bits = 8;
    164       const int compression = -1;  // Use zlib default
    165       if (!png::WriteImageToBuffer(
    166               image.data(), w, h, w * depth, depth, channel_bits, compression,
    167               si->mutable_encoded_image_string(), nullptr)) {
    168         return errors::Internal("PNG encoding failed");
    169       }
    170     }
    171     return Status::OK();
    172   }
    173 
    174   template <class T>
    175   static void NormalizeFloatImage(int hw, int depth,
    176                                   typename TTypes<T>::ConstMatrix values,
    177                                   typename TTypes<uint8>::ConstVec bad_color,
    178                                   Uint8Image* image) {
    179     if (!image->size()) return;  // Nothing to do for empty images
    180 
    181     // Rescale the image to uint8 range.
    182     //
    183     // We are trying to generate an RGB image from a float/half tensor.  We do
    184     // not have any info about the expected range of values in the tensor
    185     // but the generated image needs to have all RGB values within [0, 255].
    186     //
    187     // We use two different algorithms to generate these values.  If the
    188     // tensor has only positive values we scale them all by 255/max(values).
    189     // If the tensor has both negative and positive values we scale them by
    190     // the max of their absolute values and center them around 127.
    191     //
    192     // This works for most cases, but does not respect the relative dynamic
    193     // range across different instances of the tensor.
    194 
    195     // Compute min and max ignoring nonfinite pixels
    196     float image_min = std::numeric_limits<float>::infinity();
    197     float image_max = -image_min;
    198     for (int i = 0; i < hw; i++) {
    199       bool finite = true;
    200       for (int j = 0; j < depth; j++) {
    201         if (!Eigen::numext::isfinite(values(i, j))) {
    202           finite = false;
    203           break;
    204         }
    205       }
    206       if (finite) {
    207         for (int j = 0; j < depth; j++) {
    208           float value(values(i, j));
    209           image_min = std::min(image_min, value);
    210           image_max = std::max(image_max, value);
    211         }
    212       }
    213     }
    214 
    215     // Pick an affine transform into uint8
    216     const float kZeroThreshold = 1e-6;
    217     T scale, offset;
    218     if (image_min < 0) {
    219       float max_val = std::max(std::abs(image_min), std::abs(image_max));
    220       scale = T(max_val < kZeroThreshold ? 0.0f : 127.0f / max_val);
    221       offset = T(128.0f);
    222     } else {
    223       scale = T(image_max < kZeroThreshold ? 0.0f : 255.0f / image_max);
    224       offset = T(0.0f);
    225     }
    226 
    227     // Transform image, turning nonfinite values to bad_color
    228     for (int i = 0; i < hw; i++) {
    229       bool finite = true;
    230       for (int j = 0; j < depth; j++) {
    231         if (!Eigen::numext::isfinite(values(i, j))) {
    232           finite = false;
    233           break;
    234         }
    235       }
    236       if (finite) {
    237         image->chip<0>(i) = (values.template chip<0>(i) * scale + offset)
    238                                 .template cast<uint8>();
    239       } else {
    240         image->chip<0>(i) = bad_color;
    241       }
    242     }
    243   }
    244 
    245  private:
    246   int32 max_images_;
    247   Tensor bad_color_;
    248 };
    249 
    250 REGISTER_KERNEL_BUILDER(Name("ImageSummary").Device(DEVICE_CPU),
    251                         SummaryImageOp);
    252 
    253 }  // namespace tensorflow
    254