Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/tensorboard/db/schema.h"
     17 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
     18 #include "tensorflow/contrib/tensorboard/db/summary_file_writer.h"
     19 #include "tensorflow/core/framework/graph.pb.h"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/resource_mgr.h"
     22 #include "tensorflow/core/lib/db/sqlite.h"
     23 #include "tensorflow/core/platform/protobuf.h"
     24 
     25 namespace tensorflow {
     26 
     27 REGISTER_KERNEL_BUILDER(Name("SummaryWriter").Device(DEVICE_CPU),
     28                         ResourceHandleOp<SummaryWriterInterface>);
     29 
     30 class CreateSummaryFileWriterOp : public OpKernel {
     31  public:
     32   explicit CreateSummaryFileWriterOp(OpKernelConstruction* ctx)
     33       : OpKernel(ctx) {}
     34 
     35   void Compute(OpKernelContext* ctx) override {
     36     const Tensor* tmp;
     37     OP_REQUIRES_OK(ctx, ctx->input("logdir", &tmp));
     38     const string logdir = tmp->scalar<string>()();
     39     OP_REQUIRES_OK(ctx, ctx->input("max_queue", &tmp));
     40     const int32 max_queue = tmp->scalar<int32>()();
     41     OP_REQUIRES_OK(ctx, ctx->input("flush_millis", &tmp));
     42     const int32 flush_millis = tmp->scalar<int32>()();
     43     OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
     44     const string filename_suffix = tmp->scalar<string>()();
     45 
     46     SummaryWriterInterface* s = nullptr;
     47     OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
     48                             ctx, HandleFromInput(ctx, 0), &s,
     49                             [max_queue, flush_millis, logdir, filename_suffix,
     50                              ctx](SummaryWriterInterface** s) {
     51                               return CreateSummaryFileWriter(
     52                                   max_queue, flush_millis, logdir,
     53                                   filename_suffix, ctx->env(), s);
     54                             }));
     55   }
     56 };
     57 REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
     58                         CreateSummaryFileWriterOp);
     59 
     60 class CreateSummaryDbWriterOp : public OpKernel {
     61  public:
     62   explicit CreateSummaryDbWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     63 
     64   void Compute(OpKernelContext* ctx) override {
     65     const Tensor* tmp;
     66     OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp));
     67     const string db_uri = tmp->scalar<string>()();
     68     OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp));
     69     const string experiment_name = tmp->scalar<string>()();
     70     OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp));
     71     const string run_name = tmp->scalar<string>()();
     72     OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
     73     const string user_name = tmp->scalar<string>()();
     74 
     75     SummaryWriterInterface* s = nullptr;
     76     OP_REQUIRES_OK(
     77         ctx,
     78         LookupOrCreateResource<SummaryWriterInterface>(
     79             ctx, HandleFromInput(ctx, 0), &s,
     80             [db_uri, experiment_name, run_name, user_name,
     81              ctx](SummaryWriterInterface** s) {
     82               Sqlite* db;
     83               TF_RETURN_IF_ERROR(Sqlite::Open(
     84                   db_uri, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db));
     85               core::ScopedUnref unref(db);
     86               TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db));
     87               TF_RETURN_IF_ERROR(CreateSummaryDbWriter(
     88                   db, experiment_name, run_name, user_name, ctx->env(), s));
     89               return Status::OK();
     90             }));
     91   }
     92 };
     93 REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
     94                         CreateSummaryDbWriterOp);
     95 
     96 class FlushSummaryWriterOp : public OpKernel {
     97  public:
     98   explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     99 
    100   void Compute(OpKernelContext* ctx) override {
    101     SummaryWriterInterface* s;
    102     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    103     core::ScopedUnref unref(s);
    104     OP_REQUIRES_OK(ctx, s->Flush());
    105   }
    106 };
    107 REGISTER_KERNEL_BUILDER(Name("FlushSummaryWriter").Device(DEVICE_CPU),
    108                         FlushSummaryWriterOp);
    109 
    110 class CloseSummaryWriterOp : public OpKernel {
    111  public:
    112   explicit CloseSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    113 
    114   void Compute(OpKernelContext* ctx) override {
    115     OP_REQUIRES_OK(ctx, DeleteResource<SummaryWriterInterface>(
    116                             ctx, HandleFromInput(ctx, 0)));
    117   }
    118 };
    119 REGISTER_KERNEL_BUILDER(Name("CloseSummaryWriter").Device(DEVICE_CPU),
    120                         CloseSummaryWriterOp);
    121 
    122 class WriteSummaryOp : public OpKernel {
    123  public:
    124   explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    125 
    126   void Compute(OpKernelContext* ctx) override {
    127     SummaryWriterInterface* s;
    128     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    129     core::ScopedUnref unref(s);
    130     const Tensor* tmp;
    131     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
    132     const int64 step = tmp->scalar<int64>()();
    133     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
    134     const string& tag = tmp->scalar<string>()();
    135     OP_REQUIRES_OK(ctx, ctx->input("summary_metadata", &tmp));
    136     const string& serialized_metadata = tmp->scalar<string>()();
    137 
    138     const Tensor* t;
    139     OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
    140 
    141     OP_REQUIRES_OK(ctx, s->WriteTensor(step, *t, tag, serialized_metadata));
    142   }
    143 };
    144 REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU),
    145                         WriteSummaryOp);
    146 
    147 class ImportEventOp : public OpKernel {
    148  public:
    149   explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    150 
    151   void Compute(OpKernelContext* ctx) override {
    152     SummaryWriterInterface* s;
    153     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    154     core::ScopedUnref unref(s);
    155     const Tensor* t;
    156     OP_REQUIRES_OK(ctx, ctx->input("event", &t));
    157     std::unique_ptr<Event> event{new Event};
    158     if (!ParseProtoUnlimited(event.get(), t->scalar<string>()())) {
    159       ctx->CtxFailureWithWarning(
    160           errors::DataLoss("Bad tf.Event binary proto tensor string"));
    161       return;
    162     }
    163     OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event)));
    164   }
    165 };
    166 REGISTER_KERNEL_BUILDER(Name("ImportEvent").Device(DEVICE_CPU), ImportEventOp);
    167 
    168 class WriteScalarSummaryOp : public OpKernel {
    169  public:
    170   explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    171 
    172   void Compute(OpKernelContext* ctx) override {
    173     SummaryWriterInterface* s;
    174     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    175     core::ScopedUnref unref(s);
    176     const Tensor* tmp;
    177     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
    178     const int64 step = tmp->scalar<int64>()();
    179     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
    180     const string& tag = tmp->scalar<string>()();
    181 
    182     const Tensor* t;
    183     OP_REQUIRES_OK(ctx, ctx->input("value", &t));
    184 
    185     OP_REQUIRES_OK(ctx, s->WriteScalar(step, *t, tag));
    186   }
    187 };
    188 REGISTER_KERNEL_BUILDER(Name("WriteScalarSummary").Device(DEVICE_CPU),
    189                         WriteScalarSummaryOp);
    190 
    191 class WriteHistogramSummaryOp : public OpKernel {
    192  public:
    193   explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    194 
    195   void Compute(OpKernelContext* ctx) override {
    196     SummaryWriterInterface* s;
    197     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    198     core::ScopedUnref unref(s);
    199     const Tensor* tmp;
    200     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
    201     const int64 step = tmp->scalar<int64>()();
    202     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
    203     const string& tag = tmp->scalar<string>()();
    204 
    205     const Tensor* t;
    206     OP_REQUIRES_OK(ctx, ctx->input("values", &t));
    207 
    208     OP_REQUIRES_OK(ctx, s->WriteHistogram(step, *t, tag));
    209   }
    210 };
    211 REGISTER_KERNEL_BUILDER(Name("WriteHistogramSummary").Device(DEVICE_CPU),
    212                         WriteHistogramSummaryOp);
    213 
    214 class WriteImageSummaryOp : public OpKernel {
    215  public:
    216   explicit WriteImageSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    217     int64 max_images_tmp;
    218     OP_REQUIRES_OK(ctx, ctx->GetAttr("max_images", &max_images_tmp));
    219     OP_REQUIRES(ctx, max_images_tmp < (1LL << 31),
    220                 errors::InvalidArgument("max_images must be < 2^31"));
    221     max_images_ = static_cast<int32>(max_images_tmp);
    222   }
    223 
    224   void Compute(OpKernelContext* ctx) override {
    225     SummaryWriterInterface* s;
    226     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    227     core::ScopedUnref unref(s);
    228     const Tensor* tmp;
    229     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
    230     const int64 step = tmp->scalar<int64>()();
    231     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
    232     const string& tag = tmp->scalar<string>()();
    233     const Tensor* bad_color;
    234     OP_REQUIRES_OK(ctx, ctx->input("bad_color", &bad_color));
    235     OP_REQUIRES(
    236         ctx, TensorShapeUtils::IsVector(bad_color->shape()),
    237         errors::InvalidArgument("bad_color must be a vector, got shape ",
    238                                 bad_color->shape().DebugString()));
    239 
    240     const Tensor* t;
    241     OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
    242 
    243     OP_REQUIRES_OK(ctx, s->WriteImage(step, *t, tag, max_images_, *bad_color));
    244   }
    245 
    246  private:
    247   int32 max_images_;
    248 };
    249 REGISTER_KERNEL_BUILDER(Name("WriteImageSummary").Device(DEVICE_CPU),
    250                         WriteImageSummaryOp);
    251 
    252 class WriteAudioSummaryOp : public OpKernel {
    253  public:
    254   explicit WriteAudioSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    255     OP_REQUIRES_OK(ctx, ctx->GetAttr("max_outputs", &max_outputs_));
    256     OP_REQUIRES(ctx, max_outputs_ > 0,
    257                 errors::InvalidArgument("max_outputs must be > 0"));
    258   }
    259 
    260   void Compute(OpKernelContext* ctx) override {
    261     SummaryWriterInterface* s;
    262     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    263     core::ScopedUnref unref(s);
    264     const Tensor* tmp;
    265     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
    266     const int64 step = tmp->scalar<int64>()();
    267     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
    268     const string& tag = tmp->scalar<string>()();
    269     OP_REQUIRES_OK(ctx, ctx->input("sample_rate", &tmp));
    270     const float sample_rate = tmp->scalar<float>()();
    271 
    272     const Tensor* t;
    273     OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
    274 
    275     OP_REQUIRES_OK(ctx,
    276                    s->WriteAudio(step, *t, tag, max_outputs_, sample_rate));
    277   }
    278 
    279  private:
    280   int max_outputs_;
    281 };
    282 REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU),
    283                         WriteAudioSummaryOp);
    284 
    285 class WriteGraphSummaryOp : public OpKernel {
    286  public:
    287   explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    288 
    289   void Compute(OpKernelContext* ctx) override {
    290     SummaryWriterInterface* s;
    291     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    292     core::ScopedUnref unref(s);
    293     const Tensor* t;
    294     OP_REQUIRES_OK(ctx, ctx->input("step", &t));
    295     const int64 step = t->scalar<int64>()();
    296     OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
    297     std::unique_ptr<GraphDef> graph{new GraphDef};
    298     if (!ParseProtoUnlimited(graph.get(), t->scalar<string>()())) {
    299       ctx->CtxFailureWithWarning(
    300           errors::DataLoss("Bad tf.GraphDef binary proto tensor string"));
    301       return;
    302     }
    303     OP_REQUIRES_OK(ctx, s->WriteGraph(step, std::move(graph)));
    304   }
    305 };
    306 REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU),
    307                         WriteGraphSummaryOp);
    308 
    309 }  // namespace tensorflow
    310