Home | History | Annotate | Download | only in converters
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Converter for logical expressions.
     16 
     17 e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF.
     18 """
     19 
     20 from __future__ import absolute_import
     21 from __future__ import division
     22 from __future__ import print_function
     23 
     24 import gast
     25 
     26 from tensorflow.contrib.py2tf.pyct import parser
     27 
     28 
     29 class LogicalExpressionTransformer(gast.NodeTransformer):
     30   """Converts logical expressions to corresponding TF calls."""
     31 
     32   def __init__(self):
     33     # TODO(mdan): Look into replacing with bitwise operators instead.
     34     self.op_mapping = {
     35         gast.And: 'tf.logical_and',
     36         gast.Or: 'tf.logical_or',
     37         gast.Not: 'tf.logical_not',
     38         gast.Eq: 'tf.equal',
     39     }
     40 
     41   def visit_Compare(self, node):
     42     node = self.generic_visit(node)
     43     if len(node.ops) > 1:
     44       raise NotImplementedError()
     45     cmp_type = type(node.ops[0])
     46     if cmp_type in self.op_mapping:
     47       tf_function = parser.parse_str(self.op_mapping[cmp_type]).body[0].value
     48       return gast.Call(
     49           func=tf_function, args=[node.left, node.comparators[0]], keywords=[])
     50     return node
     51 
     52   def visit_UnaryOp(self, node):
     53     node = self.generic_visit(node)
     54     if isinstance(node.op, gast.Not):
     55       tf_function = parser.parse_str(self.op_mapping[type(
     56           node.op)]).body[0].value
     57       node = gast.Call(func=tf_function, args=[node.operand], keywords=[])
     58     return node
     59 
     60   def visit_BoolOp(self, node):
     61     # TODO(mdan): A normalizer may be useful here. Use ANF?
     62     node = self.generic_visit(node)
     63     tf_function = parser.parse_str(self.op_mapping[type(node.op)]).body[0].value
     64     left = node.values[0]
     65     for i in range(1, len(node.values)):
     66       left = gast.Call(
     67           func=tf_function, args=[left, node.values[i]], keywords=[])
     68     return left
     69 
     70 
     71 def transform(node):
     72   transformer = LogicalExpressionTransformer()
     73   node = transformer.visit(node)
     74   return node
     75