Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_
     17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_
     18 
     19 #include <functional>
     20 #include <unordered_map>
     21 
     22 #include <vector>
     23 #include "tensorflow/core/framework/op_def_builder.h"
     24 #include "tensorflow/core/framework/op_def_util.h"
     25 #include "tensorflow/core/framework/selective_registration.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/strings/str_util.h"
     29 #include "tensorflow/core/lib/strings/strcat.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/platform/mutex.h"
     33 #include "tensorflow/core/platform/thread_annotations.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace tensorflow {
     37 
     38 // Users that want to look up an OpDef by type name should take an
     39 // OpRegistryInterface.  Functions accepting a
     40 // (const) OpRegistryInterface* may call LookUp() from multiple threads.
     41 class OpRegistryInterface {
     42  public:
     43   virtual ~OpRegistryInterface();
     44 
     45   // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
     46   // registered under that name, otherwise returns the registered OpDef.
     47   // Caller must not delete the returned pointer.
     48   virtual Status LookUp(const string& op_type_name,
     49                         const OpRegistrationData** op_reg_data) const = 0;
     50 
     51   // Shorthand for calling LookUp to get the OpDef.
     52   Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
     53 };
     54 
     55 // The standard implementation of OpRegistryInterface, along with a
     56 // global singleton used for registering ops via the REGISTER
     57 // macros below.  Thread-safe.
     58 //
     59 // Example registration:
     60 //   OpRegistry::Global()->Register(
     61 //     [](OpRegistrationData* op_reg_data)->Status {
     62 //       // Populate *op_reg_data here.
     63 //       return Status::OK();
     64 //   });
     65 class OpRegistry : public OpRegistryInterface {
     66  public:
     67   typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
     68 
     69   OpRegistry();
     70   ~OpRegistry() override;
     71 
     72   void Register(const OpRegistrationDataFactory& op_data_factory);
     73 
     74   Status LookUp(const string& op_type_name,
     75                 const OpRegistrationData** op_reg_data) const override;
     76 
     77   // Fills *ops with all registered OpDefs (except those with names
     78   // starting with '_' if include_internal == false) sorted in
     79   // ascending alphabetical order.
     80   void Export(bool include_internal, OpList* ops) const;
     81 
     82   // Returns ASCII-format OpList for all registered OpDefs (except
     83   // those with names starting with '_' if include_internal == false).
     84   string DebugString(bool include_internal) const;
     85 
     86   // A singleton available at startup.
     87   static OpRegistry* Global();
     88 
     89   // Get all registered ops.
     90   void GetRegisteredOps(std::vector<OpDef>* op_defs);
     91 
     92   // Get all `OpRegistrationData`s.
     93   void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
     94 
     95   // Watcher, a function object.
     96   // The watcher, if set by SetWatcher(), is called every time an op is
     97   // registered via the Register function. The watcher is passed the Status
     98   // obtained from building and adding the OpDef to the registry, and the OpDef
     99   // itself if it was successfully built. A watcher returns a Status which is in
    100   // turn returned as the final registration status.
    101   typedef std::function<Status(const Status&, const OpDef&)> Watcher;
    102 
    103   // An OpRegistry object has only one watcher. This interface is not thread
    104   // safe, as different clients are free to set the watcher any time.
    105   // Clients are expected to atomically perform the following sequence of
    106   // operations :
    107   // SetWatcher(a_watcher);
    108   // Register some ops;
    109   // op_registry->ProcessRegistrations();
    110   // SetWatcher(nullptr);
    111   // Returns a non-OK status if a non-null watcher is over-written by another
    112   // non-null watcher.
    113   Status SetWatcher(const Watcher& watcher);
    114 
    115   // Process the current list of deferred registrations. Note that calls to
    116   // Export, LookUp and DebugString would also implicitly process the deferred
    117   // registrations. Returns the status of the first failed op registration or
    118   // Status::OK() otherwise.
    119   Status ProcessRegistrations() const;
    120 
    121   // Defer the registrations until a later call to a function that processes
    122   // deferred registrations are made. Normally, registrations that happen after
    123   // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
    124   // immediately. Call this to defer future registrations.
    125   void DeferRegistrations();
    126 
    127   // Clear the registrations that have been deferred.
    128   void ClearDeferredRegistrations();
    129 
    130  private:
    131   // Ensures that all the functions in deferred_ get called, their OpDef's
    132   // registered, and returns with deferred_ empty.  Returns true the first
    133   // time it is called. Prints a fatal log if any op registration fails.
    134   bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
    135 
    136   // Calls the functions in deferred_ and registers their OpDef's
    137   // It returns the Status of the first failed op registration or Status::OK()
    138   // otherwise.
    139   Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
    140 
    141   // Add 'def' to the registry with additional data 'data'. On failure, or if
    142   // there is already an OpDef with that name registered, returns a non-okay
    143   // status.
    144   Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
    145       const EXCLUSIVE_LOCKS_REQUIRED(mu_);
    146 
    147   Status LookUpSlow(const string& op_type_name,
    148                     const OpRegistrationData** op_reg_data) const;
    149 
    150   mutable mutex mu_;
    151   // Functions in deferred_ may only be called with mu_ held.
    152   mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
    153   // Values are owned.
    154   mutable std::unordered_map<string, const OpRegistrationData*> registry_
    155       GUARDED_BY(mu_);
    156   mutable bool initialized_ GUARDED_BY(mu_);
    157 
    158   // Registry watcher.
    159   mutable Watcher watcher_ GUARDED_BY(mu_);
    160 };
    161 
    162 // An adapter to allow an OpList to be used as an OpRegistryInterface.
    163 //
    164 // Note that shape inference functions are not passed in to OpListOpRegistry, so
    165 // it will return an unusable shape inference function for every op it supports;
    166 // therefore, it should only be used in contexts where this is okay.
    167 class OpListOpRegistry : public OpRegistryInterface {
    168  public:
    169   // Does not take ownership of op_list, *op_list must outlive *this.
    170   OpListOpRegistry(const OpList* op_list);
    171   ~OpListOpRegistry() override;
    172   Status LookUp(const string& op_type_name,
    173                 const OpRegistrationData** op_reg_data) const override;
    174 
    175  private:
    176   // Values are owned.
    177   std::unordered_map<string, const OpRegistrationData*> index_;
    178 };
    179 
    180 // Support for defining the OpDef (specifying the semantics of the Op and how
    181 // it should be created) and registering it in the OpRegistry::Global()
    182 // registry.  Usage:
    183 //
    184 // REGISTER_OP("my_op_name")
    185 //     .Attr("<name>:<type>")
    186 //     .Attr("<name>:<type>=<default>")
    187 //     .Input("<name>:<type-expr>")
    188 //     .Input("<name>:Ref(<type-expr>)")
    189 //     .Output("<name>:<type-expr>")
    190 //     .Doc(R"(
    191 // <1-line summary>
    192 // <rest of the description (potentially many lines)>
    193 // <name-of-attr-input-or-output>: <description of name>
    194 // <name-of-attr-input-or-output>: <description of name;
    195 //   if long, indent the description on subsequent lines>
    196 // )");
    197 //
    198 // Note: .Doc() should be last.
    199 // For details, see the OpDefBuilder class in op_def_builder.h.
    200 
    201 namespace register_op {
    202 
    203 // OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP
    204 // calls. This allows the result of REGISTER_OP to be used in chaining, as in
    205 // REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective
    206 // registration to turn the entire call-chain into a no-op.
    207 template <bool should_register>
    208 class OpDefBuilderWrapper;
    209 
    210 // Template specialization that forwards all calls to the contained builder.
    211 template <>
    212 class OpDefBuilderWrapper<true> {
    213  public:
    214   OpDefBuilderWrapper(const char name[]) : builder_(name) {}
    215   OpDefBuilderWrapper<true>& Attr(string spec) {
    216     builder_.Attr(std::move(spec));
    217     return *this;
    218   }
    219   OpDefBuilderWrapper<true>& Input(string spec) {
    220     builder_.Input(std::move(spec));
    221     return *this;
    222   }
    223   OpDefBuilderWrapper<true>& Output(string spec) {
    224     builder_.Output(std::move(spec));
    225     return *this;
    226   }
    227   OpDefBuilderWrapper<true>& SetIsCommutative() {
    228     builder_.SetIsCommutative();
    229     return *this;
    230   }
    231   OpDefBuilderWrapper<true>& SetIsAggregate() {
    232     builder_.SetIsAggregate();
    233     return *this;
    234   }
    235   OpDefBuilderWrapper<true>& SetIsStateful() {
    236     builder_.SetIsStateful();
    237     return *this;
    238   }
    239   OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
    240     builder_.SetAllowsUninitializedInput();
    241     return *this;
    242   }
    243   OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
    244     builder_.Deprecated(version, std::move(explanation));
    245     return *this;
    246   }
    247   OpDefBuilderWrapper<true>& Doc(string text) {
    248     builder_.Doc(std::move(text));
    249     return *this;
    250   }
    251   OpDefBuilderWrapper<true>& SetShapeFn(
    252       Status (*fn)(shape_inference::InferenceContext*)) {
    253     builder_.SetShapeFn(fn);
    254     return *this;
    255   }
    256   const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
    257 
    258  private:
    259   mutable ::tensorflow::OpDefBuilder builder_;
    260 };
    261 
    262 // Template specialization that turns all calls into no-ops.
    263 template <>
    264 class OpDefBuilderWrapper<false> {
    265  public:
    266   constexpr OpDefBuilderWrapper(const char name[]) {}
    267   OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; }
    268   OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; }
    269   OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; }
    270   OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; }
    271   OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
    272   OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
    273   OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
    274   OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
    275   OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
    276   OpDefBuilderWrapper<false>& SetShapeFn(
    277       Status (*fn)(shape_inference::InferenceContext*)) {
    278     return *this;
    279   }
    280 };
    281 
    282 struct OpDefBuilderReceiver {
    283   // To call OpRegistry::Global()->Register(...), used by the
    284   // REGISTER_OP macro below.
    285   // Note: These are implicitly converting constructors.
    286   OpDefBuilderReceiver(
    287       const OpDefBuilderWrapper<true>& wrapper);  // NOLINT(runtime/explicit)
    288   constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) {
    289   }  // NOLINT(runtime/explicit)
    290 };
    291 }  // namespace register_op
    292 
    293 #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
    294 #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
    295 #define REGISTER_OP_UNIQ(ctr, name)                                          \
    296   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr    \
    297       TF_ATTRIBUTE_UNUSED =                                                  \
    298           ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
    299               name)>(name)
    300 
    301 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
    302 // that the op is registered unconditionally even when selective
    303 // registration is used.
    304 #define REGISTER_SYSTEM_OP(name) \
    305   REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name)
    306 #define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \
    307   REGISTER_SYSTEM_OP_UNIQ(ctr, name)
    308 #define REGISTER_SYSTEM_OP_UNIQ(ctr, name)                                \
    309   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
    310       TF_ATTRIBUTE_UNUSED =                                               \
    311           ::tensorflow::register_op::OpDefBuilderWrapper<true>(name)
    312 
    313 }  // namespace tensorflow
    314 
    315 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_H_
    316