Home | History | Annotate | Download | only in ops
      1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
      4 // use this file except in compliance with the License.  You may obtain a copy
      5 // of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
     11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
     12 // License for the specific language governing permissions and limitations under
     13 // the License.
     14 // ==============================================================================
     15 
     16 #include "tensorflow/core/framework/common_shape_fns.h"
     17 #include "tensorflow/core/framework/op.h"
     18 
     19 namespace tensorflow {
     20 
     21 REGISTER_OP("KmeansPlusPlusInitialization")
     22     .Input("points: float32")
     23     .Input("num_to_sample: int64")
     24     .Input("seed: int64")
     25     .Input("num_retries_per_sample: int64")
     26     .Output("samples: float32")
     27     .SetShapeFn(shape_inference::UnknownShape)
     28     .Doc(R"(
     29 Selects num_to_sample rows of input using the KMeans++ criterion.
     30 
     31 Rows of points are assumed to be input points. One row is selected at random.
     32 Subsequent rows are sampled with probability proportional to the squared L2
     33 distance from the nearest row selected thus far till num_to_sample rows have
     34 been sampled.
     35 
     36 points: Matrix of shape (n, d). Rows are assumed to be input points.
     37 num_to_sample: Scalar. The number of rows to sample. This value must not be
     38   larger than n.
     39 seed: Scalar. Seed for initializing the random number generator.
     40 num_retries_per_sample: Scalar. For each row that is sampled, this parameter
     41   specifies the number of additional points to draw from the current
     42   distribution before selecting the best. If a negative value is specified, a
     43   heuristic is used to sample O(log(num_to_sample)) additional points.
     44 samples: Matrix of shape (num_to_sample, d). The sampled rows.
     45 )");
     46 
     47 REGISTER_OP("KMC2ChainInitialization")
     48     .Input("distances: float32")
     49     .Input("seed: int64")
     50     .Output("index: int64")
     51     .SetShapeFn(shape_inference::ScalarShape)
     52     .Doc(R"(
     53 Returns the index of a data point that should be added to the seed set.
     54 
     55 Entries in distances are assumed to be squared distances of candidate points to
     56 the already sampled centers in the seed set. The op constructs one Markov chain
     57 of the k-MC^2 algorithm and returns the index of one candidate point to be added
     58 as an additional cluster center.
     59 
     60 distances: Vector with squared distances to the closest previously sampled
     61   cluster center for each candidate point.
     62 seed: Scalar. Seed for initializing the random number generator.
     63 index: Scalar with the index of the sampled point.
     64 )");
     65 
     66 REGISTER_OP("NearestNeighbors")
     67     .Input("points: float32")
     68     .Input("centers: float32")
     69     .Input("k: int64")
     70     .Output("nearest_center_indices: int64")
     71     .Output("nearest_center_distances: float32")
     72     .SetShapeFn(shape_inference::UnknownShape)
     73     .Doc(R"(
     74 Selects the k nearest centers for each point.
     75 
     76 Rows of points are assumed to be input points. Rows of centers are assumed to be
     77 the list of candidate centers. For each point, the k centers that have least L2
     78 distance to it are computed.
     79 
     80 points: Matrix of shape (n, d). Rows are assumed to be input points.
     81 centers: Matrix of shape (m, d). Rows are assumed to be centers.
     82 k: Scalar. Number of nearest centers to return for each point. If k is larger
     83   than m, then only m centers are returned.
     84 nearest_center_indices: Matrix of shape (n, min(m, k)). Each row contains the
     85   indices of the centers closest to the corresponding point, ordered by
     86   increasing distance.
     87 nearest_center_distances: Matrix of shape (n, min(m, k)). Each row contains the
     88   squared L2 distance to the corresponding center in nearest_center_indices.
     89 )");
     90 
     91 }  // namespace tensorflow
     92