1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Symbol naming utilities.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.py2tf.pyct import qual_names 22 23 24 class Namer(object): 25 """Implementation of the namer interfaces required by various converters. 26 27 This implementation performs additional tasks like keeping track of the 28 function calls that have been encountered and replaced with calls to their 29 corresponding compiled counterparts. 30 31 Interfaces currently implemented: 32 * call_trees.FunctionNamer 33 * control_flow.SymbolNamer 34 * side_effect_guards.SymbolNamer 35 """ 36 37 def __init__(self, global_namespace, recursive, name_map, partial_types): 38 self.global_namespace = global_namespace 39 self.recursive = recursive 40 self.partial_types = partial_types 41 42 self.renamed_calls = {} 43 if name_map is not None: 44 self.renamed_calls.update(name_map) 45 46 self.generated_names = set() 47 48 def compiled_class_name(self, original_fqn, live_entity=None): 49 """See call_trees.FunctionNamer.compiled_class_name.""" 50 if live_entity is not None and live_entity in self.renamed_calls: 51 return self.renamed_calls[live_entity] 52 53 if isinstance(original_fqn, tuple): 54 original_name = '__'.join(original_fqn) 55 else: 56 original_name = original_fqn 57 58 new_name_root = 'Tf%s' % original_name 59 new_name = new_name_root 60 n = 0 61 while new_name in self.global_namespace: 62 n += 1 63 new_name = '%s_%d' % (new_name_root, n) 64 65 if live_entity is not None: 66 self.renamed_calls[live_entity] = new_name 67 self.generated_names.add(new_name) 68 if live_entity is not None: 69 self.renamed_calls[live_entity] = new_name 70 return new_name 71 72 def compiled_function_name(self, 73 original_fqn, 74 live_entity=None, 75 owner_type=None): 76 """See call_trees.FunctionNamer.compiled_function_name.""" 77 78 if not self.recursive: 79 return None, False 80 81 if owner_type is not None and owner_type not in self.partial_types: 82 # Members are not renamed when part of an entire converted class. 83 return None, False 84 85 if isinstance(original_fqn, tuple): 86 original_name = '__'.join(original_fqn) 87 else: 88 original_name = original_fqn 89 90 if live_entity is not None and live_entity in self.renamed_calls: 91 return self.renamed_calls[live_entity], True 92 93 new_name_root = 'tf__%s' % original_name 94 new_name = new_name_root 95 n = 0 96 while new_name in self.global_namespace: 97 n += 1 98 new_name = '%s_%d' % (new_name_root, n) 99 100 if live_entity is not None: 101 self.renamed_calls[live_entity] = new_name 102 self.generated_names.add(new_name) 103 104 return new_name, True 105 106 def new_symbol(self, name_root, reserved_locals): 107 """See control_flow.SymbolNamer.new_symbol.""" 108 # reserved_locals may contain QNs. 109 all_reserved_locals = set() 110 for s in reserved_locals: 111 if isinstance(s, qual_names.QN): 112 all_reserved_locals.update(s.qn) 113 elif isinstance(s, str): 114 all_reserved_locals.add(s) 115 else: 116 raise ValueError('Unexpected symbol type "%s"' % type(s)) 117 118 pieces = name_root.split('_') 119 if pieces[-1].isdigit(): 120 name_root = '_'.join(pieces[:-1]) 121 n = int(pieces[-1]) 122 else: 123 n = 0 124 new_name = name_root 125 126 while (new_name in self.global_namespace or 127 new_name in all_reserved_locals or new_name in self.generated_names): 128 n += 1 129 new_name = '%s_%d' % (name_root, n) 130 131 self.generated_names.add(new_name) 132 return new_name 133