Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Clustering Operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.factorization.python.ops import gen_clustering_ops
     22 # go/tf-wildcard-import
     23 # pylint: disable=wildcard-import
     24 from tensorflow.contrib.factorization.python.ops.gen_clustering_ops import *
     25 # pylint: enable=wildcard-import
     26 from tensorflow.contrib.util import loader
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import check_ops
     32 from tensorflow.python.ops import control_flow_ops
     33 from tensorflow.python.ops import math_ops
     34 from tensorflow.python.ops import nn_impl
     35 from tensorflow.python.ops import random_ops
     36 from tensorflow.python.ops import state_ops
     37 from tensorflow.python.ops import variable_scope
     38 from tensorflow.python.ops.embedding_ops import embedding_lookup
     39 from tensorflow.python.platform import resource_loader
     40 
     41 _clustering_ops = loader.load_op_library(
     42     resource_loader.get_path_to_datafile('_clustering_ops.so'))
     43 
     44 # Euclidean distance between vectors U and V is defined as ||U - V||_F which is
     45 # the square root of the sum of the absolute squares of the elements difference.
     46 SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean'
     47 # Cosine distance between vectors U and V is defined as
     48 # 1 - (U \dot V) / (||U||_F ||V||_F)
     49 COSINE_DISTANCE = 'cosine'
     50 
     51 RANDOM_INIT = 'random'
     52 KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'
     53 KMC2_INIT = 'kmc2'
     54 
     55 # The name of the variable holding the cluster centers. Used by the Estimator.
     56 CLUSTERS_VAR_NAME = 'clusters'
     57 
     58 
     59 class KMeans(object):
     60   """Creates the graph for k-means clustering."""
     61 
     62   def __init__(self,
     63                inputs,
     64                num_clusters,
     65                initial_clusters=RANDOM_INIT,
     66                distance_metric=SQUARED_EUCLIDEAN_DISTANCE,
     67                use_mini_batch=False,
     68                mini_batch_steps_per_iteration=1,
     69                random_seed=0,
     70                kmeans_plus_plus_num_retries=2,
     71                kmc2_chain_length=200):
     72     """Creates an object for generating KMeans clustering graph.
     73 
     74     This class implements the following variants of K-means algorithm:
     75 
     76     If use_mini_batch is False, it runs standard full batch K-means. Each step
     77     runs a single iteration of K-Means. This step can be run sharded across
     78     multiple workers by passing a list of sharded inputs to this class. Note
     79     however that a single step needs to process the full input at once.
     80 
     81     If use_mini_batch is True, it runs a generalization of the mini-batch
     82     K-means algorithm. It runs multiple iterations, where each iteration is
     83     composed of mini_batch_steps_per_iteration steps. Two copies of cluster
     84     centers are maintained: one that is updated at the end of each iteration,
     85     and one that is updated every step. The first copy is used to compute
     86     cluster allocations for each step, and for inference, while the second copy
     87     is the one updated each step using the mini-batch update rule. After each
     88     iteration is complete, this second copy is copied back the first copy.
     89 
     90     Note that for use_mini_batch=True, when mini_batch_steps_per_iteration=1,
     91     the algorithm reduces to the standard mini-batch algorithm. Also by setting
     92     mini_batch_steps_per_iteration = num_inputs / batch_size, the algorithm
     93     becomes an asynchronous version of the full-batch algorithm. Note however
     94     that there is no guarantee by this implementation that each input is seen
     95     exactly once per iteration. Also, different updates are applied
     96     asynchronously without locking. So this asynchronous version may not behave
     97     exactly like a full-batch version.
     98 
     99     Args:
    100       inputs: An input tensor or list of input tensors. It is assumed that the
    101         data points have been previously randomly permuted.
    102       num_clusters: An integer tensor specifying the number of clusters. This
    103         argument is ignored if initial_clusters is a tensor or numpy array.
    104       initial_clusters: Specifies the clusters used during initialization. One
    105         of the following:
    106         - a tensor or numpy array with the initial cluster centers.
    107         - a function f(inputs, k) that returns up to k centers from `inputs`.
    108         - "random": Choose centers randomly from `inputs`.
    109         - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`.
    110         - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`.
    111         In the last three cases, one batch of `inputs` may not yield
    112         `num_clusters` centers, in which case initialization will require
    113         multiple batches until enough centers are chosen. In the case of
    114         "random" or "kmeans_plus_plus", if the input size is <= `num_clusters`
    115         then the entire batch is chosen to be cluster centers.
    116       distance_metric: Distance metric used for clustering. Supported options:
    117         "squared_euclidean", "cosine".
    118       use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume
    119         full batch.
    120       mini_batch_steps_per_iteration: Number of steps after which the updated
    121         cluster centers are synced back to a master copy.
    122       random_seed: Seed for PRNG used to initialize seeds.
    123       kmeans_plus_plus_num_retries: For each point that is sampled during
    124         kmeans++ initialization, this parameter specifies the number of
    125         additional points to draw from the current distribution before selecting
    126         the best. If a negative value is specified, a heuristic is used to
    127         sample O(log(num_to_sample)) additional points.
    128       kmc2_chain_length: Determines how many candidate points are used by the
    129         k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch
    130         contains less points, one new cluster center is generated from the
    131         (mini-)batch.
    132 
    133     Raises:
    134       ValueError: An invalid argument was passed to initial_clusters or
    135         distance_metric.
    136     """
    137     if isinstance(initial_clusters, str) and initial_clusters not in [
    138         RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT
    139     ]:
    140       raise ValueError(
    141           "Unsupported initialization algorithm '%s'" % initial_clusters)
    142     if distance_metric not in [SQUARED_EUCLIDEAN_DISTANCE, COSINE_DISTANCE]:
    143       raise ValueError("Unsupported distance metric '%s'" % distance_metric)
    144     self._inputs = inputs if isinstance(inputs, list) else [inputs]
    145     self._num_clusters = num_clusters
    146     self._initial_clusters = initial_clusters
    147     self._distance_metric = distance_metric
    148     self._use_mini_batch = use_mini_batch
    149     self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration)
    150     self._random_seed = random_seed
    151     self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
    152     self._kmc2_chain_length = kmc2_chain_length
    153 
    154   @classmethod
    155   def _distance_graph(cls, inputs, clusters, distance_metric):
    156     """Computes distance between each input and each cluster center.
    157 
    158     Args:
    159       inputs: list of input Tensors.
    160       clusters: cluster Tensor.
    161       distance_metric: distance metric used for clustering
    162 
    163     Returns:
    164       list of Tensors, where each element corresponds to each element in inputs.
    165       The value is the distance of each row to all the cluster centers.
    166       Currently only Euclidean distance and cosine distance are supported.
    167     """
    168     assert isinstance(inputs, list)
    169     if distance_metric == SQUARED_EUCLIDEAN_DISTANCE:
    170       return cls._compute_euclidean_distance(inputs, clusters)
    171     elif distance_metric == COSINE_DISTANCE:
    172       return cls._compute_cosine_distance(
    173           inputs, clusters, inputs_normalized=True)
    174     else:
    175       assert False, str(distance_metric)
    176 
    177   @classmethod
    178   def _compute_euclidean_distance(cls, inputs, clusters):
    179     """Computes Euclidean distance between each input and each cluster center.
    180 
    181     Args:
    182       inputs: list of input Tensors.
    183       clusters: cluster Tensor.
    184 
    185     Returns:
    186       list of Tensors, where each element corresponds to each element in inputs.
    187       The value is the distance of each row to all the cluster centers.
    188     """
    189     output = []
    190     for inp in inputs:
    191       with ops.colocate_with(inp, ignore_existing=True):
    192         # Computes Euclidean distance. Note the first and third terms are
    193         # broadcast additions.
    194         squared_distance = (
    195             math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) -
    196             2 * math_ops.matmul(inp, clusters, transpose_b=True) +
    197             array_ops.transpose(
    198                 math_ops.reduce_sum(
    199                     math_ops.square(clusters), 1, keepdims=True)))
    200         output.append(squared_distance)
    201 
    202     return output
    203 
    204   @classmethod
    205   def _compute_cosine_distance(cls, inputs, clusters, inputs_normalized=True):
    206     """Computes cosine distance between each input and each cluster center.
    207 
    208     Args:
    209       inputs: list of input Tensor.
    210       clusters: cluster Tensor
    211       inputs_normalized: if True, it assumes that inp and clusters are
    212       normalized and computes the dot product which is equivalent to the cosine
    213       distance. Else it L2 normalizes the inputs first.
    214 
    215     Returns:
    216       list of Tensors, where each element corresponds to each element in inp.
    217       The value is the distance of each row to all the cluster centers.
    218     """
    219     output = []
    220     if not inputs_normalized:
    221       with ops.colocate_with(clusters, ignore_existing=True):
    222         clusters = nn_impl.l2_normalize(clusters, dim=1)
    223     for inp in inputs:
    224       with ops.colocate_with(inp, ignore_existing=True):
    225         if not inputs_normalized:
    226           inp = nn_impl.l2_normalize(inp, dim=1)
    227         output.append(1 - math_ops.matmul(inp, clusters, transpose_b=True))
    228     return output
    229 
    230   def _infer_graph(self, inputs, clusters):
    231     """Maps input to closest cluster and the score.
    232 
    233     Args:
    234       inputs: list of input Tensors.
    235       clusters: Tensor of cluster centers.
    236 
    237     Returns:
    238       List of tuple, where each value in tuple corresponds to a value in inp.
    239       The tuple has following three elements:
    240       all_scores: distance of each input to each cluster center.
    241       score: distance of each input to closest cluster center.
    242       cluster_idx: index of cluster center closest to the corresponding input.
    243     """
    244     assert isinstance(inputs, list)
    245     # Pairwise distances are used only by transform(). In all other cases, this
    246     # sub-graph is not evaluated.
    247     scores = self._distance_graph(inputs, clusters, self._distance_metric)
    248     output = []
    249     if (self._distance_metric == COSINE_DISTANCE and
    250         not self._clusters_l2_normalized()):
    251       # The cosine distance between normalized vectors x and y is the same as
    252       # 2 * squared_euclidean_distance. We are using this fact and reusing the
    253       # nearest_neighbors op.
    254       # TODO(ands): Support COSINE distance in nearest_neighbors and remove
    255       # this.
    256       with ops.colocate_with(clusters, ignore_existing=True):
    257         clusters = nn_impl.l2_normalize(clusters, dim=1)
    258     for inp, score in zip(inputs, scores):
    259       with ops.colocate_with(inp, ignore_existing=True):
    260         (indices, distances) = gen_clustering_ops.nearest_neighbors(
    261             inp, clusters, 1)
    262         if self._distance_metric == COSINE_DISTANCE:
    263           distances *= 0.5
    264         output.append((score, array_ops.squeeze(distances, [-1]),
    265                        array_ops.squeeze(indices, [-1])))
    266     return zip(*output)
    267 
    268   def _clusters_l2_normalized(self):
    269     """Returns True if clusters centers are kept normalized."""
    270     return (self._distance_metric == COSINE_DISTANCE and
    271             (not self._use_mini_batch or
    272              self._mini_batch_steps_per_iteration > 1))
    273 
    274   def _create_variables(self, num_clusters):
    275     """Creates variables.
    276 
    277     Args:
    278       num_clusters: an integer Tensor providing the number of clusters.
    279 
    280     Returns:
    281       Tuple with following elements:
    282       - cluster_centers: a Tensor for storing cluster centers
    283       - cluster_centers_initialized: bool Variable indicating whether clusters
    284             are initialized.
    285       - cluster_counts: a Tensor for storing counts of points assigned to this
    286             cluster. This is used by mini-batch training.
    287       - cluster_centers_updated: Tensor representing copy of cluster centers
    288             that are updated every step.
    289       - update_in_steps: numbers of steps left before we sync
    290             cluster_centers_updated back to cluster_centers.
    291     """
    292     init_value = array_ops.constant([], dtype=dtypes.float32)
    293     cluster_centers = variable_scope.variable(
    294         init_value, name=CLUSTERS_VAR_NAME, validate_shape=False)
    295     cluster_centers_initialized = variable_scope.variable(
    296         False, dtype=dtypes.bool, name='initialized')
    297 
    298     if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
    299       # Copy of cluster centers actively updated each step according to
    300       # mini-batch update rule.
    301       cluster_centers_updated = variable_scope.variable(
    302           init_value, name='clusters_updated', validate_shape=False)
    303       # How many steps till we copy the updated clusters to cluster_centers.
    304       update_in_steps = variable_scope.variable(
    305           self._mini_batch_steps_per_iteration,
    306           dtype=dtypes.int64,
    307           name='update_in_steps')
    308       # Count of points assigned to cluster_centers_updated.
    309       cluster_counts = variable_scope.variable(
    310           array_ops.zeros([num_clusters], dtype=dtypes.int64))
    311     else:
    312       cluster_centers_updated = cluster_centers
    313       update_in_steps = None
    314       cluster_counts = (
    315           variable_scope.variable(
    316               array_ops.ones([num_clusters], dtype=dtypes.int64))
    317           if self._use_mini_batch else None)
    318     return (cluster_centers, cluster_centers_initialized, cluster_counts,
    319             cluster_centers_updated, update_in_steps)
    320 
    321   @classmethod
    322   def _l2_normalize_data(cls, inputs):
    323     """Normalized the input data."""
    324     output = []
    325     for inp in inputs:
    326       with ops.colocate_with(inp, ignore_existing=True):
    327         output.append(nn_impl.l2_normalize(inp, dim=1))
    328     return output
    329 
    330   def training_graph(self):
    331     """Generate a training graph for kmeans algorithm.
    332 
    333     This returns, among other things, an op that chooses initial centers
    334     (init_op), a boolean variable that is set to True when the initial centers
    335     are chosen (cluster_centers_initialized), and an op to perform either an
    336     entire Lloyd iteration or a mini-batch of a Lloyd iteration (training_op).
    337     The caller should use these components as follows. A single worker should
    338     execute init_op multiple times until cluster_centers_initialized becomes
    339     True. Then multiple workers may execute training_op any number of times.
    340 
    341     Returns:
    342       A tuple consisting of:
    343       all_scores: A matrix (or list of matrices) of dimensions (num_input,
    344         num_clusters) where the value is the distance of an input vector and a
    345         cluster center.
    346       cluster_idx: A vector (or list of vectors). Each element in the vector
    347         corresponds to an input row in 'inp' and specifies the cluster id
    348         corresponding to the input.
    349       scores: Similar to cluster_idx but specifies the distance to the
    350         assigned cluster instead.
    351       cluster_centers_initialized: scalar indicating whether clusters have been
    352         initialized.
    353       init_op: an op to initialize the clusters.
    354       training_op: an op that runs an iteration of training.
    355     """
    356     # Implementation of kmeans.
    357     if (isinstance(self._initial_clusters, str) or
    358         callable(self._initial_clusters)):
    359       initial_clusters = self._initial_clusters
    360       num_clusters = ops.convert_to_tensor(self._num_clusters)
    361     else:
    362       initial_clusters = ops.convert_to_tensor(self._initial_clusters)
    363       num_clusters = array_ops.shape(initial_clusters)[0]
    364 
    365     inputs = self._inputs
    366     (cluster_centers_var, cluster_centers_initialized, total_counts,
    367      cluster_centers_updated,
    368      update_in_steps) = self._create_variables(num_clusters)
    369     init_op = _InitializeClustersOpFactory(
    370         self._inputs, num_clusters, initial_clusters, self._distance_metric,
    371         self._random_seed, self._kmeans_plus_plus_num_retries,
    372         self._kmc2_chain_length, cluster_centers_var, cluster_centers_updated,
    373         cluster_centers_initialized).op()
    374     cluster_centers = cluster_centers_var
    375 
    376     if self._distance_metric == COSINE_DISTANCE:
    377       inputs = self._l2_normalize_data(inputs)
    378       if not self._clusters_l2_normalized():
    379         cluster_centers = nn_impl.l2_normalize(cluster_centers, dim=1)
    380 
    381     all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
    382     if self._use_mini_batch:
    383       sync_updates_op = self._mini_batch_sync_updates_op(
    384           update_in_steps, cluster_centers_var, cluster_centers_updated,
    385           total_counts)
    386       assert sync_updates_op is not None
    387       with ops.control_dependencies([sync_updates_op]):
    388         training_op = self._mini_batch_training_op(
    389             inputs, cluster_idx, cluster_centers_updated, total_counts)
    390     else:
    391       assert cluster_centers == cluster_centers_var
    392       training_op = self._full_batch_training_op(
    393           inputs, num_clusters, cluster_idx, cluster_centers_var)
    394 
    395     return (all_scores, cluster_idx, scores, cluster_centers_initialized,
    396             init_op, training_op)
    397 
    398   def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
    399                                   cluster_centers_updated, total_counts):
    400     if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
    401       assert update_in_steps is not None
    402       with ops.colocate_with(update_in_steps, ignore_existing=True):
    403 
    404         def _f():
    405           # Note that there is a race condition here, so we do a best effort
    406           # updates here. We reset update_in_steps first so that other workers
    407           # don't duplicate the updates. Also we update cluster_center_vars
    408           # before resetting total_counts to avoid large updates to
    409           # cluster_centers_updated based on partially updated
    410           # cluster_center_vars.
    411           with ops.control_dependencies([
    412               state_ops.assign(update_in_steps,
    413                                self._mini_batch_steps_per_iteration - 1)
    414           ]):
    415             with ops.colocate_with(
    416                 cluster_centers_updated, ignore_existing=True):
    417               if self._distance_metric == COSINE_DISTANCE:
    418                 cluster_centers = nn_impl.l2_normalize(
    419                     cluster_centers_updated, dim=1)
    420               else:
    421                 cluster_centers = cluster_centers_updated
    422             with ops.colocate_with(cluster_centers_var, ignore_existing=True):
    423               with ops.control_dependencies(
    424                   [state_ops.assign(cluster_centers_var, cluster_centers)]):
    425                 with ops.colocate_with(None, ignore_existing=True):
    426                   with ops.control_dependencies([
    427                       state_ops.assign(total_counts,
    428                                        array_ops.zeros_like(total_counts))
    429                   ]):
    430                     return array_ops.identity(update_in_steps)
    431 
    432         return control_flow_ops.cond(
    433             update_in_steps <= 0, _f,
    434             lambda: state_ops.assign_sub(update_in_steps, 1))
    435     else:
    436       return control_flow_ops.no_op()
    437 
    438   def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
    439                               total_counts):
    440     """Creates an op for training for mini batch case.
    441 
    442     Args:
    443       inputs: list of input Tensors.
    444       cluster_idx_list: A vector (or list of vectors). Each element in the
    445         vector corresponds to an input row in 'inp' and specifies the cluster id
    446         corresponding to the input.
    447       cluster_centers: Tensor Ref of cluster centers.
    448       total_counts: Tensor Ref of cluster counts.
    449 
    450     Returns:
    451       An op for doing an update of mini-batch k-means.
    452     """
    453     update_ops = []
    454     for inp, cluster_idx in zip(inputs, cluster_idx_list):
    455       with ops.colocate_with(inp, ignore_existing=True):
    456         assert total_counts is not None
    457         cluster_idx = array_ops.reshape(cluster_idx, [-1])
    458         # Dedupe the unique ids of cluster_centers being updated so that updates
    459         # can be locally aggregated.
    460         unique_ids, unique_idx = array_ops.unique(cluster_idx)
    461         num_unique_cluster_idx = array_ops.size(unique_ids)
    462         # Fetch the old values of counts and cluster_centers.
    463         with ops.colocate_with(total_counts, ignore_existing=True):
    464           old_counts = array_ops.gather(total_counts, unique_ids)
    465         # TODO(agarwal): This colocation seems to run into problems. Fix it.
    466         with ops.colocate_with(cluster_centers, ignore_existing=True):
    467           old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
    468         # Locally aggregate the increment to counts.
    469         count_updates = math_ops.unsorted_segment_sum(
    470             array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
    471             unique_idx, num_unique_cluster_idx)
    472         # Locally compute the sum of inputs mapped to each id.
    473         # For a cluster with old cluster value x, old count n, and with data
    474         # d_1,...d_k newly assigned to it, we recompute the new value as
    475         # x += (sum_i(d_i) - k * x) / (n + k).
    476         # Compute sum_i(d_i), see comment above.
    477         cluster_center_updates = math_ops.unsorted_segment_sum(
    478             inp, unique_idx, num_unique_cluster_idx)
    479         # Shape to enable broadcasting count_updates and learning_rate to inp.
    480         # It extends the shape with 1's to match the rank of inp.
    481         broadcast_shape = array_ops.concat([
    482             array_ops.reshape(num_unique_cluster_idx, [1]),
    483             array_ops.ones(
    484                 array_ops.reshape(array_ops.rank(inp) - 1, [1]),
    485                 dtype=dtypes.int32)
    486         ], 0)
    487         # Subtract k * x, see comment above.
    488         cluster_center_updates -= math_ops.cast(
    489             array_ops.reshape(count_updates, broadcast_shape),
    490             inp.dtype) * old_cluster_centers
    491         learning_rate = math_ops.reciprocal(
    492             math_ops.cast(old_counts + count_updates, inp.dtype))
    493         learning_rate = array_ops.reshape(learning_rate, broadcast_shape)
    494         # scale by 1 / (n + k), see comment above.
    495         cluster_center_updates *= learning_rate
    496         # Apply the updates.
    497       update_counts = state_ops.scatter_add(total_counts, unique_ids,
    498                                             count_updates)
    499       update_cluster_centers = state_ops.scatter_add(
    500           cluster_centers, unique_ids, cluster_center_updates)
    501       update_ops.extend([update_counts, update_cluster_centers])
    502     return control_flow_ops.group(*update_ops)
    503 
    504   def _full_batch_training_op(self, inputs, num_clusters, cluster_idx_list,
    505                               cluster_centers):
    506     """Creates an op for training for full batch case.
    507 
    508     Args:
    509       inputs: list of input Tensors.
    510       num_clusters: an integer Tensor providing the number of clusters.
    511       cluster_idx_list: A vector (or list of vectors). Each element in the
    512         vector corresponds to an input row in 'inp' and specifies the cluster id
    513         corresponding to the input.
    514       cluster_centers: Tensor Ref of cluster centers.
    515 
    516     Returns:
    517       An op for doing an update of mini-batch k-means.
    518     """
    519     cluster_sums = []
    520     cluster_counts = []
    521     epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
    522     for inp, cluster_idx in zip(inputs, cluster_idx_list):
    523       with ops.colocate_with(inp, ignore_existing=True):
    524         cluster_sums.append(
    525             math_ops.unsorted_segment_sum(inp, cluster_idx, num_clusters))
    526         cluster_counts.append(
    527             math_ops.unsorted_segment_sum(
    528                 array_ops.reshape(
    529                     array_ops.ones(
    530                         array_ops.reshape(array_ops.shape(inp)[0], [-1])),
    531                     [-1, 1]), cluster_idx, num_clusters))
    532     with ops.colocate_with(cluster_centers, ignore_existing=True):
    533       new_clusters_centers = math_ops.add_n(cluster_sums) / (
    534           math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) +
    535           epsilon)
    536       if self._clusters_l2_normalized():
    537         new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1)
    538     return state_ops.assign(cluster_centers, new_clusters_centers)
    539 
    540 
    541 class _InitializeClustersOpFactory(object):
    542   """Internal class to create the op to initialize the clusters.
    543 
    544     The op performs this algorithm (see constructor args):
    545 
    546     num_remaining = num_clusters - length(cluster_centers)
    547     if num_remaining == 0:
    548       assert that cluster_centers_initialized is true
    549     else:
    550       assert that num_remaining > 0
    551       new_centers = choose up to num_remaining initial centers
    552       l2-normalize new_centers if using cosine distance
    553       all_centers = concat(cluster_centers, new_centers)
    554       cluster_centers := all_centers
    555       if there is a cluster_centers_updated variable:
    556         cluster_centers_updated := cluster_centers
    557       num_now_remaining = num_clusters - length(cluster_centers)
    558       if num_now_remaining == 0:
    559         cluster_centers_initialized := true
    560   """
    561 
    562   # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case.
    563 
    564   def __init__(self, inputs, num_clusters, initial_clusters, distance_metric,
    565                random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length,
    566                cluster_centers, cluster_centers_updated,
    567                cluster_centers_initialized):
    568     """Creates an op factory.
    569 
    570     Args:
    571       inputs: See KMeans constructor.
    572       num_clusters: An integer Tensor providing the number of clusters.
    573       initial_clusters: See KMeans constructor.
    574       distance_metric: See KMeans constructor.
    575       random_seed: See KMeans constructor.
    576       kmeans_plus_plus_num_retries: See KMeans constructor.
    577       kmc2_chain_length: See KMeans constructor.
    578       cluster_centers: The TF variable holding the initial centers. It may
    579           already contain some centers when the op is executed.
    580       cluster_centers_updated: A second TF variable to hold a copy of the
    581           initial centers, used for full-batch mode. In mini-batch mode,
    582           cluster_centers_updated is the same variable as cluster_centers.
    583       cluster_centers_initialized: A boolean TF variable that will be set
    584           to true when all the initial centers have been chosen.
    585     """
    586     # All of these instance variables are constants.
    587     self._inputs = inputs
    588     self._num_clusters = num_clusters
    589     self._initial_clusters = initial_clusters
    590     self._distance_metric = distance_metric
    591     self._random_seed = random_seed
    592     self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
    593     self._kmc2_chain_length = kmc2_chain_length
    594     self._cluster_centers = cluster_centers
    595     self._cluster_centers_updated = cluster_centers_updated
    596     self._cluster_centers_initialized = cluster_centers_initialized
    597 
    598     self._num_selected = array_ops.shape(self._cluster_centers)[0]
    599     self._num_remaining = self._num_clusters - self._num_selected
    600     self._num_data = math_ops.add_n(
    601         [array_ops.shape(i)[0] for i in self._inputs])
    602 
    603   def _random(self):
    604     indices = random_ops.random_uniform(
    605         array_ops.reshape(self._num_remaining, [-1]),
    606         minval=0,
    607         maxval=math_ops.cast(self._num_data, dtypes.int64),
    608         seed=self._random_seed,
    609         dtype=dtypes.int64)
    610     return embedding_lookup(self._inputs, indices, partition_strategy='div')
    611 
    612   def _kmeans_plus_plus(self):
    613     # Points from only the first shard are used for initializing centers.
    614     # TODO(ands): Use all points.
    615     inp = self._inputs[0]
    616     if self._distance_metric == COSINE_DISTANCE:
    617       inp = nn_impl.l2_normalize(inp, dim=1)
    618     return gen_clustering_ops.kmeans_plus_plus_initialization(
    619         inp,
    620         math_ops.to_int64(self._num_remaining), self._random_seed,
    621         self._kmeans_plus_plus_num_retries)
    622 
    623   def _kmc2_multiple_centers(self):
    624     """Adds new initial cluster centers using the k-MC2 algorithm.
    625 
    626     In each call to the op, the provided batch is split into subsets based on
    627     the specified `kmc2_chain_length`. On each subset, a single Markov chain of
    628     the k-MC2 algorithm is used to add *one* new center cluster center. If there
    629     are less than `kmc2_chain_length` points in the subset, a single center is
    630     added using one Markov chain on the full input. It is assumed that the
    631     provided batch has previously been randomly permuted. Otherwise, k-MC2 may
    632     return suboptimal centers.
    633 
    634     Returns:
    635       An op that adds new cluster centers.
    636     """
    637     # The op only operates on the first shard of data.
    638     first_shard = self._inputs[0]
    639     # Number of points in the input that can be used.
    640     batch_size = array_ops.shape(first_shard)[0]
    641     # Maximum number of subsets such that the size of each subset is at least
    642     # `kmc2_chain_length`. Final subsets may be larger.
    643     max_to_sample = math_ops.cast(
    644         batch_size / self._kmc2_chain_length, dtype=dtypes.int32)
    645     # We sample at least one new center and at most all remaining centers.
    646     num_to_sample = math_ops.maximum(
    647         math_ops.minimum(self._num_remaining, max_to_sample), 1)
    648 
    649     def _cond(i, _):
    650       """Stopping condition for the while loop."""
    651       return math_ops.less(i, num_to_sample)
    652 
    653     def _body(i, _):
    654       """Body that adds a single new center based on a subset."""
    655 
    656       def _sample_random():
    657         """Returns a random point as a cluster center."""
    658         # By assumption the batch is reshuffled and _sample_random is always
    659         # called for i=0. Hence, we simply return the first point.
    660         new_center = array_ops.reshape(first_shard[0], [1, -1])
    661         if self._distance_metric == COSINE_DISTANCE:
    662           new_center = nn_impl.l2_normalize(new_center, dim=1)
    663         return new_center
    664 
    665       def _sample_kmc2_chain():
    666         """Returns previous centers as well as a new center sampled using k-MC2.
    667         """
    668         # Extract the subset from the underlying batch.
    669         start = i * self._kmc2_chain_length
    670         end = start + self._kmc2_chain_length
    671         subset = first_shard[start:end]
    672         # Compute the distances from points in the subset to previous centers.
    673         _, distances = gen_clustering_ops.nearest_neighbors(
    674             subset, self._cluster_centers, 1)
    675         # Sample index of new center using k-MC2 Markov chain.
    676         new_center_index = gen_clustering_ops.kmc2_chain_initialization(
    677             array_ops.squeeze(distances), self._random_seed)
    678         # Extract actual new center.
    679         newly_sampled_center = array_ops.reshape(subset[new_center_index],
    680                                                  [1, -1])
    681         # Return concatenation with previously sampled centers.
    682         if self._distance_metric == COSINE_DISTANCE:
    683           newly_sampled_center = nn_impl.l2_normalize(
    684               newly_sampled_center, dim=1)
    685         return array_ops.concat([self._cluster_centers, newly_sampled_center],
    686                                 0)
    687 
    688       # Obtain a random point if there are no previously sampled centers.
    689       # Otherwise, construct a k-MC2 Markov chain.
    690       new_centers = control_flow_ops.cond(
    691           math_ops.equal(self._num_selected, 0), _sample_random,
    692           _sample_kmc2_chain)
    693       # Assign new cluster centers to underlying variable.
    694       assigned_centers = state_ops.assign(
    695           self._cluster_centers, new_centers, validate_shape=False)
    696       if self._cluster_centers_updated is not self._cluster_centers:
    697         assigned_centers = state_ops.assign(
    698             self._cluster_centers_updated,
    699             assigned_centers,
    700             validate_shape=False)
    701       return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0]
    702 
    703     # Add num_to_sample new data points.
    704     _, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0])
    705     return num_remaining
    706 
    707   def _greedy_batch_sampler(self, sampler):
    708     # If the input dataset size is smaller than the number of centers
    709     # remaining, choose the entire input dataset as centers. This can happen
    710     # with mini-batch. Otherwise, sample the batch according to the provided
    711     # sampler.
    712     return control_flow_ops.cond(self._num_data <= self._num_remaining,
    713                                  lambda: array_ops.concat(self._inputs, 0),
    714                                  sampler)
    715 
    716   def _single_batch_sampler(self, sampler):
    717     # Enforce that there are at least as many data points as centers
    718     # remaining. This gives the provided sampler the chance to select all
    719     # remaining centers from a single batch.
    720     with ops.control_dependencies(
    721         [check_ops.assert_greater_equal(self._num_data, self._num_remaining)]):
    722       return sampler()
    723 
    724   def _choose_initial_centers(self):
    725     if isinstance(self._initial_clusters, str):
    726       if self._initial_clusters == RANDOM_INIT:
    727         return self._greedy_batch_sampler(self._random)
    728       else:  # self._initial_clusters == KMEANS_PLUS_PLUS_INIT
    729         return self._single_batch_sampler(self._kmeans_plus_plus)
    730     elif callable(self._initial_clusters):
    731       return self._initial_clusters(self._inputs, self._num_remaining)
    732     else:
    733       with ops.control_dependencies([
    734           check_ops.assert_equal(self._num_remaining,
    735                                  array_ops.shape(self._initial_clusters)[0])
    736       ]):
    737         return self._initial_clusters
    738 
    739   def _add_new_centers(self):
    740     """Adds some centers and returns the number of centers remaining."""
    741     new_centers = self._choose_initial_centers()
    742     if self._distance_metric == COSINE_DISTANCE:
    743       new_centers = nn_impl.l2_normalize(new_centers, dim=1)
    744     # If cluster_centers is empty, it doesn't have the right shape for concat.
    745     all_centers = control_flow_ops.cond(
    746         math_ops.equal(self._num_selected, 0), lambda: new_centers,
    747         lambda: array_ops.concat([self._cluster_centers, new_centers], 0))
    748     # TODO(ccolby): De-dupe all_centers?
    749     a = state_ops.assign(
    750         self._cluster_centers, all_centers, validate_shape=False)
    751     if self._cluster_centers_updated is not self._cluster_centers:
    752       a = state_ops.assign(
    753           self._cluster_centers_updated, a, validate_shape=False)
    754     return self._num_clusters - array_ops.shape(a)[0]
    755 
    756   def _initialize(self):
    757     with ops.control_dependencies([
    758         check_ops.assert_positive(self._num_remaining),
    759     ]):
    760       if self._initial_clusters == KMC2_INIT:
    761         num_now_remaining = self._kmc2_multiple_centers()
    762       else:
    763         num_now_remaining = self._add_new_centers()
    764       return control_flow_ops.cond(
    765           math_ops.equal(num_now_remaining, 0),
    766           lambda: state_ops.assign(self._cluster_centers_initialized, True),
    767           control_flow_ops.no_op)
    768 
    769   def op(self):
    770     """Returns the cluster initializer op."""
    771     return control_flow_ops.cond(
    772         math_ops.equal(self._num_remaining, 0),
    773         lambda: check_ops.assert_equal(self._cluster_centers_initialized, True),
    774         self._initialize)
    775