1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Registry mechanism for "registering" classes/functions for general use. 17 18 This is typically used with a decorator that calls Register for adding 19 a class or function to a registry. 20 """ 21 22 from __future__ import absolute_import 23 from __future__ import division 24 from __future__ import print_function 25 26 import traceback 27 28 from tensorflow.python.platform import tf_logging as logging 29 from tensorflow.python.util import compat 30 31 32 # Registry mechanism below is based on mapreduce.python.mrpython.Register. 33 _LOCATION_TAG = "location" 34 _TYPE_TAG = "type" 35 36 37 class Registry(object): 38 """Provides a registry for saving objects.""" 39 40 def __init__(self, name): 41 """Creates a new registry.""" 42 self._name = name 43 self._registry = dict() 44 45 def register(self, candidate, name=None): 46 """Registers a Python object "candidate" for the given "name". 47 48 Args: 49 candidate: The candidate object to add to the registry. 50 name: An optional string specifying the registry key for the candidate. 51 If None, candidate.__name__ will be used. 52 Raises: 53 KeyError: If same name is used twice. 54 """ 55 if not name: 56 name = candidate.__name__ 57 if name in self._registry: 58 (filename, line_number, function_name, _) = ( 59 self._registry[name][_LOCATION_TAG]) 60 raise KeyError("Registering two %s with name '%s' !" 61 "(Previous registration was in %s %s:%d)" % 62 (self._name, name, function_name, filename, line_number)) 63 64 logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name) 65 # stack trace is [this_function, Register(), user_function,...] 66 # so the user function is #2. 67 stack = traceback.extract_stack() 68 self._registry[name] = {_TYPE_TAG: candidate, _LOCATION_TAG: stack[2]} 69 70 def list(self): 71 """Lists registered items. 72 73 Returns: 74 A list of names of registered objects. 75 """ 76 return self._registry.keys() 77 78 def lookup(self, name): 79 """Looks up "name". 80 81 Args: 82 name: a string specifying the registry key for the candidate. 83 Returns: 84 Registered object if found 85 Raises: 86 LookupError: if "name" has not been registered. 87 """ 88 name = compat.as_str(name) 89 if name in self._registry: 90 return self._registry[name][_TYPE_TAG] 91 else: 92 raise LookupError( 93 "%s registry has no entry for: %s" % (self._name, name)) 94