1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Converter for logical expressions. 16 17 e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF. 18 """ 19 20 from __future__ import absolute_import 21 from __future__ import division 22 from __future__ import print_function 23 24 import gast 25 26 from tensorflow.contrib.py2tf.pyct import parser 27 28 29 class LogicalExpressionTransformer(gast.NodeTransformer): 30 """Converts logical expressions to corresponding TF calls.""" 31 32 def __init__(self): 33 # TODO(mdan): Look into replacing with bitwise operators instead. 34 self.op_mapping = { 35 gast.And: 'tf.logical_and', 36 gast.Or: 'tf.logical_or', 37 gast.Not: 'tf.logical_not', 38 gast.Eq: 'tf.equal', 39 } 40 41 def visit_Compare(self, node): 42 node = self.generic_visit(node) 43 if len(node.ops) > 1: 44 raise NotImplementedError() 45 cmp_type = type(node.ops[0]) 46 if cmp_type in self.op_mapping: 47 tf_function = parser.parse_str(self.op_mapping[cmp_type]).body[0].value 48 return gast.Call( 49 func=tf_function, args=[node.left, node.comparators[0]], keywords=[]) 50 return node 51 52 def visit_UnaryOp(self, node): 53 node = self.generic_visit(node) 54 if isinstance(node.op, gast.Not): 55 tf_function = parser.parse_str(self.op_mapping[type( 56 node.op)]).body[0].value 57 node = gast.Call(func=tf_function, args=[node.operand], keywords=[]) 58 return node 59 60 def visit_BoolOp(self, node): 61 # TODO(mdan): A normalizer may be useful here. Use ANF? 62 node = self.generic_visit(node) 63 tf_function = parser.parse_str(self.op_mapping[type(node.op)]).body[0].value 64 left = node.values[0] 65 for i in range(1, len(node.values)): 66 left = gast.Call( 67 func=tf_function, args=[left, node.values[i]], keywords=[]) 68 return left 69 70 71 def transform(node): 72 transformer = LogicalExpressionTransformer() 73 node = transformer.visit(node) 74 return node 75