Home | History | Annotate | Download | only in impl
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Symbol naming utilities."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.py2tf.pyct import qual_names
     22 
     23 
     24 class Namer(object):
     25   """Implementation of the namer interfaces required by various converters.
     26 
     27   This implementation performs additional tasks like keeping track of the
     28   function calls that have been encountered and replaced with calls to their
     29   corresponding compiled counterparts.
     30 
     31   Interfaces currently implemented:
     32     * call_trees.FunctionNamer
     33     * control_flow.SymbolNamer
     34     * side_effect_guards.SymbolNamer
     35   """
     36 
     37   def __init__(self, global_namespace, recursive, name_map, partial_types):
     38     self.global_namespace = global_namespace
     39     self.recursive = recursive
     40     self.partial_types = partial_types
     41 
     42     self.renamed_calls = {}
     43     if name_map is not None:
     44       self.renamed_calls.update(name_map)
     45 
     46     self.generated_names = set()
     47 
     48   def compiled_class_name(self, original_fqn, live_entity=None):
     49     """See call_trees.FunctionNamer.compiled_class_name."""
     50     if live_entity is not None and live_entity in self.renamed_calls:
     51       return self.renamed_calls[live_entity]
     52 
     53     if isinstance(original_fqn, tuple):
     54       original_name = '__'.join(original_fqn)
     55     else:
     56       original_name = original_fqn
     57 
     58     new_name_root = 'Tf%s' % original_name
     59     new_name = new_name_root
     60     n = 0
     61     while new_name in self.global_namespace:
     62       n += 1
     63       new_name = '%s_%d' % (new_name_root, n)
     64 
     65     if live_entity is not None:
     66       self.renamed_calls[live_entity] = new_name
     67     self.generated_names.add(new_name)
     68     if live_entity is not None:
     69       self.renamed_calls[live_entity] = new_name
     70     return new_name
     71 
     72   def compiled_function_name(self,
     73                              original_fqn,
     74                              live_entity=None,
     75                              owner_type=None):
     76     """See call_trees.FunctionNamer.compiled_function_name."""
     77 
     78     if not self.recursive:
     79       return None, False
     80 
     81     if owner_type is not None and owner_type not in self.partial_types:
     82       # Members are not renamed when part of an entire converted class.
     83       return None, False
     84 
     85     if isinstance(original_fqn, tuple):
     86       original_name = '__'.join(original_fqn)
     87     else:
     88       original_name = original_fqn
     89 
     90     if live_entity is not None and live_entity in self.renamed_calls:
     91       return self.renamed_calls[live_entity], True
     92 
     93     new_name_root = 'tf__%s' % original_name
     94     new_name = new_name_root
     95     n = 0
     96     while new_name in self.global_namespace:
     97       n += 1
     98       new_name = '%s_%d' % (new_name_root, n)
     99 
    100     if live_entity is not None:
    101       self.renamed_calls[live_entity] = new_name
    102     self.generated_names.add(new_name)
    103 
    104     return new_name, True
    105 
    106   def new_symbol(self, name_root, reserved_locals):
    107     """See control_flow.SymbolNamer.new_symbol."""
    108     # reserved_locals may contain QNs.
    109     all_reserved_locals = set()
    110     for s in reserved_locals:
    111       if isinstance(s, qual_names.QN):
    112         all_reserved_locals.update(s.qn)
    113       elif isinstance(s, str):
    114         all_reserved_locals.add(s)
    115       else:
    116         raise ValueError('Unexpected symbol type "%s"' % type(s))
    117 
    118     pieces = name_root.split('_')
    119     if pieces[-1].isdigit():
    120       name_root = '_'.join(pieces[:-1])
    121       n = int(pieces[-1])
    122     else:
    123       n = 0
    124     new_name = name_root
    125 
    126     while (new_name in self.global_namespace or
    127            new_name in all_reserved_locals or new_name in self.generated_names):
    128       n += 1
    129       new_name = '%s_%d' % (name_root, n)
    130 
    131     self.generated_names.add(new_name)
    132     return new_name
    133