Home | History | Annotate | Download | only in preprocessing
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Categorical vocabulary classes to map categories to indexes.
     17 
     18 Can be used for categorical variables, sparse variables and words.
     19 """
     20 
     21 from __future__ import absolute_import
     22 from __future__ import division
     23 from __future__ import print_function
     24 
     25 import collections
     26 import six
     27 
     28 
     29 class CategoricalVocabulary(object):
     30   """Categorical variables vocabulary class.
     31 
     32   Accumulates and provides mapping from classes to indexes.
     33   Can be easily used for words.
     34   """
     35 
     36   def __init__(self, unknown_token="<UNK>", support_reverse=True):
     37     self._unknown_token = unknown_token
     38     self._mapping = {unknown_token: 0}
     39     self._support_reverse = support_reverse
     40     if support_reverse:
     41       self._reverse_mapping = [unknown_token]
     42     self._freq = collections.defaultdict(int)
     43     self._freeze = False
     44 
     45   def __len__(self):
     46     """Returns total count of mappings. Including unknown token."""
     47     return len(self._mapping)
     48 
     49   def freeze(self, freeze=True):
     50     """Freezes the vocabulary, after which new words return unknown token id.
     51 
     52     Args:
     53       freeze: True to freeze, False to unfreeze.
     54     """
     55     self._freeze = freeze
     56 
     57   def get(self, category):
     58     """Returns word's id in the vocabulary.
     59 
     60     If category is new, creates a new id for it.
     61 
     62     Args:
     63       category: string or integer to lookup in vocabulary.
     64 
     65     Returns:
     66       interger, id in the vocabulary.
     67     """
     68     if category not in self._mapping:
     69       if self._freeze:
     70         return 0
     71       self._mapping[category] = len(self._mapping)
     72       if self._support_reverse:
     73         self._reverse_mapping.append(category)
     74     return self._mapping[category]
     75 
     76   def add(self, category, count=1):
     77     """Adds count of the category to the frequency table.
     78 
     79     Args:
     80       category: string or integer, category to add frequency to.
     81       count: optional integer, how many to add.
     82     """
     83     category_id = self.get(category)
     84     if category_id <= 0:
     85       return
     86     self._freq[category] += count
     87 
     88   def trim(self, min_frequency, max_frequency=-1):
     89     """Trims vocabulary for minimum frequency.
     90 
     91     Remaps ids from 1..n in sort frequency order.
     92     where n - number of elements left.
     93 
     94     Args:
     95       min_frequency: minimum frequency to keep.
     96       max_frequency: optional, maximum frequency to keep.
     97         Useful to remove very frequent categories (like stop words).
     98     """
     99     # Sort by alphabet then reversed frequency.
    100     self._freq = sorted(
    101         sorted(
    102             six.iteritems(self._freq),
    103             key=lambda x: (isinstance(x[0], str), x[0])),
    104         key=lambda x: x[1],
    105         reverse=True)
    106     self._mapping = {self._unknown_token: 0}
    107     if self._support_reverse:
    108       self._reverse_mapping = [self._unknown_token]
    109     idx = 1
    110     for category, count in self._freq:
    111       if max_frequency > 0 and count >= max_frequency:
    112         continue
    113       if count <= min_frequency:
    114         break
    115       self._mapping[category] = idx
    116       idx += 1
    117       if self._support_reverse:
    118         self._reverse_mapping.append(category)
    119     self._freq = dict(self._freq[:idx - 1])
    120 
    121   def reverse(self, class_id):
    122     """Given class id reverse to original class name.
    123 
    124     Args:
    125       class_id: Id of the class.
    126 
    127     Returns:
    128       Class name.
    129 
    130     Raises:
    131       ValueError: if this vocabulary wasn't initialized with support_reverse.
    132     """
    133     if not self._support_reverse:
    134       raise ValueError("This vocabulary wasn't initialized with "
    135                        "support_reverse to support reverse() function.")
    136     return self._reverse_mapping[class_id]
    137