Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/graph.pb.h"
     17 #include "tensorflow/core/framework/op_kernel.h"
     18 #include "tensorflow/core/framework/resource_mgr.h"
     19 #include "tensorflow/core/lib/db/sqlite.h"
     20 #include "tensorflow/core/platform/protobuf.h"
     21 #include "tensorflow/core/summary/schema.h"
     22 #include "tensorflow/core/summary/summary_db_writer.h"
     23 #include "tensorflow/core/summary/summary_file_writer.h"
     24 #include "tensorflow/core/util/event.pb.h"
     25 
     26 namespace tensorflow {
     27 
     28 REGISTER_KERNEL_BUILDER(Name("SummaryWriter").Device(DEVICE_CPU),
     29                         ResourceHandleOp<SummaryWriterInterface>);
     30 
     31 class CreateSummaryFileWriterOp : public OpKernel {
     32  public:
     33   explicit CreateSummaryFileWriterOp(OpKernelConstruction* ctx)
     34       : OpKernel(ctx) {}
     35 
     36   void Compute(OpKernelContext* ctx) override {
     37     const Tensor* tmp;
     38     OP_REQUIRES_OK(ctx, ctx->input("logdir", &tmp));
     39     const string logdir = tmp->scalar<string>()();
     40     OP_REQUIRES_OK(ctx, ctx->input("max_queue", &tmp));
     41     const int32 max_queue = tmp->scalar<int32>()();
     42     OP_REQUIRES_OK(ctx, ctx->input("flush_millis", &tmp));
     43     const int32 flush_millis = tmp->scalar<int32>()();
     44     OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
     45     const string filename_suffix = tmp->scalar<string>()();
     46 
     47     SummaryWriterInterface* s = nullptr;
     48     OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
     49                             ctx, HandleFromInput(ctx, 0), &s,
     50                             [max_queue, flush_millis, logdir, filename_suffix,
     51                              ctx](SummaryWriterInterface** s) {
     52                               return CreateSummaryFileWriter(
     53                                   max_queue, flush_millis, logdir,
     54                                   filename_suffix, ctx->env(), s);
     55                             }));
     56     core::ScopedUnref unref(s);
     57   }
     58 };
     59 REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
     60                         CreateSummaryFileWriterOp);
     61 
     62 class CreateSummaryDbWriterOp : public OpKernel {
     63  public:
     64   explicit CreateSummaryDbWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     65 
     66   void Compute(OpKernelContext* ctx) override {
     67     const Tensor* tmp;
     68     OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp));
     69     const string db_uri = tmp->scalar<string>()();
     70     OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp));
     71     const string experiment_name = tmp->scalar<string>()();
     72     OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp));
     73     const string run_name = tmp->scalar<string>()();
     74     OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
     75     const string user_name = tmp->scalar<string>()();
     76 
     77     SummaryWriterInterface* s = nullptr;
     78     OP_REQUIRES_OK(
     79         ctx,
     80         LookupOrCreateResource<SummaryWriterInterface>(
     81             ctx, HandleFromInput(ctx, 0), &s,
     82             [db_uri, experiment_name, run_name, user_name,
     83              ctx](SummaryWriterInterface** s) {
     84               Sqlite* db;
     85               TF_RETURN_IF_ERROR(Sqlite::Open(
     86                   db_uri, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db));
     87               core::ScopedUnref unref(db);
     88               TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db));
     89               TF_RETURN_IF_ERROR(CreateSummaryDbWriter(
     90                   db, experiment_name, run_name, user_name, ctx->env(), s));
     91               return Status::OK();
     92             }));
     93     core::ScopedUnref unref(s);
     94   }
     95 };
     96 REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
     97                         CreateSummaryDbWriterOp);
     98 
     99 class FlushSummaryWriterOp : public OpKernel {
    100  public:
    101   explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    102 
    103   void Compute(OpKernelContext* ctx) override {
    104     SummaryWriterInterface* s;
    105     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    106     core::ScopedUnref unref(s);
    107     OP_REQUIRES_OK(ctx, s->Flush());
    108   }
    109 };
    110 REGISTER_KERNEL_BUILDER(Name("FlushSummaryWriter").Device(DEVICE_CPU),
    111                         FlushSummaryWriterOp);
    112 
    113 class CloseSummaryWriterOp : public OpKernel {
    114  public:
    115   explicit CloseSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    116 
    117   void Compute(OpKernelContext* ctx) override {
    118     OP_REQUIRES_OK(ctx, DeleteResource<SummaryWriterInterface>(
    119                             ctx, HandleFromInput(ctx, 0)));
    120   }
    121 };
    122 REGISTER_KERNEL_BUILDER(Name("CloseSummaryWriter").Device(DEVICE_CPU),
    123                         CloseSummaryWriterOp);
    124 
    125 class WriteSummaryOp : public OpKernel {
    126  public:
    127   explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    128 
    129   void Compute(OpKernelContext* ctx) override {
    130     SummaryWriterInterface* s;
    131     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    132     core::ScopedUnref unref(s);
    133     const Tensor* tmp;
    134     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
    135     const int64 step = tmp->scalar<int64>()();
    136     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
    137     const string& tag = tmp->scalar<string>()();
    138     OP_REQUIRES_OK(ctx, ctx->input("summary_metadata", &tmp));
    139     const string& serialized_metadata = tmp->scalar<string>()();
    140 
    141     const Tensor* t;
    142     OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
    143 
    144     OP_REQUIRES_OK(ctx, s->WriteTensor(step, *t, tag, serialized_metadata));
    145   }
    146 };
    147 REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU),
    148                         WriteSummaryOp);
    149 
    150 class ImportEventOp : public OpKernel {
    151  public:
    152   explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    153 
    154   void Compute(OpKernelContext* ctx) override {
    155     SummaryWriterInterface* s;
    156     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    157     core::ScopedUnref unref(s);
    158     const Tensor* t;
    159     OP_REQUIRES_OK(ctx, ctx->input("event", &t));
    160     std::unique_ptr<Event> event{new Event};
    161     if (!ParseProtoUnlimited(event.get(), t->scalar<string>()())) {
    162       ctx->CtxFailureWithWarning(
    163           errors::DataLoss("Bad tf.Event binary proto tensor string"));
    164       return;
    165     }
    166     OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event)));
    167   }
    168 };
    169 REGISTER_KERNEL_BUILDER(Name("ImportEvent").Device(DEVICE_CPU), ImportEventOp);
    170 
    171 class WriteScalarSummaryOp : public OpKernel {
    172  public:
    173   explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    174 
    175   void Compute(OpKernelContext* ctx) override {
    176     SummaryWriterInterface* s;
    177     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    178     core::ScopedUnref unref(s);
    179     const Tensor* tmp;
    180     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
    181     const int64 step = tmp->scalar<int64>()();
    182     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
    183     const string& tag = tmp->scalar<string>()();
    184 
    185     const Tensor* t;
    186     OP_REQUIRES_OK(ctx, ctx->input("value", &t));
    187 
    188     OP_REQUIRES_OK(ctx, s->WriteScalar(step, *t, tag));
    189   }
    190 };
    191 REGISTER_KERNEL_BUILDER(Name("WriteScalarSummary").Device(DEVICE_CPU),
    192                         WriteScalarSummaryOp);
    193 
    194 class WriteHistogramSummaryOp : public OpKernel {
    195  public:
    196   explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    197 
    198   void Compute(OpKernelContext* ctx) override {
    199     SummaryWriterInterface* s;
    200     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    201     core::ScopedUnref unref(s);
    202     const Tensor* tmp;
    203     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
    204     const int64 step = tmp->scalar<int64>()();
    205     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
    206     const string& tag = tmp->scalar<string>()();
    207 
    208     const Tensor* t;
    209     OP_REQUIRES_OK(ctx, ctx->input("values", &t));
    210 
    211     OP_REQUIRES_OK(ctx, s->WriteHistogram(step, *t, tag));
    212   }
    213 };
    214 REGISTER_KERNEL_BUILDER(Name("WriteHistogramSummary").Device(DEVICE_CPU),
    215                         WriteHistogramSummaryOp);
    216 
    217 class WriteImageSummaryOp : public OpKernel {
    218  public:
    219   explicit WriteImageSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    220     int64 max_images_tmp;
    221     OP_REQUIRES_OK(ctx, ctx->GetAttr("max_images", &max_images_tmp));
    222     OP_REQUIRES(ctx, max_images_tmp < (1LL << 31),
    223                 errors::InvalidArgument("max_images must be < 2^31"));
    224     max_images_ = static_cast<int32>(max_images_tmp);
    225   }
    226 
    227   void Compute(OpKernelContext* ctx) override {
    228     SummaryWriterInterface* s;
    229     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    230     core::ScopedUnref unref(s);
    231     const Tensor* tmp;
    232     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
    233     const int64 step = tmp->scalar<int64>()();
    234     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
    235     const string& tag = tmp->scalar<string>()();
    236     const Tensor* bad_color;
    237     OP_REQUIRES_OK(ctx, ctx->input("bad_color", &bad_color));
    238     OP_REQUIRES(
    239         ctx, TensorShapeUtils::IsVector(bad_color->shape()),
    240         errors::InvalidArgument("bad_color must be a vector, got shape ",
    241                                 bad_color->shape().DebugString()));
    242 
    243     const Tensor* t;
    244     OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
    245 
    246     OP_REQUIRES_OK(ctx, s->WriteImage(step, *t, tag, max_images_, *bad_color));
    247   }
    248 
    249  private:
    250   int32 max_images_;
    251 };
    252 REGISTER_KERNEL_BUILDER(Name("WriteImageSummary").Device(DEVICE_CPU),
    253                         WriteImageSummaryOp);
    254 
    255 class WriteAudioSummaryOp : public OpKernel {
    256  public:
    257   explicit WriteAudioSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    258     OP_REQUIRES_OK(ctx, ctx->GetAttr("max_outputs", &max_outputs_));
    259     OP_REQUIRES(ctx, max_outputs_ > 0,
    260                 errors::InvalidArgument("max_outputs must be > 0"));
    261   }
    262 
    263   void Compute(OpKernelContext* ctx) override {
    264     SummaryWriterInterface* s;
    265     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    266     core::ScopedUnref unref(s);
    267     const Tensor* tmp;
    268     OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
    269     const int64 step = tmp->scalar<int64>()();
    270     OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
    271     const string& tag = tmp->scalar<string>()();
    272     OP_REQUIRES_OK(ctx, ctx->input("sample_rate", &tmp));
    273     const float sample_rate = tmp->scalar<float>()();
    274 
    275     const Tensor* t;
    276     OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
    277 
    278     OP_REQUIRES_OK(ctx,
    279                    s->WriteAudio(step, *t, tag, max_outputs_, sample_rate));
    280   }
    281 
    282  private:
    283   int max_outputs_;
    284 };
    285 REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU),
    286                         WriteAudioSummaryOp);
    287 
    288 class WriteGraphSummaryOp : public OpKernel {
    289  public:
    290   explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    291 
    292   void Compute(OpKernelContext* ctx) override {
    293     SummaryWriterInterface* s;
    294     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
    295     core::ScopedUnref unref(s);
    296     const Tensor* t;
    297     OP_REQUIRES_OK(ctx, ctx->input("step", &t));
    298     const int64 step = t->scalar<int64>()();
    299     OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
    300     std::unique_ptr<GraphDef> graph{new GraphDef};
    301     if (!ParseProtoUnlimited(graph.get(), t->scalar<string>()())) {
    302       ctx->CtxFailureWithWarning(
    303           errors::DataLoss("Bad tf.GraphDef binary proto tensor string"));
    304       return;
    305     }
    306     OP_REQUIRES_OK(ctx, s->WriteGraph(step, std::move(graph)));
    307   }
    308 };
    309 REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU),
    310                         WriteGraphSummaryOp);
    311 
    312 }  // namespace tensorflow
    313