Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Registry mechanism for "registering" classes/functions for general use.
     17 
     18 This is typically used with a decorator that calls Register for adding
     19 a class or function to a registry.
     20 """
     21 
     22 from __future__ import absolute_import
     23 from __future__ import division
     24 from __future__ import print_function
     25 
     26 import traceback
     27 
     28 from tensorflow.python.platform import tf_logging as logging
     29 from tensorflow.python.util import compat
     30 
     31 
     32 # Registry mechanism below is based on mapreduce.python.mrpython.Register.
     33 _LOCATION_TAG = "location"
     34 _TYPE_TAG = "type"
     35 
     36 
     37 class Registry(object):
     38   """Provides a registry for saving objects."""
     39 
     40   def __init__(self, name):
     41     """Creates a new registry."""
     42     self._name = name
     43     self._registry = dict()
     44 
     45   def register(self, candidate, name=None):
     46     """Registers a Python object "candidate" for the given "name".
     47 
     48     Args:
     49       candidate: The candidate object to add to the registry.
     50       name: An optional string specifying the registry key for the candidate.
     51             If None, candidate.__name__ will be used.
     52     Raises:
     53       KeyError: If same name is used twice.
     54     """
     55     if not name:
     56       name = candidate.__name__
     57     if name in self._registry:
     58       (filename, line_number, function_name, _) = (
     59           self._registry[name][_LOCATION_TAG])
     60       raise KeyError("Registering two %s with name '%s' !"
     61                      "(Previous registration was in %s %s:%d)" %
     62                      (self._name, name, function_name, filename, line_number))
     63 
     64     logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name)
     65     # stack trace is [this_function, Register(), user_function,...]
     66     # so the user function is #2.
     67     stack = traceback.extract_stack()
     68     self._registry[name] = {_TYPE_TAG: candidate, _LOCATION_TAG: stack[2]}
     69 
     70   def list(self):
     71     """Lists registered items.
     72 
     73     Returns:
     74       A list of names of registered objects.
     75     """
     76     return self._registry.keys()
     77 
     78   def lookup(self, name):
     79     """Looks up "name".
     80 
     81     Args:
     82       name: a string specifying the registry key for the candidate.
     83     Returns:
     84       Registered object if found
     85     Raises:
     86       LookupError: if "name" has not been registered.
     87     """
     88     name = compat.as_str(name)
     89     if name in self._registry:
     90       return self._registry[name][_TYPE_TAG]
     91     else:
     92       raise LookupError(
     93           "%s registry has no entry for: %s" % (self._name, name))
     94