Home | History | Annotate | Download | only in utils
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for np_utils."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.keras._impl import keras
     24 from tensorflow.python.platform import test
     25 
     26 
     27 class TestNPUtils(test.TestCase):
     28 
     29   def test_to_categorical(self):
     30     num_classes = 5
     31     shapes = [(1,), (3,), (4, 3), (5, 4, 3), (3, 1), (3, 2, 1)]
     32     expected_shapes = [(1, num_classes),
     33                        (3, num_classes),
     34                        (4, 3, num_classes),
     35                        (5, 4, 3, num_classes),
     36                        (3, num_classes)]
     37     labels = [np.random.randint(0, num_classes, shape) for shape in shapes]
     38     one_hots = [
     39         keras.utils.to_categorical(label, num_classes) for label in labels]
     40     for label, one_hot, expected_shape in zip(labels,
     41                                               one_hots,
     42                                               expected_shapes):
     43       # Check shape
     44       self.assertEqual(one_hot.shape, expected_shape)
     45       # Make sure there is only one 1 in a row
     46       self.assertTrue(np.all(one_hot.sum(axis=-1) == 1))
     47       # Get original labels back from one hots
     48       self.assertTrue(np.all(
     49           np.argmax(one_hot, -1).reshape(label.shape) == label))
     50 
     51 
     52 if __name__ == '__main__':
     53   test.main()
     54