1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for np_utils.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.keras._impl import keras 24 from tensorflow.python.platform import test 25 26 27 class TestNPUtils(test.TestCase): 28 29 def test_to_categorical(self): 30 num_classes = 5 31 shapes = [(1,), (3,), (4, 3), (5, 4, 3), (3, 1), (3, 2, 1)] 32 expected_shapes = [(1, num_classes), 33 (3, num_classes), 34 (4, 3, num_classes), 35 (5, 4, 3, num_classes), 36 (3, num_classes)] 37 labels = [np.random.randint(0, num_classes, shape) for shape in shapes] 38 one_hots = [ 39 keras.utils.to_categorical(label, num_classes) for label in labels] 40 for label, one_hot, expected_shape in zip(labels, 41 one_hots, 42 expected_shapes): 43 # Check shape 44 self.assertEqual(one_hot.shape, expected_shape) 45 # Make sure there is only one 1 in a row 46 self.assertTrue(np.all(one_hot.sum(axis=-1) == 1)) 47 # Get original labels back from one hots 48 self.assertTrue(np.all( 49 np.argmax(one_hot, -1).reshape(label.shape) == label)) 50 51 52 if __name__ == '__main__': 53 test.main() 54