1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/tensorboard/db/schema.h" 17 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" 18 #include "tensorflow/contrib/tensorboard/db/summary_file_writer.h" 19 #include "tensorflow/core/framework/graph.pb.h" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/resource_mgr.h" 22 #include "tensorflow/core/lib/db/sqlite.h" 23 #include "tensorflow/core/platform/protobuf.h" 24 25 namespace tensorflow { 26 27 REGISTER_KERNEL_BUILDER(Name("SummaryWriter").Device(DEVICE_CPU), 28 ResourceHandleOp<SummaryWriterInterface>); 29 30 class CreateSummaryFileWriterOp : public OpKernel { 31 public: 32 explicit CreateSummaryFileWriterOp(OpKernelConstruction* ctx) 33 : OpKernel(ctx) {} 34 35 void Compute(OpKernelContext* ctx) override { 36 const Tensor* tmp; 37 OP_REQUIRES_OK(ctx, ctx->input("logdir", &tmp)); 38 const string logdir = tmp->scalar<string>()(); 39 OP_REQUIRES_OK(ctx, ctx->input("max_queue", &tmp)); 40 const int32 max_queue = tmp->scalar<int32>()(); 41 OP_REQUIRES_OK(ctx, ctx->input("flush_millis", &tmp)); 42 const int32 flush_millis = tmp->scalar<int32>()(); 43 OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp)); 44 const string filename_suffix = tmp->scalar<string>()(); 45 46 SummaryWriterInterface* s = nullptr; 47 OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>( 48 ctx, HandleFromInput(ctx, 0), &s, 49 [max_queue, flush_millis, logdir, filename_suffix, 50 ctx](SummaryWriterInterface** s) { 51 return CreateSummaryFileWriter( 52 max_queue, flush_millis, logdir, 53 filename_suffix, ctx->env(), s); 54 })); 55 } 56 }; 57 REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU), 58 CreateSummaryFileWriterOp); 59 60 class CreateSummaryDbWriterOp : public OpKernel { 61 public: 62 explicit CreateSummaryDbWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 63 64 void Compute(OpKernelContext* ctx) override { 65 const Tensor* tmp; 66 OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp)); 67 const string db_uri = tmp->scalar<string>()(); 68 OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp)); 69 const string experiment_name = tmp->scalar<string>()(); 70 OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp)); 71 const string run_name = tmp->scalar<string>()(); 72 OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp)); 73 const string user_name = tmp->scalar<string>()(); 74 75 SummaryWriterInterface* s = nullptr; 76 OP_REQUIRES_OK( 77 ctx, 78 LookupOrCreateResource<SummaryWriterInterface>( 79 ctx, HandleFromInput(ctx, 0), &s, 80 [db_uri, experiment_name, run_name, user_name, 81 ctx](SummaryWriterInterface** s) { 82 Sqlite* db; 83 TF_RETURN_IF_ERROR(Sqlite::Open( 84 db_uri, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db)); 85 core::ScopedUnref unref(db); 86 TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db)); 87 TF_RETURN_IF_ERROR(CreateSummaryDbWriter( 88 db, experiment_name, run_name, user_name, ctx->env(), s)); 89 return Status::OK(); 90 })); 91 } 92 }; 93 REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU), 94 CreateSummaryDbWriterOp); 95 96 class FlushSummaryWriterOp : public OpKernel { 97 public: 98 explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 99 100 void Compute(OpKernelContext* ctx) override { 101 SummaryWriterInterface* s; 102 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 103 core::ScopedUnref unref(s); 104 OP_REQUIRES_OK(ctx, s->Flush()); 105 } 106 }; 107 REGISTER_KERNEL_BUILDER(Name("FlushSummaryWriter").Device(DEVICE_CPU), 108 FlushSummaryWriterOp); 109 110 class CloseSummaryWriterOp : public OpKernel { 111 public: 112 explicit CloseSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 113 114 void Compute(OpKernelContext* ctx) override { 115 OP_REQUIRES_OK(ctx, DeleteResource<SummaryWriterInterface>( 116 ctx, HandleFromInput(ctx, 0))); 117 } 118 }; 119 REGISTER_KERNEL_BUILDER(Name("CloseSummaryWriter").Device(DEVICE_CPU), 120 CloseSummaryWriterOp); 121 122 class WriteSummaryOp : public OpKernel { 123 public: 124 explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 125 126 void Compute(OpKernelContext* ctx) override { 127 SummaryWriterInterface* s; 128 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 129 core::ScopedUnref unref(s); 130 const Tensor* tmp; 131 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); 132 const int64 step = tmp->scalar<int64>()(); 133 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); 134 const string& tag = tmp->scalar<string>()(); 135 OP_REQUIRES_OK(ctx, ctx->input("summary_metadata", &tmp)); 136 const string& serialized_metadata = tmp->scalar<string>()(); 137 138 const Tensor* t; 139 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); 140 141 OP_REQUIRES_OK(ctx, s->WriteTensor(step, *t, tag, serialized_metadata)); 142 } 143 }; 144 REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU), 145 WriteSummaryOp); 146 147 class ImportEventOp : public OpKernel { 148 public: 149 explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 150 151 void Compute(OpKernelContext* ctx) override { 152 SummaryWriterInterface* s; 153 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 154 core::ScopedUnref unref(s); 155 const Tensor* t; 156 OP_REQUIRES_OK(ctx, ctx->input("event", &t)); 157 std::unique_ptr<Event> event{new Event}; 158 if (!ParseProtoUnlimited(event.get(), t->scalar<string>()())) { 159 ctx->CtxFailureWithWarning( 160 errors::DataLoss("Bad tf.Event binary proto tensor string")); 161 return; 162 } 163 OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event))); 164 } 165 }; 166 REGISTER_KERNEL_BUILDER(Name("ImportEvent").Device(DEVICE_CPU), ImportEventOp); 167 168 class WriteScalarSummaryOp : public OpKernel { 169 public: 170 explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 171 172 void Compute(OpKernelContext* ctx) override { 173 SummaryWriterInterface* s; 174 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 175 core::ScopedUnref unref(s); 176 const Tensor* tmp; 177 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); 178 const int64 step = tmp->scalar<int64>()(); 179 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); 180 const string& tag = tmp->scalar<string>()(); 181 182 const Tensor* t; 183 OP_REQUIRES_OK(ctx, ctx->input("value", &t)); 184 185 OP_REQUIRES_OK(ctx, s->WriteScalar(step, *t, tag)); 186 } 187 }; 188 REGISTER_KERNEL_BUILDER(Name("WriteScalarSummary").Device(DEVICE_CPU), 189 WriteScalarSummaryOp); 190 191 class WriteHistogramSummaryOp : public OpKernel { 192 public: 193 explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 194 195 void Compute(OpKernelContext* ctx) override { 196 SummaryWriterInterface* s; 197 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 198 core::ScopedUnref unref(s); 199 const Tensor* tmp; 200 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); 201 const int64 step = tmp->scalar<int64>()(); 202 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); 203 const string& tag = tmp->scalar<string>()(); 204 205 const Tensor* t; 206 OP_REQUIRES_OK(ctx, ctx->input("values", &t)); 207 208 OP_REQUIRES_OK(ctx, s->WriteHistogram(step, *t, tag)); 209 } 210 }; 211 REGISTER_KERNEL_BUILDER(Name("WriteHistogramSummary").Device(DEVICE_CPU), 212 WriteHistogramSummaryOp); 213 214 class WriteImageSummaryOp : public OpKernel { 215 public: 216 explicit WriteImageSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 217 int64 max_images_tmp; 218 OP_REQUIRES_OK(ctx, ctx->GetAttr("max_images", &max_images_tmp)); 219 OP_REQUIRES(ctx, max_images_tmp < (1LL << 31), 220 errors::InvalidArgument("max_images must be < 2^31")); 221 max_images_ = static_cast<int32>(max_images_tmp); 222 } 223 224 void Compute(OpKernelContext* ctx) override { 225 SummaryWriterInterface* s; 226 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 227 core::ScopedUnref unref(s); 228 const Tensor* tmp; 229 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); 230 const int64 step = tmp->scalar<int64>()(); 231 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); 232 const string& tag = tmp->scalar<string>()(); 233 const Tensor* bad_color; 234 OP_REQUIRES_OK(ctx, ctx->input("bad_color", &bad_color)); 235 OP_REQUIRES( 236 ctx, TensorShapeUtils::IsVector(bad_color->shape()), 237 errors::InvalidArgument("bad_color must be a vector, got shape ", 238 bad_color->shape().DebugString())); 239 240 const Tensor* t; 241 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); 242 243 OP_REQUIRES_OK(ctx, s->WriteImage(step, *t, tag, max_images_, *bad_color)); 244 } 245 246 private: 247 int32 max_images_; 248 }; 249 REGISTER_KERNEL_BUILDER(Name("WriteImageSummary").Device(DEVICE_CPU), 250 WriteImageSummaryOp); 251 252 class WriteAudioSummaryOp : public OpKernel { 253 public: 254 explicit WriteAudioSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 255 OP_REQUIRES_OK(ctx, ctx->GetAttr("max_outputs", &max_outputs_)); 256 OP_REQUIRES(ctx, max_outputs_ > 0, 257 errors::InvalidArgument("max_outputs must be > 0")); 258 } 259 260 void Compute(OpKernelContext* ctx) override { 261 SummaryWriterInterface* s; 262 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 263 core::ScopedUnref unref(s); 264 const Tensor* tmp; 265 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); 266 const int64 step = tmp->scalar<int64>()(); 267 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); 268 const string& tag = tmp->scalar<string>()(); 269 OP_REQUIRES_OK(ctx, ctx->input("sample_rate", &tmp)); 270 const float sample_rate = tmp->scalar<float>()(); 271 272 const Tensor* t; 273 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); 274 275 OP_REQUIRES_OK(ctx, 276 s->WriteAudio(step, *t, tag, max_outputs_, sample_rate)); 277 } 278 279 private: 280 int max_outputs_; 281 }; 282 REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU), 283 WriteAudioSummaryOp); 284 285 class WriteGraphSummaryOp : public OpKernel { 286 public: 287 explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 288 289 void Compute(OpKernelContext* ctx) override { 290 SummaryWriterInterface* s; 291 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); 292 core::ScopedUnref unref(s); 293 const Tensor* t; 294 OP_REQUIRES_OK(ctx, ctx->input("step", &t)); 295 const int64 step = t->scalar<int64>()(); 296 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); 297 std::unique_ptr<GraphDef> graph{new GraphDef}; 298 if (!ParseProtoUnlimited(graph.get(), t->scalar<string>()())) { 299 ctx->CtxFailureWithWarning( 300 errors::DataLoss("Bad tf.GraphDef binary proto tensor string")); 301 return; 302 } 303 OP_REQUIRES_OK(ctx, s->WriteGraph(step, std::move(graph))); 304 } 305 }; 306 REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU), 307 WriteGraphSummaryOp); 308 309 } // namespace tensorflow 310