1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/graph.pb.h" 17 #include "tensorflow/core/framework/op_kernel.h" 18 #include "tensorflow/core/framework/resource_mgr.h" 19 #include "tensorflow/core/lib/db/sqlite.h" 20 #include "tensorflow/core/platform/protobuf.h" 21 #include "tensorflow/core/summary/schema.h" 22 #include "tensorflow/core/summary/summary_db_writer.h" 23 #include "tensorflow/core/summary/summary_file_writer.h" 24 #include "tensorflow/core/util/event.pb.h" 25 26 namespace tensorflow { 27 28 REGISTER_KERNEL_BUILDER(Name("SummaryWriter").Device(DEVICE_CPU), 29 ResourceHandleOp<SummaryWriterInterface>); 30 31 class CreateSummaryFileWriterOp : public OpKernel { 32 public: 33 explicit CreateSummaryFileWriterOp(OpKernelConstruction* ctx) 34 : OpKernel(ctx) {} 35 36 void Compute(OpKernelContext* ctx) override { 37 const Tensor* tmp; 38 OP_REQUIRES_OK(ctx, ctx->input("logdir", &tmp)); 39 const string logdir = tmp->scalar<string>()(); 40 OP_REQUIRES_OK(ctx, ctx->input("max_queue", &tmp)); 41 const int32 max_queue = tmp->scalar<int32>()(); 42 OP_REQUIRES_OK(ctx, ctx->input("flush_millis", &tmp)); 43 const int32 flush_millis = tmp->scalar<int32>()(); 44 OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp)); 45 const string filename_suffix = tmp->scalar<string>()(); 46 47 SummaryWriterInterface* s = nullptr; 48 OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>( 49 ctx, HandleFromInput(ctx, 0), &s, 50 [max_queue, flush_millis, logdir, filename_suffix, 51 ctx](SummaryWriterInterface** s) { 52 return CreateSummaryFileWriter( 53 max_queue, flush_millis, logdir, 54 filename_suffix, ctx->env(), s); 55 })); 56 core::ScopedUnref unref(s); 57 } 58 }; 59 REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU), 60 CreateSummaryFileWriterOp); 61 62 class CreateSummaryDbWriterOp : public OpKernel { 63 public: 64 explicit CreateSummaryDbWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 65 66 void Compute(OpKernelContext* ctx) override { 67 const Tensor* tmp; 68 OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp)); 69 const string db_uri = tmp->scalar<string>()(); 70 OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp)); 71 const string experiment_name = tmp->scalar<string>()(); 72 OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp)); 73 const string run_name = tmp->scalar<string>()(); 74 OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp)); 75 const string user_name = tmp->scalar<string>()(); 76 77 SummaryWriterInterface* s = nullptr; 78 OP_REQUIRES_OK( 79 ctx, 80 LookupOrCreateResource<SummaryWriterInterface>( 81 ctx, HandleFromInput(ctx, 0), &s, 82 [db_uri, experiment_name, run_name, user_name, 83 ctx](SummaryWriterInterface** s) { 84 Sqlite* db; 85 TF_RETURN_IF_ERROR(Sqlite::Open( 86 db_uri, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db)); 87 core::ScopedUnref unref(db); 88 TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db)); 89 TF_RETURN_IF_ERROR(CreateSummaryDbWriter( 90 db, experiment_name, run_name, user_name, ctx->env(), s)); 91 return Status::OK(); 92 })); 93 core::ScopedUnref unref(s); 94 } 95 }; 96 REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU), 97 CreateSummaryDbWriterOp); 98 99 class FlushSummaryWriterOp : public OpKernel { 100 public: 101 explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 102 103 void Compute(OpKernelContext* ctx) override { 104 SummaryWriterInterface* s; 105 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 106 core::ScopedUnref unref(s); 107 OP_REQUIRES_OK(ctx, s->Flush()); 108 } 109 }; 110 REGISTER_KERNEL_BUILDER(Name("FlushSummaryWriter").Device(DEVICE_CPU), 111 FlushSummaryWriterOp); 112 113 class CloseSummaryWriterOp : public OpKernel { 114 public: 115 explicit CloseSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 116 117 void Compute(OpKernelContext* ctx) override { 118 OP_REQUIRES_OK(ctx, DeleteResource<SummaryWriterInterface>( 119 ctx, HandleFromInput(ctx, 0))); 120 } 121 }; 122 REGISTER_KERNEL_BUILDER(Name("CloseSummaryWriter").Device(DEVICE_CPU), 123 CloseSummaryWriterOp); 124 125 class WriteSummaryOp : public OpKernel { 126 public: 127 explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 128 129 void Compute(OpKernelContext* ctx) override { 130 SummaryWriterInterface* s; 131 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 132 core::ScopedUnref unref(s); 133 const Tensor* tmp; 134 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); 135 const int64 step = tmp->scalar<int64>()(); 136 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); 137 const string& tag = tmp->scalar<string>()(); 138 OP_REQUIRES_OK(ctx, ctx->input("summary_metadata", &tmp)); 139 const string& serialized_metadata = tmp->scalar<string>()(); 140 141 const Tensor* t; 142 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); 143 144 OP_REQUIRES_OK(ctx, s->WriteTensor(step, *t, tag, serialized_metadata)); 145 } 146 }; 147 REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU), 148 WriteSummaryOp); 149 150 class ImportEventOp : public OpKernel { 151 public: 152 explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 153 154 void Compute(OpKernelContext* ctx) override { 155 SummaryWriterInterface* s; 156 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 157 core::ScopedUnref unref(s); 158 const Tensor* t; 159 OP_REQUIRES_OK(ctx, ctx->input("event", &t)); 160 std::unique_ptr<Event> event{new Event}; 161 if (!ParseProtoUnlimited(event.get(), t->scalar<string>()())) { 162 ctx->CtxFailureWithWarning( 163 errors::DataLoss("Bad tf.Event binary proto tensor string")); 164 return; 165 } 166 OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event))); 167 } 168 }; 169 REGISTER_KERNEL_BUILDER(Name("ImportEvent").Device(DEVICE_CPU), ImportEventOp); 170 171 class WriteScalarSummaryOp : public OpKernel { 172 public: 173 explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 174 175 void Compute(OpKernelContext* ctx) override { 176 SummaryWriterInterface* s; 177 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 178 core::ScopedUnref unref(s); 179 const Tensor* tmp; 180 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); 181 const int64 step = tmp->scalar<int64>()(); 182 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); 183 const string& tag = tmp->scalar<string>()(); 184 185 const Tensor* t; 186 OP_REQUIRES_OK(ctx, ctx->input("value", &t)); 187 188 OP_REQUIRES_OK(ctx, s->WriteScalar(step, *t, tag)); 189 } 190 }; 191 REGISTER_KERNEL_BUILDER(Name("WriteScalarSummary").Device(DEVICE_CPU), 192 WriteScalarSummaryOp); 193 194 class WriteHistogramSummaryOp : public OpKernel { 195 public: 196 explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 197 198 void Compute(OpKernelContext* ctx) override { 199 SummaryWriterInterface* s; 200 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 201 core::ScopedUnref unref(s); 202 const Tensor* tmp; 203 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); 204 const int64 step = tmp->scalar<int64>()(); 205 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); 206 const string& tag = tmp->scalar<string>()(); 207 208 const Tensor* t; 209 OP_REQUIRES_OK(ctx, ctx->input("values", &t)); 210 211 OP_REQUIRES_OK(ctx, s->WriteHistogram(step, *t, tag)); 212 } 213 }; 214 REGISTER_KERNEL_BUILDER(Name("WriteHistogramSummary").Device(DEVICE_CPU), 215 WriteHistogramSummaryOp); 216 217 class WriteImageSummaryOp : public OpKernel { 218 public: 219 explicit WriteImageSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 220 int64 max_images_tmp; 221 OP_REQUIRES_OK(ctx, ctx->GetAttr("max_images", &max_images_tmp)); 222 OP_REQUIRES(ctx, max_images_tmp < (1LL << 31), 223 errors::InvalidArgument("max_images must be < 2^31")); 224 max_images_ = static_cast<int32>(max_images_tmp); 225 } 226 227 void Compute(OpKernelContext* ctx) override { 228 SummaryWriterInterface* s; 229 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 230 core::ScopedUnref unref(s); 231 const Tensor* tmp; 232 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); 233 const int64 step = tmp->scalar<int64>()(); 234 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); 235 const string& tag = tmp->scalar<string>()(); 236 const Tensor* bad_color; 237 OP_REQUIRES_OK(ctx, ctx->input("bad_color", &bad_color)); 238 OP_REQUIRES( 239 ctx, TensorShapeUtils::IsVector(bad_color->shape()), 240 errors::InvalidArgument("bad_color must be a vector, got shape ", 241 bad_color->shape().DebugString())); 242 243 const Tensor* t; 244 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); 245 246 OP_REQUIRES_OK(ctx, s->WriteImage(step, *t, tag, max_images_, *bad_color)); 247 } 248 249 private: 250 int32 max_images_; 251 }; 252 REGISTER_KERNEL_BUILDER(Name("WriteImageSummary").Device(DEVICE_CPU), 253 WriteImageSummaryOp); 254 255 class WriteAudioSummaryOp : public OpKernel { 256 public: 257 explicit WriteAudioSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 258 OP_REQUIRES_OK(ctx, ctx->GetAttr("max_outputs", &max_outputs_)); 259 OP_REQUIRES(ctx, max_outputs_ > 0, 260 errors::InvalidArgument("max_outputs must be > 0")); 261 } 262 263 void Compute(OpKernelContext* ctx) override { 264 SummaryWriterInterface* s; 265 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 266 core::ScopedUnref unref(s); 267 const Tensor* tmp; 268 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); 269 const int64 step = tmp->scalar<int64>()(); 270 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); 271 const string& tag = tmp->scalar<string>()(); 272 OP_REQUIRES_OK(ctx, ctx->input("sample_rate", &tmp)); 273 const float sample_rate = tmp->scalar<float>()(); 274 275 const Tensor* t; 276 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); 277 278 OP_REQUIRES_OK(ctx, 279 s->WriteAudio(step, *t, tag, max_outputs_, sample_rate)); 280 } 281 282 private: 283 int max_outputs_; 284 }; 285 REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU), 286 WriteAudioSummaryOp); 287 288 class WriteGraphSummaryOp : public OpKernel { 289 public: 290 explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 291 292 void Compute(OpKernelContext* ctx) override { 293 SummaryWriterInterface* s; 294 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 295 core::ScopedUnref unref(s); 296 const Tensor* t; 297 OP_REQUIRES_OK(ctx, ctx->input("step", &t)); 298 const int64 step = t->scalar<int64>()(); 299 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); 300 std::unique_ptr<GraphDef> graph{new GraphDef}; 301 if (!ParseProtoUnlimited(graph.get(), t->scalar<string>()())) { 302 ctx->CtxFailureWithWarning( 303 errors::DataLoss("Bad tf.GraphDef binary proto tensor string")); 304 return; 305 } 306 OP_REQUIRES_OK(ctx, s->WriteGraph(step, std::move(graph))); 307 } 308 }; 309 REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU), 310 WriteGraphSummaryOp); 311 312 } // namespace tensorflow 313