Home | History | Annotate | Download | only in framework
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
     17 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
     18 
     19 #include "tensorflow/core/lib/core/status.h"
     20 
     21 namespace tensorflow {
     22 namespace dataset {
     23 // Registry for stateful ops that need to be used in dataset functions.
     24 // See below macro for usage details.
     25 class WhitelistedStatefulOpRegistry {
     26  public:
     27   Status Add(StringPiece op_name) {
     28     op_names_.insert(op_name);
     29     return Status::OK();
     30   }
     31 
     32   bool Contains(StringPiece op_name) {
     33     return op_names_.find(op_name) != op_names_.end();
     34   }
     35 
     36   static WhitelistedStatefulOpRegistry* Global() {
     37     static WhitelistedStatefulOpRegistry* reg =
     38         new WhitelistedStatefulOpRegistry;
     39     return reg;
     40   }
     41 
     42  private:
     43   WhitelistedStatefulOpRegistry() {}
     44   WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy);
     45   WhitelistedStatefulOpRegistry operator=(
     46       WhitelistedStatefulOpRegistry const& copy);
     47   std::set<StringPiece> op_names_;
     48 };
     49 
     50 }  // namespace dataset
     51 
     52 // Use this macro to whitelist an op that is marked stateful but needs to be
     53 // used inside a map_fn in an input pipeline. This is only needed if you wish
     54 // to be able to checkpoint the state of the input pipeline. We currently
     55 // do not allow stateful ops to be defined inside of map_fns since it is not
     56 // possible to save their state.
     57 // Note that the state of the whitelisted ops inside functions will not be
     58 // saved during checkpointing, hence this should only be used if the op is
     59 // marked stateful for reasons like to avoid constant folding during graph
     60 // optimiztion but is not stateful.
     61 // If possible, try to remove the stateful flag on the op first.
     62 // Example usage:
     63 //
     64 //   WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LegacyStatefulReader");
     65 //
     66 #define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS(name) \
     67   WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name)
     68 #define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \
     69   WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)
     70 #define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)        \
     71   static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED =      \
     72       ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \
     73           name)
     74 
     75 }  // namespace tensorflow
     76 
     77 #endif  // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
     78