1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for PrecisionOp.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import errors_impl 25 from tensorflow.python.ops import nn_ops 26 from tensorflow.python.platform import test 27 28 29 class InTopKTest(test.TestCase): 30 31 def _validateInTopK(self, predictions, target, k, expected): 32 np_ans = np.array(expected) 33 with self.test_session(): 34 precision = nn_ops.in_top_k(predictions, target, k) 35 out = precision.eval() 36 self.assertAllClose(np_ans, out) 37 self.assertShapeEqual(np_ans, precision) 38 39 def testInTop1(self): 40 predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] 41 target = [3, 1] 42 self._validateInTopK(predictions, target, 1, [True, False]) 43 44 def testInTop2(self): 45 predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] 46 target = [0, 2] 47 self._validateInTopK(predictions, target, 2, [False, True]) 48 49 def testInTop2Tie(self): 50 # Class 2 and 3 tie for 2nd, so both are considered in top 2. 51 predictions = [[0.1, 0.3, 0.2, 0.2], [0.1, 0.3, 0.2, 0.2]] 52 target = [2, 3] 53 self._validateInTopK(predictions, target, 2, [True, True]) 54 55 def testInTop2_int64Target(self): 56 predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] 57 target = np.asarray([0, 2]).astype(np.int64) 58 self._validateInTopK(predictions, target, 2, [False, True]) 59 60 def testInTopNan(self): 61 predictions = [[0.1, float("nan"), 0.2, 0.4], [0.1, 0.2, 0.3, float("inf")]] 62 target = [0, 2] 63 self._validateInTopK(predictions, target, 2, [False, False]) 64 65 def testBadTarget(self): 66 predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] 67 target = [0, 80000] 68 with self.test_session(): 69 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 70 "target.*out of range"): 71 nn_ops.in_top_k(predictions, target, 2).eval() 72 73 def testTensorK(self): 74 predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] 75 target = [0, 2] 76 k = constant_op.constant(3) 77 np_ans = np.array([False, True]) 78 with self.test_session(): 79 precision = nn_ops.in_top_k(predictions, target, k) 80 out = precision.eval() 81 self.assertAllClose(np_ans, out) 82 self.assertShapeEqual(np_ans, precision) 83 84 if __name__ == "__main__": 85 test.main() 86