1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for DecodeRaw op from parsing_ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import gzip 22 import zlib 23 24 from six import BytesIO 25 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import parsing_ops 29 from tensorflow.python.platform import test 30 31 32 class DecodeCompressedOpTest(test.TestCase): 33 34 def _compress(self, bytes_in, compression_type): 35 if not compression_type: 36 return bytes_in 37 elif compression_type == "ZLIB": 38 return zlib.compress(bytes_in) 39 else: 40 out = BytesIO() 41 with gzip.GzipFile(fileobj=out, mode="wb") as f: 42 f.write(bytes_in) 43 return out.getvalue() 44 45 def testDecompress(self): 46 for compression_type in ["ZLIB", "GZIP", ""]: 47 with self.test_session(): 48 in_bytes = array_ops.placeholder(dtypes.string, shape=[2]) 49 decompressed = parsing_ops.decode_compressed( 50 in_bytes, compression_type=compression_type) 51 self.assertEqual([2], decompressed.get_shape().as_list()) 52 53 result = decompressed.eval( 54 feed_dict={in_bytes: [self._compress(b"AaAA", compression_type), 55 self._compress(b"bBbb", compression_type)]}) 56 self.assertAllEqual([b"AaAA", b"bBbb"], result) 57 58 def testDecompressWithRaw(self): 59 for compression_type in ["ZLIB", "GZIP", ""]: 60 with self.test_session(): 61 in_bytes = array_ops.placeholder(dtypes.string, shape=[None]) 62 decompressed = parsing_ops.decode_compressed( 63 in_bytes, compression_type=compression_type) 64 decode = parsing_ops.decode_raw(decompressed, out_type=dtypes.int16) 65 66 result = decode.eval( 67 feed_dict={in_bytes: [self._compress(b"AaBC", compression_type)]}) 68 self.assertAllEqual( 69 [[ord("A") + ord("a") * 256, ord("B") + ord("C") * 256]], result) 70 71 72 if __name__ == "__main__": 73 test.main() 74