1 # encoding: utf-8 2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================== 16 """Text processor tests.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 from __future__ import unicode_literals 22 23 from tensorflow.contrib.learn.python.learn.preprocessing import CategoricalVocabulary 24 from tensorflow.contrib.learn.python.learn.preprocessing import text 25 from tensorflow.python.platform import test 26 27 28 class TextTest(test.TestCase): 29 """Text processor tests.""" 30 31 def testTokenizer(self): 32 words = text.tokenizer( 33 ["a b c", "a\nb\nc", "a, b - c", " ", " "]) 34 self.assertEqual( 35 list(words), [["a", "b", "c"], ["a", "b", "c"], ["a", "b", "-", "c"], 36 ["", ""], ["", ""]]) 37 38 def testByteProcessor(self): 39 processor = text.ByteProcessor(max_document_length=8) 40 inp = ["abc", "", "", b"abc", "12345678901234567890"] 41 res = list(processor.fit_transform(inp)) 42 self.assertAllEqual(res, [[97, 98, 99, 0, 0, 0, 0, 0], 43 [209, 132, 209, 139, 208, 178, 208, 176], 44 [209, 132, 209, 139, 208, 178, 208, 176], 45 [97, 98, 99, 0, 0, 0, 0, 0], 46 [49, 50, 51, 52, 53, 54, 55, 56]]) 47 res = list(processor.reverse(res)) 48 self.assertAllEqual(res, ["abc", "", "", "abc", "12345678"]) 49 50 def testVocabularyProcessor(self): 51 vocab_processor = text.VocabularyProcessor( 52 max_document_length=4, min_frequency=1) 53 tokens = vocab_processor.fit_transform(["a b c", "a\nb\nc", "a, b - c"]) 54 self.assertAllEqual( 55 list(tokens), [[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 3]]) 56 57 def testVocabularyProcessorSaveRestore(self): 58 filename = test.get_temp_dir() + "test.vocab" 59 vocab_processor = text.VocabularyProcessor( 60 max_document_length=4, min_frequency=1) 61 tokens = vocab_processor.fit_transform(["a b c", "a\nb\nc", "a, b - c"]) 62 vocab_processor.save(filename) 63 new_vocab = text.VocabularyProcessor.restore(filename) 64 tokens = new_vocab.transform(["a b c"]) 65 self.assertAllEqual(list(tokens), [[1, 2, 3, 0]]) 66 67 def testExistingVocabularyProcessor(self): 68 vocab = CategoricalVocabulary() 69 vocab.get("A") 70 vocab.get("B") 71 vocab.freeze() 72 vocab_processor = text.VocabularyProcessor( 73 max_document_length=4, vocabulary=vocab, tokenizer_fn=list) 74 tokens = vocab_processor.fit_transform(["ABC", "CBABAF"]) 75 self.assertAllEqual(list(tokens), [[1, 2, 0, 0], [0, 2, 1, 2]]) 76 77 78 if __name__ == "__main__": 79 test.main() 80