Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_FRAMEWORK_OP_H_
     17 #define TENSORFLOW_FRAMEWORK_OP_H_
     18 
     19 #include <functional>
     20 #include <unordered_map>
     21 
     22 #include <vector>
     23 #include "tensorflow/core/framework/op_def_builder.h"
     24 #include "tensorflow/core/framework/op_def_util.h"
     25 #include "tensorflow/core/framework/selective_registration.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/strings/str_util.h"
     29 #include "tensorflow/core/lib/strings/strcat.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/platform/mutex.h"
     33 #include "tensorflow/core/platform/thread_annotations.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace tensorflow {
     37 
     38 // Users that want to look up an OpDef by type name should take an
     39 // OpRegistryInterface.  Functions accepting a
     40 // (const) OpRegistryInterface* may call LookUp() from multiple threads.
     41 class OpRegistryInterface {
     42  public:
     43   virtual ~OpRegistryInterface();
     44 
     45   // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
     46   // registered under that name, otherwise returns the registered OpDef.
     47   // Caller must not delete the returned pointer.
     48   virtual Status LookUp(const string& op_type_name,
     49                         const OpRegistrationData** op_reg_data) const = 0;
     50 
     51   // Shorthand for calling LookUp to get the OpDef.
     52   Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
     53 };
     54 
     55 // The standard implementation of OpRegistryInterface, along with a
     56 // global singleton used for registering ops via the REGISTER
     57 // macros below.  Thread-safe.
     58 //
     59 // Example registration:
     60 //   OpRegistry::Global()->Register(
     61 //     [](OpRegistrationData* op_reg_data)->Status {
     62 //       // Populate *op_reg_data here.
     63 //       return Status::OK();
     64 //   });
     65 class OpRegistry : public OpRegistryInterface {
     66  public:
     67   typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
     68 
     69   OpRegistry();
     70   ~OpRegistry() override;
     71 
     72   void Register(const OpRegistrationDataFactory& op_data_factory);
     73 
     74   Status LookUp(const string& op_type_name,
     75                 const OpRegistrationData** op_reg_data) const override;
     76 
     77   // Fills *ops with all registered OpDefs (except those with names
     78   // starting with '_' if include_internal == false) sorted in
     79   // ascending alphabetical order.
     80   void Export(bool include_internal, OpList* ops) const;
     81 
     82   // Returns ASCII-format OpList for all registered OpDefs (except
     83   // those with names starting with '_' if include_internal == false).
     84   string DebugString(bool include_internal) const;
     85 
     86   // A singleton available at startup.
     87   static OpRegistry* Global();
     88 
     89   // Get all registered ops.
     90   void GetRegisteredOps(std::vector<OpDef>* op_defs);
     91 
     92   // Watcher, a function object.
     93   // The watcher, if set by SetWatcher(), is called every time an op is
     94   // registered via the Register function. The watcher is passed the Status
     95   // obtained from building and adding the OpDef to the registry, and the OpDef
     96   // itself if it was successfully built. A watcher returns a Status which is in
     97   // turn returned as the final registration status.
     98   typedef std::function<Status(const Status&, const OpDef&)> Watcher;
     99 
    100   // An OpRegistry object has only one watcher. This interface is not thread
    101   // safe, as different clients are free to set the watcher any time.
    102   // Clients are expected to atomically perform the following sequence of
    103   // operations :
    104   // SetWatcher(a_watcher);
    105   // Register some ops;
    106   // op_registry->ProcessRegistrations();
    107   // SetWatcher(nullptr);
    108   // Returns a non-OK status if a non-null watcher is over-written by another
    109   // non-null watcher.
    110   Status SetWatcher(const Watcher& watcher);
    111 
    112   // Process the current list of deferred registrations. Note that calls to
    113   // Export, LookUp and DebugString would also implicitly process the deferred
    114   // registrations. Returns the status of the first failed op registration or
    115   // Status::OK() otherwise.
    116   Status ProcessRegistrations() const;
    117 
    118   // Defer the registrations until a later call to a function that processes
    119   // deferred registrations are made. Normally, registrations that happen after
    120   // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
    121   // immediately. Call this to defer future registrations.
    122   void DeferRegistrations();
    123 
    124   // Clear the registrations that have been deferred.
    125   void ClearDeferredRegistrations();
    126 
    127  private:
    128   // Ensures that all the functions in deferred_ get called, their OpDef's
    129   // registered, and returns with deferred_ empty.  Returns true the first
    130   // time it is called. Prints a fatal log if any op registration fails.
    131   bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
    132 
    133   // Calls the functions in deferred_ and registers their OpDef's
    134   // It returns the Status of the first failed op registration or Status::OK()
    135   // otherwise.
    136   Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
    137 
    138   // Add 'def' to the registry with additional data 'data'. On failure, or if
    139   // there is already an OpDef with that name registered, returns a non-okay
    140   // status.
    141   Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
    142       const EXCLUSIVE_LOCKS_REQUIRED(mu_);
    143 
    144   mutable mutex mu_;
    145   // Functions in deferred_ may only be called with mu_ held.
    146   mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
    147   // Values are owned.
    148   mutable std::unordered_map<string, const OpRegistrationData*> registry_
    149       GUARDED_BY(mu_);
    150   mutable bool initialized_ GUARDED_BY(mu_);
    151 
    152   // Registry watcher.
    153   mutable Watcher watcher_ GUARDED_BY(mu_);
    154 };
    155 
    156 // An adapter to allow an OpList to be used as an OpRegistryInterface.
    157 //
    158 // Note that shape inference functions are not passed in to OpListOpRegistry, so
    159 // it will return an unusable shape inference function for every op it supports;
    160 // therefore, it should only be used in contexts where this is okay.
    161 class OpListOpRegistry : public OpRegistryInterface {
    162  public:
    163   // Does not take ownership of op_list, *op_list must outlive *this.
    164   OpListOpRegistry(const OpList* op_list);
    165   ~OpListOpRegistry() override;
    166   Status LookUp(const string& op_type_name,
    167                 const OpRegistrationData** op_reg_data) const override;
    168 
    169  private:
    170   // Values are owned.
    171   std::unordered_map<string, const OpRegistrationData*> index_;
    172 };
    173 
    174 // Support for defining the OpDef (specifying the semantics of the Op and how
    175 // it should be created) and registering it in the OpRegistry::Global()
    176 // registry.  Usage:
    177 //
    178 // REGISTER_OP("my_op_name")
    179 //     .Attr("<name>:<type>")
    180 //     .Attr("<name>:<type>=<default>")
    181 //     .Input("<name>:<type-expr>")
    182 //     .Input("<name>:Ref(<type-expr>)")
    183 //     .Output("<name>:<type-expr>")
    184 //     .Doc(R"(
    185 // <1-line summary>
    186 // <rest of the description (potentially many lines)>
    187 // <name-of-attr-input-or-output>: <description of name>
    188 // <name-of-attr-input-or-output>: <description of name;
    189 //   if long, indent the description on subsequent lines>
    190 // )");
    191 //
    192 // Note: .Doc() should be last.
    193 // For details, see the OpDefBuilder class in op_def_builder.h.
    194 
    195 namespace register_op {
    196 
    197 // OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP
    198 // calls. This allows the result of REGISTER_OP to be used in chaining, as in
    199 // REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective
    200 // registration to turn the entire call-chain into a no-op.
    201 template <bool should_register>
    202 class OpDefBuilderWrapper;
    203 
    204 // Template specialization that forwards all calls to the contained builder.
    205 template <>
    206 class OpDefBuilderWrapper<true> {
    207  public:
    208   OpDefBuilderWrapper(const char name[]) : builder_(name) {}
    209   OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
    210     builder_.Attr(spec);
    211     return *this;
    212   }
    213   OpDefBuilderWrapper<true>& Input(StringPiece spec) {
    214     builder_.Input(spec);
    215     return *this;
    216   }
    217   OpDefBuilderWrapper<true>& Output(StringPiece spec) {
    218     builder_.Output(spec);
    219     return *this;
    220   }
    221   OpDefBuilderWrapper<true>& SetIsCommutative() {
    222     builder_.SetIsCommutative();
    223     return *this;
    224   }
    225   OpDefBuilderWrapper<true>& SetIsAggregate() {
    226     builder_.SetIsAggregate();
    227     return *this;
    228   }
    229   OpDefBuilderWrapper<true>& SetIsStateful() {
    230     builder_.SetIsStateful();
    231     return *this;
    232   }
    233   OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
    234     builder_.SetAllowsUninitializedInput();
    235     return *this;
    236   }
    237   OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
    238     builder_.Deprecated(version, explanation);
    239     return *this;
    240   }
    241   OpDefBuilderWrapper<true>& Doc(StringPiece text) {
    242     builder_.Doc(text);
    243     return *this;
    244   }
    245   OpDefBuilderWrapper<true>& SetShapeFn(
    246       Status (*fn)(shape_inference::InferenceContext*)) {
    247     builder_.SetShapeFn(fn);
    248     return *this;
    249   }
    250   const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
    251 
    252  private:
    253   mutable ::tensorflow::OpDefBuilder builder_;
    254 };
    255 
    256 // Template specialization that turns all calls into no-ops.
    257 template <>
    258 class OpDefBuilderWrapper<false> {
    259  public:
    260   constexpr OpDefBuilderWrapper(const char name[]) {}
    261   OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; }
    262   OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; }
    263   OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; }
    264   OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; }
    265   OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
    266   OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
    267   OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
    268   OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
    269   OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
    270   OpDefBuilderWrapper<false>& SetShapeFn(
    271       Status (*fn)(shape_inference::InferenceContext*)) {
    272     return *this;
    273   }
    274 };
    275 
    276 struct OpDefBuilderReceiver {
    277   // To call OpRegistry::Global()->Register(...), used by the
    278   // REGISTER_OP macro below.
    279   // Note: These are implicitly converting constructors.
    280   OpDefBuilderReceiver(
    281       const OpDefBuilderWrapper<true>& wrapper);  // NOLINT(runtime/explicit)
    282   constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) {
    283   }  // NOLINT(runtime/explicit)
    284 };
    285 }  // namespace register_op
    286 
    287 #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
    288 #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
    289 #define REGISTER_OP_UNIQ(ctr, name)                                          \
    290   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr    \
    291       TF_ATTRIBUTE_UNUSED =                                                  \
    292           ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
    293               name)>(name)
    294 
    295 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
    296 // that the op is registered unconditionally even when selective
    297 // registration is used.
    298 #define REGISTER_SYSTEM_OP(name) \
    299   REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name)
    300 #define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \
    301   REGISTER_SYSTEM_OP_UNIQ(ctr, name)
    302 #define REGISTER_SYSTEM_OP_UNIQ(ctr, name)                                \
    303   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
    304       TF_ATTRIBUTE_UNUSED =                                               \
    305           ::tensorflow::register_op::OpDefBuilderWrapper<true>(name)
    306 
    307 }  // namespace tensorflow
    308 
    309 #endif  // TENSORFLOW_FRAMEWORK_OP_H_
    310