Home | History | Annotate | Download | only in test
      1 import imghdr
      2 import io
      3 import sys
      4 import unittest
      5 from test.test_support import findfile, TESTFN, unlink, run_unittest
      6 
      7 TEST_FILES = (
      8     ('python.png', 'png'),
      9     ('python.gif', 'gif'),
     10     ('python.bmp', 'bmp'),
     11     ('python.ppm', 'ppm'),
     12     ('python.pgm', 'pgm'),
     13     ('python.pbm', 'pbm'),
     14     ('python.jpg', 'jpeg'),
     15     ('python.ras', 'rast'),
     16     ('python.sgi', 'rgb'),
     17     ('python.tiff', 'tiff'),
     18     ('python.xbm', 'xbm')
     19 )
     20 
     21 class UnseekableIO(io.FileIO):
     22     def tell(self):
     23         raise io.UnsupportedOperation
     24 
     25     def seek(self, *args, **kwargs):
     26         raise io.UnsupportedOperation
     27 
     28 class TestImghdr(unittest.TestCase):
     29     @classmethod
     30     def setUpClass(cls):
     31         cls.testfile = findfile('python.png', subdir='imghdrdata')
     32         with open(cls.testfile, 'rb') as stream:
     33             cls.testdata = stream.read()
     34 
     35     def tearDown(self):
     36         unlink(TESTFN)
     37 
     38     def test_data(self):
     39         for filename, expected in TEST_FILES:
     40             filename = findfile(filename, subdir='imghdrdata')
     41             self.assertEqual(imghdr.what(filename), expected)
     42             ufilename = filename.decode(sys.getfilesystemencoding())
     43             self.assertEqual(imghdr.what(ufilename), expected)
     44             with open(filename, 'rb') as stream:
     45                 self.assertEqual(imghdr.what(stream), expected)
     46             with open(filename, 'rb') as stream:
     47                 data = stream.read()
     48             self.assertEqual(imghdr.what(None, data), expected)
     49 
     50     def test_register_test(self):
     51         def test_jumbo(h, file):
     52             if h.startswith(b'eggs'):
     53                 return 'ham'
     54         imghdr.tests.append(test_jumbo)
     55         self.addCleanup(imghdr.tests.pop)
     56         self.assertEqual(imghdr.what(None, b'eggs'), 'ham')
     57 
     58     def test_file_pos(self):
     59         with open(TESTFN, 'wb') as stream:
     60             stream.write(b'ababagalamaga')
     61             pos = stream.tell()
     62             stream.write(self.testdata)
     63         with open(TESTFN, 'rb') as stream:
     64             stream.seek(pos)
     65             self.assertEqual(imghdr.what(stream), 'png')
     66             self.assertEqual(stream.tell(), pos)
     67 
     68     def test_bad_args(self):
     69         with self.assertRaises(TypeError):
     70             imghdr.what()
     71         with self.assertRaises(AttributeError):
     72             imghdr.what(None)
     73         with self.assertRaises(TypeError):
     74             imghdr.what(self.testfile, 1)
     75         with open(self.testfile, 'rb') as f:
     76             with self.assertRaises(AttributeError):
     77                 imghdr.what(f.fileno())
     78 
     79     def test_invalid_headers(self):
     80         for header in (b'\211PN\r\n',
     81                        b'\001\331',
     82                        b'\x59\xA6',
     83                        b'cutecat',
     84                        b'000000JFI',
     85                        b'GIF80'):
     86             self.assertIsNone(imghdr.what(None, header))
     87 
     88     def test_missing_file(self):
     89         with self.assertRaises(IOError):
     90             imghdr.what('missing')
     91 
     92     def test_closed_file(self):
     93         stream = open(self.testfile, 'rb')
     94         stream.close()
     95         with self.assertRaises(ValueError) as cm:
     96             imghdr.what(stream)
     97         stream = io.BytesIO(self.testdata)
     98         stream.close()
     99         with self.assertRaises(ValueError) as cm:
    100             imghdr.what(stream)
    101 
    102     def test_unseekable(self):
    103         with open(TESTFN, 'wb') as stream:
    104             stream.write(self.testdata)
    105         with UnseekableIO(TESTFN, 'rb') as stream:
    106             with self.assertRaises(io.UnsupportedOperation):
    107                 imghdr.what(stream)
    108 
    109     def test_output_stream(self):
    110         with open(TESTFN, 'wb') as stream:
    111             stream.write(self.testdata)
    112             stream.seek(0)
    113             with self.assertRaises(IOError) as cm:
    114                 imghdr.what(stream)
    115 
    116 def test_main():
    117     run_unittest(TestImghdr)
    118 
    119 if __name__ == '__main__':
    120     test_main()
    121