1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Clustering Operations.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.factorization.python.ops import gen_clustering_ops 22 # go/tf-wildcard-import 23 # pylint: disable=wildcard-import 24 from tensorflow.contrib.factorization.python.ops.gen_clustering_ops import * 25 # pylint: enable=wildcard-import 26 from tensorflow.contrib.util import loader 27 from tensorflow.python.framework import constant_op 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.framework import ops 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import check_ops 32 from tensorflow.python.ops import control_flow_ops 33 from tensorflow.python.ops import math_ops 34 from tensorflow.python.ops import nn_impl 35 from tensorflow.python.ops import random_ops 36 from tensorflow.python.ops import state_ops 37 from tensorflow.python.ops import variable_scope 38 from tensorflow.python.ops.embedding_ops import embedding_lookup 39 from tensorflow.python.platform import resource_loader 40 41 _clustering_ops = loader.load_op_library( 42 resource_loader.get_path_to_datafile('_clustering_ops.so')) 43 44 # Euclidean distance between vectors U and V is defined as ||U - V||_F which is 45 # the square root of the sum of the absolute squares of the elements difference. 46 SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean' 47 # Cosine distance between vectors U and V is defined as 48 # 1 - (U \dot V) / (||U||_F ||V||_F) 49 COSINE_DISTANCE = 'cosine' 50 51 RANDOM_INIT = 'random' 52 KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' 53 KMC2_INIT = 'kmc2' 54 55 # The name of the variable holding the cluster centers. Used by the Estimator. 56 CLUSTERS_VAR_NAME = 'clusters' 57 58 59 class KMeans(object): 60 """Creates the graph for k-means clustering.""" 61 62 def __init__(self, 63 inputs, 64 num_clusters, 65 initial_clusters=RANDOM_INIT, 66 distance_metric=SQUARED_EUCLIDEAN_DISTANCE, 67 use_mini_batch=False, 68 mini_batch_steps_per_iteration=1, 69 random_seed=0, 70 kmeans_plus_plus_num_retries=2, 71 kmc2_chain_length=200): 72 """Creates an object for generating KMeans clustering graph. 73 74 This class implements the following variants of K-means algorithm: 75 76 If use_mini_batch is False, it runs standard full batch K-means. Each step 77 runs a single iteration of K-Means. This step can be run sharded across 78 multiple workers by passing a list of sharded inputs to this class. Note 79 however that a single step needs to process the full input at once. 80 81 If use_mini_batch is True, it runs a generalization of the mini-batch 82 K-means algorithm. It runs multiple iterations, where each iteration is 83 composed of mini_batch_steps_per_iteration steps. Two copies of cluster 84 centers are maintained: one that is updated at the end of each iteration, 85 and one that is updated every step. The first copy is used to compute 86 cluster allocations for each step, and for inference, while the second copy 87 is the one updated each step using the mini-batch update rule. After each 88 iteration is complete, this second copy is copied back the first copy. 89 90 Note that for use_mini_batch=True, when mini_batch_steps_per_iteration=1, 91 the algorithm reduces to the standard mini-batch algorithm. Also by setting 92 mini_batch_steps_per_iteration = num_inputs / batch_size, the algorithm 93 becomes an asynchronous version of the full-batch algorithm. Note however 94 that there is no guarantee by this implementation that each input is seen 95 exactly once per iteration. Also, different updates are applied 96 asynchronously without locking. So this asynchronous version may not behave 97 exactly like a full-batch version. 98 99 Args: 100 inputs: An input tensor or list of input tensors. It is assumed that the 101 data points have been previously randomly permuted. 102 num_clusters: An integer tensor specifying the number of clusters. This 103 argument is ignored if initial_clusters is a tensor or numpy array. 104 initial_clusters: Specifies the clusters used during initialization. One 105 of the following: 106 - a tensor or numpy array with the initial cluster centers. 107 - a function f(inputs, k) that returns up to k centers from `inputs`. 108 - "random": Choose centers randomly from `inputs`. 109 - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`. 110 - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`. 111 In the last three cases, one batch of `inputs` may not yield 112 `num_clusters` centers, in which case initialization will require 113 multiple batches until enough centers are chosen. In the case of 114 "random" or "kmeans_plus_plus", if the input size is <= `num_clusters` 115 then the entire batch is chosen to be cluster centers. 116 distance_metric: Distance metric used for clustering. Supported options: 117 "squared_euclidean", "cosine". 118 use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume 119 full batch. 120 mini_batch_steps_per_iteration: Number of steps after which the updated 121 cluster centers are synced back to a master copy. 122 random_seed: Seed for PRNG used to initialize seeds. 123 kmeans_plus_plus_num_retries: For each point that is sampled during 124 kmeans++ initialization, this parameter specifies the number of 125 additional points to draw from the current distribution before selecting 126 the best. If a negative value is specified, a heuristic is used to 127 sample O(log(num_to_sample)) additional points. 128 kmc2_chain_length: Determines how many candidate points are used by the 129 k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch 130 contains less points, one new cluster center is generated from the 131 (mini-)batch. 132 133 Raises: 134 ValueError: An invalid argument was passed to initial_clusters or 135 distance_metric. 136 """ 137 if isinstance(initial_clusters, str) and initial_clusters not in [ 138 RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT 139 ]: 140 raise ValueError( 141 "Unsupported initialization algorithm '%s'" % initial_clusters) 142 if distance_metric not in [SQUARED_EUCLIDEAN_DISTANCE, COSINE_DISTANCE]: 143 raise ValueError("Unsupported distance metric '%s'" % distance_metric) 144 self._inputs = inputs if isinstance(inputs, list) else [inputs] 145 self._num_clusters = num_clusters 146 self._initial_clusters = initial_clusters 147 self._distance_metric = distance_metric 148 self._use_mini_batch = use_mini_batch 149 self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration) 150 self._random_seed = random_seed 151 self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries 152 self._kmc2_chain_length = kmc2_chain_length 153 154 @classmethod 155 def _distance_graph(cls, inputs, clusters, distance_metric): 156 """Computes distance between each input and each cluster center. 157 158 Args: 159 inputs: list of input Tensors. 160 clusters: cluster Tensor. 161 distance_metric: distance metric used for clustering 162 163 Returns: 164 list of Tensors, where each element corresponds to each element in inputs. 165 The value is the distance of each row to all the cluster centers. 166 Currently only Euclidean distance and cosine distance are supported. 167 """ 168 assert isinstance(inputs, list) 169 if distance_metric == SQUARED_EUCLIDEAN_DISTANCE: 170 return cls._compute_euclidean_distance(inputs, clusters) 171 elif distance_metric == COSINE_DISTANCE: 172 return cls._compute_cosine_distance( 173 inputs, clusters, inputs_normalized=True) 174 else: 175 assert False, str(distance_metric) 176 177 @classmethod 178 def _compute_euclidean_distance(cls, inputs, clusters): 179 """Computes Euclidean distance between each input and each cluster center. 180 181 Args: 182 inputs: list of input Tensors. 183 clusters: cluster Tensor. 184 185 Returns: 186 list of Tensors, where each element corresponds to each element in inputs. 187 The value is the distance of each row to all the cluster centers. 188 """ 189 output = [] 190 for inp in inputs: 191 with ops.colocate_with(inp, ignore_existing=True): 192 # Computes Euclidean distance. Note the first and third terms are 193 # broadcast additions. 194 squared_distance = ( 195 math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) - 196 2 * math_ops.matmul(inp, clusters, transpose_b=True) + 197 array_ops.transpose( 198 math_ops.reduce_sum( 199 math_ops.square(clusters), 1, keepdims=True))) 200 output.append(squared_distance) 201 202 return output 203 204 @classmethod 205 def _compute_cosine_distance(cls, inputs, clusters, inputs_normalized=True): 206 """Computes cosine distance between each input and each cluster center. 207 208 Args: 209 inputs: list of input Tensor. 210 clusters: cluster Tensor 211 inputs_normalized: if True, it assumes that inp and clusters are 212 normalized and computes the dot product which is equivalent to the cosine 213 distance. Else it L2 normalizes the inputs first. 214 215 Returns: 216 list of Tensors, where each element corresponds to each element in inp. 217 The value is the distance of each row to all the cluster centers. 218 """ 219 output = [] 220 if not inputs_normalized: 221 with ops.colocate_with(clusters, ignore_existing=True): 222 clusters = nn_impl.l2_normalize(clusters, dim=1) 223 for inp in inputs: 224 with ops.colocate_with(inp, ignore_existing=True): 225 if not inputs_normalized: 226 inp = nn_impl.l2_normalize(inp, dim=1) 227 output.append(1 - math_ops.matmul(inp, clusters, transpose_b=True)) 228 return output 229 230 def _infer_graph(self, inputs, clusters): 231 """Maps input to closest cluster and the score. 232 233 Args: 234 inputs: list of input Tensors. 235 clusters: Tensor of cluster centers. 236 237 Returns: 238 List of tuple, where each value in tuple corresponds to a value in inp. 239 The tuple has following three elements: 240 all_scores: distance of each input to each cluster center. 241 score: distance of each input to closest cluster center. 242 cluster_idx: index of cluster center closest to the corresponding input. 243 """ 244 assert isinstance(inputs, list) 245 # Pairwise distances are used only by transform(). In all other cases, this 246 # sub-graph is not evaluated. 247 scores = self._distance_graph(inputs, clusters, self._distance_metric) 248 output = [] 249 if (self._distance_metric == COSINE_DISTANCE and 250 not self._clusters_l2_normalized()): 251 # The cosine distance between normalized vectors x and y is the same as 252 # 2 * squared_euclidean_distance. We are using this fact and reusing the 253 # nearest_neighbors op. 254 # TODO(ands): Support COSINE distance in nearest_neighbors and remove 255 # this. 256 with ops.colocate_with(clusters, ignore_existing=True): 257 clusters = nn_impl.l2_normalize(clusters, dim=1) 258 for inp, score in zip(inputs, scores): 259 with ops.colocate_with(inp, ignore_existing=True): 260 (indices, distances) = gen_clustering_ops.nearest_neighbors( 261 inp, clusters, 1) 262 if self._distance_metric == COSINE_DISTANCE: 263 distances *= 0.5 264 output.append((score, array_ops.squeeze(distances, [-1]), 265 array_ops.squeeze(indices, [-1]))) 266 return zip(*output) 267 268 def _clusters_l2_normalized(self): 269 """Returns True if clusters centers are kept normalized.""" 270 return (self._distance_metric == COSINE_DISTANCE and 271 (not self._use_mini_batch or 272 self._mini_batch_steps_per_iteration > 1)) 273 274 def _create_variables(self, num_clusters): 275 """Creates variables. 276 277 Args: 278 num_clusters: an integer Tensor providing the number of clusters. 279 280 Returns: 281 Tuple with following elements: 282 - cluster_centers: a Tensor for storing cluster centers 283 - cluster_centers_initialized: bool Variable indicating whether clusters 284 are initialized. 285 - cluster_counts: a Tensor for storing counts of points assigned to this 286 cluster. This is used by mini-batch training. 287 - cluster_centers_updated: Tensor representing copy of cluster centers 288 that are updated every step. 289 - update_in_steps: numbers of steps left before we sync 290 cluster_centers_updated back to cluster_centers. 291 """ 292 init_value = array_ops.constant([], dtype=dtypes.float32) 293 cluster_centers = variable_scope.variable( 294 init_value, name=CLUSTERS_VAR_NAME, validate_shape=False) 295 cluster_centers_initialized = variable_scope.variable( 296 False, dtype=dtypes.bool, name='initialized') 297 298 if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: 299 # Copy of cluster centers actively updated each step according to 300 # mini-batch update rule. 301 cluster_centers_updated = variable_scope.variable( 302 init_value, name='clusters_updated', validate_shape=False) 303 # How many steps till we copy the updated clusters to cluster_centers. 304 update_in_steps = variable_scope.variable( 305 self._mini_batch_steps_per_iteration, 306 dtype=dtypes.int64, 307 name='update_in_steps') 308 # Count of points assigned to cluster_centers_updated. 309 cluster_counts = variable_scope.variable( 310 array_ops.zeros([num_clusters], dtype=dtypes.int64)) 311 else: 312 cluster_centers_updated = cluster_centers 313 update_in_steps = None 314 cluster_counts = ( 315 variable_scope.variable( 316 array_ops.ones([num_clusters], dtype=dtypes.int64)) 317 if self._use_mini_batch else None) 318 return (cluster_centers, cluster_centers_initialized, cluster_counts, 319 cluster_centers_updated, update_in_steps) 320 321 @classmethod 322 def _l2_normalize_data(cls, inputs): 323 """Normalized the input data.""" 324 output = [] 325 for inp in inputs: 326 with ops.colocate_with(inp, ignore_existing=True): 327 output.append(nn_impl.l2_normalize(inp, dim=1)) 328 return output 329 330 def training_graph(self): 331 """Generate a training graph for kmeans algorithm. 332 333 This returns, among other things, an op that chooses initial centers 334 (init_op), a boolean variable that is set to True when the initial centers 335 are chosen (cluster_centers_initialized), and an op to perform either an 336 entire Lloyd iteration or a mini-batch of a Lloyd iteration (training_op). 337 The caller should use these components as follows. A single worker should 338 execute init_op multiple times until cluster_centers_initialized becomes 339 True. Then multiple workers may execute training_op any number of times. 340 341 Returns: 342 A tuple consisting of: 343 all_scores: A matrix (or list of matrices) of dimensions (num_input, 344 num_clusters) where the value is the distance of an input vector and a 345 cluster center. 346 cluster_idx: A vector (or list of vectors). Each element in the vector 347 corresponds to an input row in 'inp' and specifies the cluster id 348 corresponding to the input. 349 scores: Similar to cluster_idx but specifies the distance to the 350 assigned cluster instead. 351 cluster_centers_initialized: scalar indicating whether clusters have been 352 initialized. 353 init_op: an op to initialize the clusters. 354 training_op: an op that runs an iteration of training. 355 """ 356 # Implementation of kmeans. 357 if (isinstance(self._initial_clusters, str) or 358 callable(self._initial_clusters)): 359 initial_clusters = self._initial_clusters 360 num_clusters = ops.convert_to_tensor(self._num_clusters) 361 else: 362 initial_clusters = ops.convert_to_tensor(self._initial_clusters) 363 num_clusters = array_ops.shape(initial_clusters)[0] 364 365 inputs = self._inputs 366 (cluster_centers_var, cluster_centers_initialized, total_counts, 367 cluster_centers_updated, 368 update_in_steps) = self._create_variables(num_clusters) 369 init_op = _InitializeClustersOpFactory( 370 self._inputs, num_clusters, initial_clusters, self._distance_metric, 371 self._random_seed, self._kmeans_plus_plus_num_retries, 372 self._kmc2_chain_length, cluster_centers_var, cluster_centers_updated, 373 cluster_centers_initialized).op() 374 cluster_centers = cluster_centers_var 375 376 if self._distance_metric == COSINE_DISTANCE: 377 inputs = self._l2_normalize_data(inputs) 378 if not self._clusters_l2_normalized(): 379 cluster_centers = nn_impl.l2_normalize(cluster_centers, dim=1) 380 381 all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers) 382 if self._use_mini_batch: 383 sync_updates_op = self._mini_batch_sync_updates_op( 384 update_in_steps, cluster_centers_var, cluster_centers_updated, 385 total_counts) 386 assert sync_updates_op is not None 387 with ops.control_dependencies([sync_updates_op]): 388 training_op = self._mini_batch_training_op( 389 inputs, cluster_idx, cluster_centers_updated, total_counts) 390 else: 391 assert cluster_centers == cluster_centers_var 392 training_op = self._full_batch_training_op( 393 inputs, num_clusters, cluster_idx, cluster_centers_var) 394 395 return (all_scores, cluster_idx, scores, cluster_centers_initialized, 396 init_op, training_op) 397 398 def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var, 399 cluster_centers_updated, total_counts): 400 if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: 401 assert update_in_steps is not None 402 with ops.colocate_with(update_in_steps, ignore_existing=True): 403 404 def _f(): 405 # Note that there is a race condition here, so we do a best effort 406 # updates here. We reset update_in_steps first so that other workers 407 # don't duplicate the updates. Also we update cluster_center_vars 408 # before resetting total_counts to avoid large updates to 409 # cluster_centers_updated based on partially updated 410 # cluster_center_vars. 411 with ops.control_dependencies([ 412 state_ops.assign(update_in_steps, 413 self._mini_batch_steps_per_iteration - 1) 414 ]): 415 with ops.colocate_with( 416 cluster_centers_updated, ignore_existing=True): 417 if self._distance_metric == COSINE_DISTANCE: 418 cluster_centers = nn_impl.l2_normalize( 419 cluster_centers_updated, dim=1) 420 else: 421 cluster_centers = cluster_centers_updated 422 with ops.colocate_with(cluster_centers_var, ignore_existing=True): 423 with ops.control_dependencies( 424 [state_ops.assign(cluster_centers_var, cluster_centers)]): 425 with ops.colocate_with(None, ignore_existing=True): 426 with ops.control_dependencies([ 427 state_ops.assign(total_counts, 428 array_ops.zeros_like(total_counts)) 429 ]): 430 return array_ops.identity(update_in_steps) 431 432 return control_flow_ops.cond( 433 update_in_steps <= 0, _f, 434 lambda: state_ops.assign_sub(update_in_steps, 1)) 435 else: 436 return control_flow_ops.no_op() 437 438 def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers, 439 total_counts): 440 """Creates an op for training for mini batch case. 441 442 Args: 443 inputs: list of input Tensors. 444 cluster_idx_list: A vector (or list of vectors). Each element in the 445 vector corresponds to an input row in 'inp' and specifies the cluster id 446 corresponding to the input. 447 cluster_centers: Tensor Ref of cluster centers. 448 total_counts: Tensor Ref of cluster counts. 449 450 Returns: 451 An op for doing an update of mini-batch k-means. 452 """ 453 update_ops = [] 454 for inp, cluster_idx in zip(inputs, cluster_idx_list): 455 with ops.colocate_with(inp, ignore_existing=True): 456 assert total_counts is not None 457 cluster_idx = array_ops.reshape(cluster_idx, [-1]) 458 # Dedupe the unique ids of cluster_centers being updated so that updates 459 # can be locally aggregated. 460 unique_ids, unique_idx = array_ops.unique(cluster_idx) 461 num_unique_cluster_idx = array_ops.size(unique_ids) 462 # Fetch the old values of counts and cluster_centers. 463 with ops.colocate_with(total_counts, ignore_existing=True): 464 old_counts = array_ops.gather(total_counts, unique_ids) 465 # TODO(agarwal): This colocation seems to run into problems. Fix it. 466 with ops.colocate_with(cluster_centers, ignore_existing=True): 467 old_cluster_centers = array_ops.gather(cluster_centers, unique_ids) 468 # Locally aggregate the increment to counts. 469 count_updates = math_ops.unsorted_segment_sum( 470 array_ops.ones_like(unique_idx, dtype=total_counts.dtype), 471 unique_idx, num_unique_cluster_idx) 472 # Locally compute the sum of inputs mapped to each id. 473 # For a cluster with old cluster value x, old count n, and with data 474 # d_1,...d_k newly assigned to it, we recompute the new value as 475 # x += (sum_i(d_i) - k * x) / (n + k). 476 # Compute sum_i(d_i), see comment above. 477 cluster_center_updates = math_ops.unsorted_segment_sum( 478 inp, unique_idx, num_unique_cluster_idx) 479 # Shape to enable broadcasting count_updates and learning_rate to inp. 480 # It extends the shape with 1's to match the rank of inp. 481 broadcast_shape = array_ops.concat([ 482 array_ops.reshape(num_unique_cluster_idx, [1]), 483 array_ops.ones( 484 array_ops.reshape(array_ops.rank(inp) - 1, [1]), 485 dtype=dtypes.int32) 486 ], 0) 487 # Subtract k * x, see comment above. 488 cluster_center_updates -= math_ops.cast( 489 array_ops.reshape(count_updates, broadcast_shape), 490 inp.dtype) * old_cluster_centers 491 learning_rate = math_ops.reciprocal( 492 math_ops.cast(old_counts + count_updates, inp.dtype)) 493 learning_rate = array_ops.reshape(learning_rate, broadcast_shape) 494 # scale by 1 / (n + k), see comment above. 495 cluster_center_updates *= learning_rate 496 # Apply the updates. 497 update_counts = state_ops.scatter_add(total_counts, unique_ids, 498 count_updates) 499 update_cluster_centers = state_ops.scatter_add( 500 cluster_centers, unique_ids, cluster_center_updates) 501 update_ops.extend([update_counts, update_cluster_centers]) 502 return control_flow_ops.group(*update_ops) 503 504 def _full_batch_training_op(self, inputs, num_clusters, cluster_idx_list, 505 cluster_centers): 506 """Creates an op for training for full batch case. 507 508 Args: 509 inputs: list of input Tensors. 510 num_clusters: an integer Tensor providing the number of clusters. 511 cluster_idx_list: A vector (or list of vectors). Each element in the 512 vector corresponds to an input row in 'inp' and specifies the cluster id 513 corresponding to the input. 514 cluster_centers: Tensor Ref of cluster centers. 515 516 Returns: 517 An op for doing an update of mini-batch k-means. 518 """ 519 cluster_sums = [] 520 cluster_counts = [] 521 epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype) 522 for inp, cluster_idx in zip(inputs, cluster_idx_list): 523 with ops.colocate_with(inp, ignore_existing=True): 524 cluster_sums.append( 525 math_ops.unsorted_segment_sum(inp, cluster_idx, num_clusters)) 526 cluster_counts.append( 527 math_ops.unsorted_segment_sum( 528 array_ops.reshape( 529 array_ops.ones( 530 array_ops.reshape(array_ops.shape(inp)[0], [-1])), 531 [-1, 1]), cluster_idx, num_clusters)) 532 with ops.colocate_with(cluster_centers, ignore_existing=True): 533 new_clusters_centers = math_ops.add_n(cluster_sums) / ( 534 math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + 535 epsilon) 536 if self._clusters_l2_normalized(): 537 new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1) 538 return state_ops.assign(cluster_centers, new_clusters_centers) 539 540 541 class _InitializeClustersOpFactory(object): 542 """Internal class to create the op to initialize the clusters. 543 544 The op performs this algorithm (see constructor args): 545 546 num_remaining = num_clusters - length(cluster_centers) 547 if num_remaining == 0: 548 assert that cluster_centers_initialized is true 549 else: 550 assert that num_remaining > 0 551 new_centers = choose up to num_remaining initial centers 552 l2-normalize new_centers if using cosine distance 553 all_centers = concat(cluster_centers, new_centers) 554 cluster_centers := all_centers 555 if there is a cluster_centers_updated variable: 556 cluster_centers_updated := cluster_centers 557 num_now_remaining = num_clusters - length(cluster_centers) 558 if num_now_remaining == 0: 559 cluster_centers_initialized := true 560 """ 561 562 # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case. 563 564 def __init__(self, inputs, num_clusters, initial_clusters, distance_metric, 565 random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length, 566 cluster_centers, cluster_centers_updated, 567 cluster_centers_initialized): 568 """Creates an op factory. 569 570 Args: 571 inputs: See KMeans constructor. 572 num_clusters: An integer Tensor providing the number of clusters. 573 initial_clusters: See KMeans constructor. 574 distance_metric: See KMeans constructor. 575 random_seed: See KMeans constructor. 576 kmeans_plus_plus_num_retries: See KMeans constructor. 577 kmc2_chain_length: See KMeans constructor. 578 cluster_centers: The TF variable holding the initial centers. It may 579 already contain some centers when the op is executed. 580 cluster_centers_updated: A second TF variable to hold a copy of the 581 initial centers, used for full-batch mode. In mini-batch mode, 582 cluster_centers_updated is the same variable as cluster_centers. 583 cluster_centers_initialized: A boolean TF variable that will be set 584 to true when all the initial centers have been chosen. 585 """ 586 # All of these instance variables are constants. 587 self._inputs = inputs 588 self._num_clusters = num_clusters 589 self._initial_clusters = initial_clusters 590 self._distance_metric = distance_metric 591 self._random_seed = random_seed 592 self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries 593 self._kmc2_chain_length = kmc2_chain_length 594 self._cluster_centers = cluster_centers 595 self._cluster_centers_updated = cluster_centers_updated 596 self._cluster_centers_initialized = cluster_centers_initialized 597 598 self._num_selected = array_ops.shape(self._cluster_centers)[0] 599 self._num_remaining = self._num_clusters - self._num_selected 600 self._num_data = math_ops.add_n( 601 [array_ops.shape(i)[0] for i in self._inputs]) 602 603 def _random(self): 604 indices = random_ops.random_uniform( 605 array_ops.reshape(self._num_remaining, [-1]), 606 minval=0, 607 maxval=math_ops.cast(self._num_data, dtypes.int64), 608 seed=self._random_seed, 609 dtype=dtypes.int64) 610 return embedding_lookup(self._inputs, indices, partition_strategy='div') 611 612 def _kmeans_plus_plus(self): 613 # Points from only the first shard are used for initializing centers. 614 # TODO(ands): Use all points. 615 inp = self._inputs[0] 616 if self._distance_metric == COSINE_DISTANCE: 617 inp = nn_impl.l2_normalize(inp, dim=1) 618 return gen_clustering_ops.kmeans_plus_plus_initialization( 619 inp, 620 math_ops.to_int64(self._num_remaining), self._random_seed, 621 self._kmeans_plus_plus_num_retries) 622 623 def _kmc2_multiple_centers(self): 624 """Adds new initial cluster centers using the k-MC2 algorithm. 625 626 In each call to the op, the provided batch is split into subsets based on 627 the specified `kmc2_chain_length`. On each subset, a single Markov chain of 628 the k-MC2 algorithm is used to add *one* new center cluster center. If there 629 are less than `kmc2_chain_length` points in the subset, a single center is 630 added using one Markov chain on the full input. It is assumed that the 631 provided batch has previously been randomly permuted. Otherwise, k-MC2 may 632 return suboptimal centers. 633 634 Returns: 635 An op that adds new cluster centers. 636 """ 637 # The op only operates on the first shard of data. 638 first_shard = self._inputs[0] 639 # Number of points in the input that can be used. 640 batch_size = array_ops.shape(first_shard)[0] 641 # Maximum number of subsets such that the size of each subset is at least 642 # `kmc2_chain_length`. Final subsets may be larger. 643 max_to_sample = math_ops.cast( 644 batch_size / self._kmc2_chain_length, dtype=dtypes.int32) 645 # We sample at least one new center and at most all remaining centers. 646 num_to_sample = math_ops.maximum( 647 math_ops.minimum(self._num_remaining, max_to_sample), 1) 648 649 def _cond(i, _): 650 """Stopping condition for the while loop.""" 651 return math_ops.less(i, num_to_sample) 652 653 def _body(i, _): 654 """Body that adds a single new center based on a subset.""" 655 656 def _sample_random(): 657 """Returns a random point as a cluster center.""" 658 # By assumption the batch is reshuffled and _sample_random is always 659 # called for i=0. Hence, we simply return the first point. 660 new_center = array_ops.reshape(first_shard[0], [1, -1]) 661 if self._distance_metric == COSINE_DISTANCE: 662 new_center = nn_impl.l2_normalize(new_center, dim=1) 663 return new_center 664 665 def _sample_kmc2_chain(): 666 """Returns previous centers as well as a new center sampled using k-MC2. 667 """ 668 # Extract the subset from the underlying batch. 669 start = i * self._kmc2_chain_length 670 end = start + self._kmc2_chain_length 671 subset = first_shard[start:end] 672 # Compute the distances from points in the subset to previous centers. 673 _, distances = gen_clustering_ops.nearest_neighbors( 674 subset, self._cluster_centers, 1) 675 # Sample index of new center using k-MC2 Markov chain. 676 new_center_index = gen_clustering_ops.kmc2_chain_initialization( 677 array_ops.squeeze(distances), self._random_seed) 678 # Extract actual new center. 679 newly_sampled_center = array_ops.reshape(subset[new_center_index], 680 [1, -1]) 681 # Return concatenation with previously sampled centers. 682 if self._distance_metric == COSINE_DISTANCE: 683 newly_sampled_center = nn_impl.l2_normalize( 684 newly_sampled_center, dim=1) 685 return array_ops.concat([self._cluster_centers, newly_sampled_center], 686 0) 687 688 # Obtain a random point if there are no previously sampled centers. 689 # Otherwise, construct a k-MC2 Markov chain. 690 new_centers = control_flow_ops.cond( 691 math_ops.equal(self._num_selected, 0), _sample_random, 692 _sample_kmc2_chain) 693 # Assign new cluster centers to underlying variable. 694 assigned_centers = state_ops.assign( 695 self._cluster_centers, new_centers, validate_shape=False) 696 if self._cluster_centers_updated is not self._cluster_centers: 697 assigned_centers = state_ops.assign( 698 self._cluster_centers_updated, 699 assigned_centers, 700 validate_shape=False) 701 return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0] 702 703 # Add num_to_sample new data points. 704 _, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0]) 705 return num_remaining 706 707 def _greedy_batch_sampler(self, sampler): 708 # If the input dataset size is smaller than the number of centers 709 # remaining, choose the entire input dataset as centers. This can happen 710 # with mini-batch. Otherwise, sample the batch according to the provided 711 # sampler. 712 return control_flow_ops.cond(self._num_data <= self._num_remaining, 713 lambda: array_ops.concat(self._inputs, 0), 714 sampler) 715 716 def _single_batch_sampler(self, sampler): 717 # Enforce that there are at least as many data points as centers 718 # remaining. This gives the provided sampler the chance to select all 719 # remaining centers from a single batch. 720 with ops.control_dependencies( 721 [check_ops.assert_greater_equal(self._num_data, self._num_remaining)]): 722 return sampler() 723 724 def _choose_initial_centers(self): 725 if isinstance(self._initial_clusters, str): 726 if self._initial_clusters == RANDOM_INIT: 727 return self._greedy_batch_sampler(self._random) 728 else: # self._initial_clusters == KMEANS_PLUS_PLUS_INIT 729 return self._single_batch_sampler(self._kmeans_plus_plus) 730 elif callable(self._initial_clusters): 731 return self._initial_clusters(self._inputs, self._num_remaining) 732 else: 733 with ops.control_dependencies([ 734 check_ops.assert_equal(self._num_remaining, 735 array_ops.shape(self._initial_clusters)[0]) 736 ]): 737 return self._initial_clusters 738 739 def _add_new_centers(self): 740 """Adds some centers and returns the number of centers remaining.""" 741 new_centers = self._choose_initial_centers() 742 if self._distance_metric == COSINE_DISTANCE: 743 new_centers = nn_impl.l2_normalize(new_centers, dim=1) 744 # If cluster_centers is empty, it doesn't have the right shape for concat. 745 all_centers = control_flow_ops.cond( 746 math_ops.equal(self._num_selected, 0), lambda: new_centers, 747 lambda: array_ops.concat([self._cluster_centers, new_centers], 0)) 748 # TODO(ccolby): De-dupe all_centers? 749 a = state_ops.assign( 750 self._cluster_centers, all_centers, validate_shape=False) 751 if self._cluster_centers_updated is not self._cluster_centers: 752 a = state_ops.assign( 753 self._cluster_centers_updated, a, validate_shape=False) 754 return self._num_clusters - array_ops.shape(a)[0] 755 756 def _initialize(self): 757 with ops.control_dependencies([ 758 check_ops.assert_positive(self._num_remaining), 759 ]): 760 if self._initial_clusters == KMC2_INIT: 761 num_now_remaining = self._kmc2_multiple_centers() 762 else: 763 num_now_remaining = self._add_new_centers() 764 return control_flow_ops.cond( 765 math_ops.equal(num_now_remaining, 0), 766 lambda: state_ops.assign(self._cluster_centers_initialized, True), 767 control_flow_ops.no_op) 768 769 def op(self): 770 """Returns the cluster initializer op.""" 771 return control_flow_ops.cond( 772 math_ops.equal(self._num_remaining, 0), 773 lambda: check_ops.assert_equal(self._cluster_centers_initialized, True), 774 self._initialize) 775