1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Test cases for operators with no arguments.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.compiler.tests.xla_test import XLATestCase 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.ops import control_flow_ops 26 from tensorflow.python.platform import googletest 27 28 29 class NullaryOpsTest(XLATestCase): 30 31 def _testNullary(self, op, expected): 32 with self.test_session() as session: 33 with self.test_scope(): 34 output = op() 35 result = session.run(output) 36 self.assertAllClose(result, expected, rtol=1e-3) 37 38 def testNoOp(self): 39 with self.test_session(): 40 with self.test_scope(): 41 output = control_flow_ops.no_op() 42 # This should not crash. 43 output.run() 44 45 def testConstants(self): 46 constants = [ 47 np.float32(42), 48 np.array([], dtype=np.float32), 49 np.array([1, 2], dtype=np.float32), 50 np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), 51 np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], 52 dtype=np.float32), 53 np.array([[[]], [[]]], dtype=np.float32), 54 np.array([[[[1]]]], dtype=np.float32), 55 ] 56 for c in constants: 57 self._testNullary(lambda c=c: constant_op.constant(c), expected=c) 58 59 60 if __name__ == "__main__": 61 googletest.main() 62