Home | History | Annotate | Download | only in test
      1 import imghdr
      2 import io
      3 import os
      4 import pathlib
      5 import unittest
      6 import warnings
      7 from test.support import findfile, TESTFN, unlink
      8 
      9 TEST_FILES = (
     10     ('python.png', 'png'),
     11     ('python.gif', 'gif'),
     12     ('python.bmp', 'bmp'),
     13     ('python.ppm', 'ppm'),
     14     ('python.pgm', 'pgm'),
     15     ('python.pbm', 'pbm'),
     16     ('python.jpg', 'jpeg'),
     17     ('python.ras', 'rast'),
     18     ('python.sgi', 'rgb'),
     19     ('python.tiff', 'tiff'),
     20     ('python.xbm', 'xbm'),
     21     ('python.webp', 'webp'),
     22     ('python.exr', 'exr'),
     23 )
     24 
     25 class UnseekableIO(io.FileIO):
     26     def tell(self):
     27         raise io.UnsupportedOperation
     28 
     29     def seek(self, *args, **kwargs):
     30         raise io.UnsupportedOperation
     31 
     32 class TestImghdr(unittest.TestCase):
     33     @classmethod
     34     def setUpClass(cls):
     35         cls.testfile = findfile('python.png', subdir='imghdrdata')
     36         with open(cls.testfile, 'rb') as stream:
     37             cls.testdata = stream.read()
     38 
     39     def tearDown(self):
     40         unlink(TESTFN)
     41 
     42     def test_data(self):
     43         for filename, expected in TEST_FILES:
     44             filename = findfile(filename, subdir='imghdrdata')
     45             self.assertEqual(imghdr.what(filename), expected)
     46             with open(filename, 'rb') as stream:
     47                 self.assertEqual(imghdr.what(stream), expected)
     48             with open(filename, 'rb') as stream:
     49                 data = stream.read()
     50             self.assertEqual(imghdr.what(None, data), expected)
     51             self.assertEqual(imghdr.what(None, bytearray(data)), expected)
     52 
     53     def test_pathlike_filename(self):
     54         for filename, expected in TEST_FILES:
     55             with self.subTest(filename=filename):
     56                 filename = findfile(filename, subdir='imghdrdata')
     57                 self.assertEqual(imghdr.what(pathlib.Path(filename)), expected)
     58 
     59     def test_register_test(self):
     60         def test_jumbo(h, file):
     61             if h.startswith(b'eggs'):
     62                 return 'ham'
     63         imghdr.tests.append(test_jumbo)
     64         self.addCleanup(imghdr.tests.pop)
     65         self.assertEqual(imghdr.what(None, b'eggs'), 'ham')
     66 
     67     def test_file_pos(self):
     68         with open(TESTFN, 'wb') as stream:
     69             stream.write(b'ababagalamaga')
     70             pos = stream.tell()
     71             stream.write(self.testdata)
     72         with open(TESTFN, 'rb') as stream:
     73             stream.seek(pos)
     74             self.assertEqual(imghdr.what(stream), 'png')
     75             self.assertEqual(stream.tell(), pos)
     76 
     77     def test_bad_args(self):
     78         with self.assertRaises(TypeError):
     79             imghdr.what()
     80         with self.assertRaises(AttributeError):
     81             imghdr.what(None)
     82         with self.assertRaises(TypeError):
     83             imghdr.what(self.testfile, 1)
     84         with self.assertRaises(AttributeError):
     85             imghdr.what(os.fsencode(self.testfile))
     86         with open(self.testfile, 'rb') as f:
     87             with self.assertRaises(AttributeError):
     88                 imghdr.what(f.fileno())
     89 
     90     def test_invalid_headers(self):
     91         for header in (b'\211PN\r\n',
     92                        b'\001\331',
     93                        b'\x59\xA6',
     94                        b'cutecat',
     95                        b'000000JFI',
     96                        b'GIF80'):
     97             self.assertIsNone(imghdr.what(None, header))
     98 
     99     def test_string_data(self):
    100         with warnings.catch_warnings():
    101             warnings.simplefilter("ignore", BytesWarning)
    102             for filename, _ in TEST_FILES:
    103                 filename = findfile(filename, subdir='imghdrdata')
    104                 with open(filename, 'rb') as stream:
    105                     data = stream.read().decode('latin1')
    106                 with self.assertRaises(TypeError):
    107                     imghdr.what(io.StringIO(data))
    108                 with self.assertRaises(TypeError):
    109                     imghdr.what(None, data)
    110 
    111     def test_missing_file(self):
    112         with self.assertRaises(FileNotFoundError):
    113             imghdr.what('missing')
    114 
    115     def test_closed_file(self):
    116         stream = open(self.testfile, 'rb')
    117         stream.close()
    118         with self.assertRaises(ValueError) as cm:
    119             imghdr.what(stream)
    120         stream = io.BytesIO(self.testdata)
    121         stream.close()
    122         with self.assertRaises(ValueError) as cm:
    123             imghdr.what(stream)
    124 
    125     def test_unseekable(self):
    126         with open(TESTFN, 'wb') as stream:
    127             stream.write(self.testdata)
    128         with UnseekableIO(TESTFN, 'rb') as stream:
    129             with self.assertRaises(io.UnsupportedOperation):
    130                 imghdr.what(stream)
    131 
    132     def test_output_stream(self):
    133         with open(TESTFN, 'wb') as stream:
    134             stream.write(self.testdata)
    135             stream.seek(0)
    136             with self.assertRaises(OSError) as cm:
    137                 imghdr.what(stream)
    138 
    139 if __name__ == '__main__':
    140     unittest.main()
    141