Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <algorithm>
     17 #include <string>
     18 #include <vector>
     19 
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/framework/types.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/lib/random/philox_random.h"
     25 #include "tensorflow/core/lib/random/simple_philox.h"
     26 #include "tensorflow/core/util/guarded_philox_random.h"
     27 
     28 namespace tensorflow {
     29 
     30 template <typename T>
     31 class SkipGramGenerateCandidatesOp : public OpKernel {
     32  public:
     33   explicit SkipGramGenerateCandidatesOp(OpKernelConstruction* context)
     34       : OpKernel(context) {
     35     OP_REQUIRES_OK(context, generator_.Init(context));
     36   }
     37 
     38   void Compute(OpKernelContext* context) override {
     39     const Tensor* input_tensor;
     40     OP_REQUIRES_OK(context, context->input("input_tensor", &input_tensor));
     41     const auto input = input_tensor->flat<T>();
     42 
     43     const Tensor* min_skips_tensor;
     44     OP_REQUIRES_OK(context, context->input("min_skips", &min_skips_tensor));
     45     const int min_skips = *(min_skips_tensor->scalar<int>().data());
     46     const Tensor* max_skips_tensor;
     47     OP_REQUIRES_OK(context, context->input("max_skips", &max_skips_tensor));
     48     const int max_skips = *(max_skips_tensor->scalar<int>().data());
     49 
     50     OP_REQUIRES(
     51         context, min_skips >= 0 && max_skips >= 0,
     52         errors::InvalidArgument("Both min_skips and max_skips must be >= 0."));
     53     OP_REQUIRES(context, min_skips <= max_skips,
     54                 errors::InvalidArgument("min_skips must be <= max_skips."));
     55 
     56     const Tensor* start_tensor;
     57     OP_REQUIRES_OK(context, context->input("start", &start_tensor));
     58     const int start = *(start_tensor->scalar<int>().data());
     59     const Tensor* limit_tensor;
     60     OP_REQUIRES_OK(context, context->input("limit", &limit_tensor));
     61     const int limit = *(limit_tensor->scalar<int>().data());
     62     const int end =
     63         limit < 0 ? input.size()
     64                   : std::min(start + limit, static_cast<int>(input.size()));
     65 
     66     const Tensor* emit_self_tensor;
     67     OP_REQUIRES_OK(context,
     68                    context->input("emit_self_as_target", &emit_self_tensor));
     69     const bool emit_self_as_target = *(emit_self_tensor->scalar<bool>().data());
     70 
     71     std::vector<T> tokens;
     72     std::vector<T> labels;
     73 
     74     // Reserve the number of random numbers we will use - we use one for each
     75     // token between start and end.
     76     random::PhiloxRandom local_gen =
     77         generator_.ReserveSamples32(end - start + 1);
     78     random::SimplePhilox rng(&local_gen);
     79 
     80     // For each token in the sentence, pick a random skip, then generates
     81     // (token, label) pairs for all labels whose distances from the token are
     82     // within the range [-skip, skip].
     83     for (int i = start; i < end; ++i) {
     84       const int skips = min_skips + rng.Uniform(max_skips - min_skips + 1);
     85       for (int j = -skips; j <= skips; ++j) {
     86         if ((i + j < start) || (i + j >= end) ||
     87             (j == 0 && !emit_self_as_target)) {
     88           continue;
     89         }
     90         tokens.push_back(input(i));
     91         labels.push_back(input(i + j));
     92       }
     93     }
     94 
     95     Tensor* tokens_output = nullptr;
     96     OP_REQUIRES_OK(context,
     97                    context->allocate_output(
     98                        "tokens", TensorShape({static_cast<int>(tokens.size())}),
     99                        &tokens_output));
    100     Tensor* labels_output = nullptr;
    101     OP_REQUIRES_OK(context,
    102                    context->allocate_output(
    103                        "labels", TensorShape({static_cast<int>(labels.size())}),
    104                        &labels_output));
    105     OP_REQUIRES(
    106         context, tokens_output->IsSameSize(*labels_output),
    107         errors::Internal(strings::StrCat(
    108             "Mismatch between tokens_output shape of ",
    109             tokens_output->shape().DebugString(),
    110             " and labels_output shape of ",
    111             labels_output->shape().DebugString(),
    112             ". This should never happen - contact ami-team@ if it does.")));
    113 
    114     // Copies results to output tensors.
    115     for (int i = 0; i < tokens.size(); ++i) {
    116       tokens_output->vec<T>()(i) = tokens[i];
    117       labels_output->vec<T>()(i) = labels[i];
    118     }
    119   }
    120 
    121  private:
    122   GuardedPhiloxRandom generator_;
    123 };
    124 
    125 #define REGISTER_KERNEL(type)                                \
    126   REGISTER_KERNEL_BUILDER(Name("SkipGramGenerateCandidates") \
    127                               .Device(DEVICE_CPU)            \
    128                               .TypeConstraint<type>("T"),    \
    129                           SkipGramGenerateCandidatesOp<type>)
    130 
    131 REGISTER_KERNEL(string);
    132 REGISTER_KERNEL(int64);
    133 REGISTER_KERNEL(int32);
    134 REGISTER_KERNEL(int16);
    135 // TODO(weiho): Add other types if the need arises.
    136 
    137 #undef REGISTER_KERNEL
    138 
    139 }  // namespace tensorflow
    140