1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy 5 // of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 // ============================================================================== 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 19 namespace tensorflow { 20 21 REGISTER_OP("KmeansPlusPlusInitialization") 22 .Input("points: float32") 23 .Input("num_to_sample: int64") 24 .Input("seed: int64") 25 .Input("num_retries_per_sample: int64") 26 .Output("samples: float32") 27 .SetShapeFn(shape_inference::UnknownShape) 28 .Doc(R"( 29 Selects num_to_sample rows of input using the KMeans++ criterion. 30 31 Rows of points are assumed to be input points. One row is selected at random. 32 Subsequent rows are sampled with probability proportional to the squared L2 33 distance from the nearest row selected thus far till num_to_sample rows have 34 been sampled. 35 36 points: Matrix of shape (n, d). Rows are assumed to be input points. 37 num_to_sample: Scalar. The number of rows to sample. This value must not be 38 larger than n. 39 seed: Scalar. Seed for initializing the random number generator. 40 num_retries_per_sample: Scalar. For each row that is sampled, this parameter 41 specifies the number of additional points to draw from the current 42 distribution before selecting the best. If a negative value is specified, a 43 heuristic is used to sample O(log(num_to_sample)) additional points. 44 samples: Matrix of shape (num_to_sample, d). The sampled rows. 45 )"); 46 47 REGISTER_OP("KMC2ChainInitialization") 48 .Input("distances: float32") 49 .Input("seed: int64") 50 .Output("index: int64") 51 .SetShapeFn(shape_inference::ScalarShape) 52 .Doc(R"( 53 Returns the index of a data point that should be added to the seed set. 54 55 Entries in distances are assumed to be squared distances of candidate points to 56 the already sampled centers in the seed set. The op constructs one Markov chain 57 of the k-MC^2 algorithm and returns the index of one candidate point to be added 58 as an additional cluster center. 59 60 distances: Vector with squared distances to the closest previously sampled 61 cluster center for each candidate point. 62 seed: Scalar. Seed for initializing the random number generator. 63 index: Scalar with the index of the sampled point. 64 )"); 65 66 REGISTER_OP("NearestNeighbors") 67 .Input("points: float32") 68 .Input("centers: float32") 69 .Input("k: int64") 70 .Output("nearest_center_indices: int64") 71 .Output("nearest_center_distances: float32") 72 .SetShapeFn(shape_inference::UnknownShape) 73 .Doc(R"( 74 Selects the k nearest centers for each point. 75 76 Rows of points are assumed to be input points. Rows of centers are assumed to be 77 the list of candidate centers. For each point, the k centers that have least L2 78 distance to it are computed. 79 80 points: Matrix of shape (n, d). Rows are assumed to be input points. 81 centers: Matrix of shape (m, d). Rows are assumed to be centers. 82 k: Scalar. Number of nearest centers to return for each point. If k is larger 83 than m, then only m centers are returned. 84 nearest_center_indices: Matrix of shape (n, min(m, k)). Each row contains the 85 indices of the centers closest to the corresponding point, ordered by 86 increasing distance. 87 nearest_center_distances: Matrix of shape (n, min(m, k)). Each row contains the 88 squared L2 distance to the corresponding center in nearest_center_indices. 89 )"); 90 91 } // namespace tensorflow 92