1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as 17 // inputs or outputs in various ways. 18 19 // See docs in ../ops/summary_ops.cc. 20 21 #include <unordered_set> 22 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/resource_mgr.h" 26 #include "tensorflow/core/framework/summary.pb.h" 27 #include "tensorflow/core/lib/core/errors.h" 28 #include "tensorflow/core/lib/histogram/histogram.h" 29 #include "tensorflow/core/platform/logging.h" 30 #include "tensorflow/core/platform/protobuf.h" 31 32 namespace tensorflow { 33 34 template <typename T> 35 class SummaryScalarOp : public OpKernel { 36 public: 37 explicit SummaryScalarOp(OpKernelConstruction* context) : OpKernel(context) {} 38 39 void Compute(OpKernelContext* c) override { 40 const Tensor& tags = c->input(0); 41 const Tensor& values = c->input(1); 42 43 OP_REQUIRES( 44 c, 45 tags.IsSameSize(values) || 46 (IsLegacyScalar(tags.shape()) && IsLegacyScalar(values.shape())), 47 errors::InvalidArgument( 48 "tags and values not the same shape: ", tags.shape().DebugString(), 49 " != ", values.shape().DebugString(), SingleTag(tags))); 50 auto Ttags = tags.flat<string>(); 51 auto Tvalues = values.flat<T>(); 52 Summary s; 53 for (int i = 0; i < Ttags.size(); i++) { 54 Summary::Value* v = s.add_value(); 55 v->set_tag(Ttags(i)); 56 v->set_simple_value(float(Tvalues(i))); 57 } 58 59 Tensor* summary_tensor = nullptr; 60 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); 61 CHECK(s.SerializeToString(&summary_tensor->scalar<string>()())); 62 } 63 64 // If there's only one tag, include it in the error message 65 static string SingleTag(const Tensor& tags) { 66 if (tags.NumElements() == 1) { 67 return strings::StrCat(" (tag '", tags.flat<string>()(0), "')"); 68 } else { 69 return ""; 70 } 71 } 72 }; 73 74 template <typename T> 75 class SummaryHistoOp : public OpKernel { 76 public: 77 // SummaryHistoOp could be extended to take a list of custom bucket 78 // boundaries as an option. 79 explicit SummaryHistoOp(OpKernelConstruction* context) : OpKernel(context) {} 80 81 void Compute(OpKernelContext* c) override { 82 const Tensor& tags = c->input(0); 83 const Tensor& values = c->input(1); 84 const auto flat = values.flat<T>(); 85 OP_REQUIRES(c, IsLegacyScalar(tags.shape()), 86 errors::InvalidArgument("tags must be scalar")); 87 // Build histogram of values in "values" tensor 88 histogram::Histogram histo; 89 for (int64 i = 0; i < flat.size(); i++) { 90 const double double_val = static_cast<double>(flat(i)); 91 if (Eigen::numext::isnan(double_val)) { 92 c->SetStatus( 93 errors::InvalidArgument("Nan in summary histogram for: ", name())); 94 break; 95 } else if (Eigen::numext::isinf(double_val)) { 96 c->SetStatus(errors::InvalidArgument( 97 "Infinity in summary histogram for: ", name())); 98 break; 99 } 100 histo.Add(double_val); 101 } 102 103 Summary s; 104 Summary::Value* v = s.add_value(); 105 v->set_tag(tags.scalar<string>()()); 106 histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */); 107 108 Tensor* summary_tensor = nullptr; 109 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); 110 CHECK(s.SerializeToString(&summary_tensor->scalar<string>()())); 111 } 112 }; 113 114 #define REGISTER(T) \ 115 REGISTER_KERNEL_BUILDER( \ 116 Name("ScalarSummary").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 117 SummaryScalarOp<T>); \ 118 REGISTER_KERNEL_BUILDER( \ 119 Name("HistogramSummary").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 120 SummaryHistoOp<T>); 121 TF_CALL_REAL_NUMBER_TYPES(REGISTER) 122 #undef REGISTER 123 124 struct HistogramResource : public ResourceBase { 125 histogram::ThreadSafeHistogram histogram; 126 127 string DebugString() override { return "A histogram summary. Stats ..."; } 128 }; 129 130 class SummaryMergeOp : public OpKernel { 131 public: 132 explicit SummaryMergeOp(OpKernelConstruction* context) : OpKernel(context) {} 133 134 void Compute(OpKernelContext* c) override { 135 Summary s; 136 std::unordered_set<string> tags; 137 for (int input_num = 0; input_num < c->num_inputs(); input_num++) { 138 const Tensor& in = c->input(input_num); 139 auto in_vec = in.flat<string>(); 140 for (int i = 0; i < in_vec.dimension(0); i++) { 141 const string& s_in = in_vec(i); 142 Summary summary_in; 143 if (!ParseProtoUnlimited(&summary_in, s_in)) { 144 c->SetStatus(errors::InvalidArgument( 145 "Could not parse one of the summary inputs")); 146 return; 147 } 148 149 for (int v = 0; v < summary_in.value_size(); v++) { 150 const string& tag = summary_in.value(v).tag(); 151 // The tag is unused by the TensorSummary op, so no need to check 152 // for duplicates. 153 if ((!tag.empty()) && !tags.insert(tag).second) { 154 c->SetStatus(errors::InvalidArgument(strings::StrCat( 155 "Duplicate tag ", tag, " found in summary inputs"))); 156 return; 157 } 158 *s.add_value() = summary_in.value(v); 159 } 160 } 161 } 162 163 Tensor* summary_tensor = nullptr; 164 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); 165 CHECK(s.SerializeToString(&summary_tensor->scalar<string>()())); 166 } 167 }; 168 169 REGISTER_KERNEL_BUILDER(Name("MergeSummary").Device(DEVICE_CPU), 170 SummaryMergeOp); 171 172 } // namespace tensorflow 173