Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for DecodeRaw op from parsing_ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import gzip
     22 import zlib
     23 
     24 from six import BytesIO
     25 
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import parsing_ops
     29 from tensorflow.python.platform import test
     30 
     31 
     32 class DecodeCompressedOpTest(test.TestCase):
     33 
     34   def _compress(self, bytes_in, compression_type):
     35     if not compression_type:
     36       return bytes_in
     37     elif compression_type == "ZLIB":
     38       return zlib.compress(bytes_in)
     39     else:
     40       out = BytesIO()
     41       with gzip.GzipFile(fileobj=out, mode="wb") as f:
     42         f.write(bytes_in)
     43       return out.getvalue()
     44 
     45   def testDecompress(self):
     46     for compression_type in ["ZLIB", "GZIP", ""]:
     47       with self.test_session():
     48         in_bytes = array_ops.placeholder(dtypes.string, shape=[2])
     49         decompressed = parsing_ops.decode_compressed(
     50             in_bytes, compression_type=compression_type)
     51         self.assertEqual([2], decompressed.get_shape().as_list())
     52 
     53         result = decompressed.eval(
     54             feed_dict={in_bytes: [self._compress(b"AaAA", compression_type),
     55                                   self._compress(b"bBbb", compression_type)]})
     56         self.assertAllEqual([b"AaAA", b"bBbb"], result)
     57 
     58   def testDecompressWithRaw(self):
     59     for compression_type in ["ZLIB", "GZIP", ""]:
     60       with self.test_session():
     61         in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
     62         decompressed = parsing_ops.decode_compressed(
     63             in_bytes, compression_type=compression_type)
     64         decode = parsing_ops.decode_raw(decompressed, out_type=dtypes.int16)
     65 
     66         result = decode.eval(
     67             feed_dict={in_bytes: [self._compress(b"AaBC", compression_type)]})
     68         self.assertAllEqual(
     69             [[ord("A") + ord("a") * 256, ord("B") + ord("C") * 256]], result)
     70 
     71 
     72 if __name__ == "__main__":
     73   test.main()
     74