1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op.h" 17 #include "tensorflow/core/framework/op_kernel.h" 18 #include "tensorflow/core/lib/core/stringpiece.h" 19 #include "tensorflow/core/lib/gtl/map_util.h" 20 #include "tensorflow/core/lib/random/distribution_sampler.h" 21 #include "tensorflow/core/lib/random/philox_random.h" 22 #include "tensorflow/core/lib/random/simple_philox.h" 23 #include "tensorflow/core/lib/strings/str_util.h" 24 #include "tensorflow/core/platform/thread_annotations.h" 25 #include "tensorflow/core/util/guarded_philox_random.h" 26 27 namespace tensorflow { 28 29 // Number of examples to precalculate. 30 const int kPrecalc = 3000; 31 // Number of words to read into a sentence before processing. 32 const int kSentenceSize = 1000; 33 34 namespace { 35 36 bool ScanWord(StringPiece* input, string* word) { 37 str_util::RemoveLeadingWhitespace(input); 38 StringPiece tmp; 39 if (str_util::ConsumeNonWhitespace(input, &tmp)) { 40 word->assign(tmp.data(), tmp.size()); 41 return true; 42 } else { 43 return false; 44 } 45 } 46 47 } // end namespace 48 49 class SkipgramOp : public OpKernel { 50 public: 51 explicit SkipgramOp(OpKernelConstruction* ctx) 52 : OpKernel(ctx), rng_(&philox_) { 53 string filename; 54 OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename)); 55 OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_)); 56 OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_)); 57 OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_)); 58 OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_)); 59 OP_REQUIRES_OK(ctx, Init(ctx->env(), filename)); 60 61 mutex_lock l(mu_); 62 example_pos_ = corpus_size_; 63 label_pos_ = corpus_size_; 64 label_limit_ = corpus_size_; 65 sentence_index_ = kSentenceSize; 66 for (int i = 0; i < kPrecalc; ++i) { 67 NextExample(&precalc_examples_[i].input, &precalc_examples_[i].label); 68 } 69 } 70 71 void Compute(OpKernelContext* ctx) override { 72 Tensor words_per_epoch(DT_INT64, TensorShape({})); 73 Tensor current_epoch(DT_INT32, TensorShape({})); 74 Tensor total_words_processed(DT_INT64, TensorShape({})); 75 Tensor examples(DT_INT32, TensorShape({batch_size_})); 76 auto Texamples = examples.flat<int32>(); 77 Tensor labels(DT_INT32, TensorShape({batch_size_})); 78 auto Tlabels = labels.flat<int32>(); 79 { 80 mutex_lock l(mu_); 81 for (int i = 0; i < batch_size_; ++i) { 82 Texamples(i) = precalc_examples_[precalc_index_].input; 83 Tlabels(i) = precalc_examples_[precalc_index_].label; 84 precalc_index_++; 85 if (precalc_index_ >= kPrecalc) { 86 precalc_index_ = 0; 87 for (int j = 0; j < kPrecalc; ++j) { 88 NextExample(&precalc_examples_[j].input, 89 &precalc_examples_[j].label); 90 } 91 } 92 } 93 words_per_epoch.scalar<int64>()() = corpus_size_; 94 current_epoch.scalar<int32>()() = current_epoch_; 95 total_words_processed.scalar<int64>()() = total_words_processed_; 96 } 97 ctx->set_output(0, word_); 98 ctx->set_output(1, freq_); 99 ctx->set_output(2, words_per_epoch); 100 ctx->set_output(3, current_epoch); 101 ctx->set_output(4, total_words_processed); 102 ctx->set_output(5, examples); 103 ctx->set_output(6, labels); 104 } 105 106 private: 107 struct Example { 108 int32 input; 109 int32 label; 110 }; 111 112 int32 batch_size_ = 0; 113 int32 window_size_ = 5; 114 float subsample_ = 1e-3; 115 int min_count_ = 5; 116 int32 vocab_size_ = 0; 117 Tensor word_; 118 Tensor freq_; 119 int64 corpus_size_ = 0; 120 std::vector<int32> corpus_; 121 std::vector<Example> precalc_examples_; 122 int precalc_index_ = 0; 123 std::vector<int32> sentence_; 124 int sentence_index_ = 0; 125 126 mutex mu_; 127 random::PhiloxRandom philox_ GUARDED_BY(mu_); 128 random::SimplePhilox rng_ GUARDED_BY(mu_); 129 int32 current_epoch_ GUARDED_BY(mu_) = -1; 130 int64 total_words_processed_ GUARDED_BY(mu_) = 0; 131 int32 example_pos_ GUARDED_BY(mu_); 132 int32 label_pos_ GUARDED_BY(mu_); 133 int32 label_limit_ GUARDED_BY(mu_); 134 135 // {example_pos_, label_pos_} is the cursor for the next example. 136 // example_pos_ wraps around at the end of corpus_. For each 137 // example, we randomly generate [label_pos_, label_limit) for 138 // labels. 139 void NextExample(int32* example, int32* label) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 140 while (true) { 141 if (label_pos_ >= label_limit_) { 142 ++total_words_processed_; 143 ++sentence_index_; 144 if (sentence_index_ >= kSentenceSize) { 145 sentence_index_ = 0; 146 for (int i = 0; i < kSentenceSize; ++i, ++example_pos_) { 147 if (example_pos_ >= corpus_size_) { 148 ++current_epoch_; 149 example_pos_ = 0; 150 } 151 if (subsample_ > 0) { 152 int32 word_freq = freq_.flat<int32>()(corpus_[example_pos_]); 153 // See Eq. 5 in http://arxiv.org/abs/1310.4546 154 float keep_prob = 155 (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) * 156 (subsample_ * corpus_size_) / word_freq; 157 if (rng_.RandFloat() > keep_prob) { 158 i--; 159 continue; 160 } 161 } 162 sentence_[i] = corpus_[example_pos_]; 163 } 164 } 165 const int32 skip = 1 + rng_.Uniform(window_size_); 166 label_pos_ = std::max<int32>(0, sentence_index_ - skip); 167 label_limit_ = 168 std::min<int32>(kSentenceSize, sentence_index_ + skip + 1); 169 } 170 if (sentence_index_ != label_pos_) { 171 break; 172 } 173 ++label_pos_; 174 } 175 *example = sentence_[sentence_index_]; 176 *label = sentence_[label_pos_++]; 177 } 178 179 Status Init(Env* env, const string& filename) { 180 string data; 181 TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data)); 182 StringPiece input = data; 183 string w; 184 corpus_size_ = 0; 185 std::unordered_map<string, int32> word_freq; 186 while (ScanWord(&input, &w)) { 187 ++(word_freq[w]); 188 ++corpus_size_; 189 } 190 if (corpus_size_ < window_size_ * 10) { 191 return errors::InvalidArgument( 192 "The text file ", filename, 193 " contains too little data: ", corpus_size_, " words"); 194 } 195 typedef std::pair<string, int32> WordFreq; 196 std::vector<WordFreq> ordered; 197 for (const auto& p : word_freq) { 198 if (p.second >= min_count_) ordered.push_back(p); 199 } 200 LOG(INFO) << "Data file: " << filename << " contains " << data.size() 201 << " bytes, " << corpus_size_ << " words, " << word_freq.size() 202 << " unique words, " << ordered.size() 203 << " unique frequent words."; 204 word_freq.clear(); 205 std::sort(ordered.begin(), ordered.end(), 206 [](const WordFreq& x, const WordFreq& y) { 207 return x.second > y.second; 208 }); 209 vocab_size_ = static_cast<int32>(1 + ordered.size()); 210 Tensor word(DT_STRING, TensorShape({vocab_size_})); 211 Tensor freq(DT_INT32, TensorShape({vocab_size_})); 212 word.flat<string>()(0) = "UNK"; 213 static const int32 kUnkId = 0; 214 std::unordered_map<string, int32> word_id; 215 int64 total_counted = 0; 216 for (std::size_t i = 0; i < ordered.size(); ++i) { 217 const auto& w = ordered[i].first; 218 auto id = i + 1; 219 word.flat<string>()(id) = w; 220 auto word_count = ordered[i].second; 221 freq.flat<int32>()(id) = word_count; 222 total_counted += word_count; 223 word_id[w] = id; 224 } 225 freq.flat<int32>()(kUnkId) = corpus_size_ - total_counted; 226 word_ = word; 227 freq_ = freq; 228 corpus_.reserve(corpus_size_); 229 input = data; 230 while (ScanWord(&input, &w)) { 231 corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId)); 232 } 233 precalc_examples_.resize(kPrecalc); 234 sentence_.resize(kSentenceSize); 235 return Status::OK(); 236 } 237 }; 238 239 REGISTER_KERNEL_BUILDER(Name("Skipgram").Device(DEVICE_CPU), SkipgramOp); 240 241 class NegTrainOp : public OpKernel { 242 public: 243 explicit NegTrainOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 244 base_.Init(0, 0); 245 246 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_)); 247 248 std::vector<int32> vocab_count; 249 OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count)); 250 251 std::vector<float> vocab_weights; 252 vocab_weights.reserve(vocab_count.size()); 253 for (const auto& f : vocab_count) { 254 float r = std::pow(static_cast<float>(f), 0.75f); 255 vocab_weights.push_back(r); 256 } 257 sampler_ = new random::DistributionSampler(vocab_weights); 258 } 259 260 ~NegTrainOp() override { delete sampler_; } 261 262 void Compute(OpKernelContext* ctx) override { 263 Tensor w_in = ctx->mutable_input(0, false); 264 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()), 265 errors::InvalidArgument("Must be a matrix")); 266 Tensor w_out = ctx->mutable_input(1, false); 267 OP_REQUIRES(ctx, w_in.shape() == w_out.shape(), 268 errors::InvalidArgument("w_in.shape == w_out.shape")); 269 const Tensor& examples = ctx->input(2); 270 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()), 271 errors::InvalidArgument("Must be a vector")); 272 const Tensor& labels = ctx->input(3); 273 OP_REQUIRES(ctx, examples.shape() == labels.shape(), 274 errors::InvalidArgument("examples.shape == labels.shape")); 275 const Tensor& learning_rate = ctx->input(4); 276 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()), 277 errors::InvalidArgument("Must be a scalar")); 278 279 auto Tw_in = w_in.matrix<float>(); 280 auto Tw_out = w_out.matrix<float>(); 281 auto Texamples = examples.flat<int32>(); 282 auto Tlabels = labels.flat<int32>(); 283 auto lr = learning_rate.scalar<float>()(); 284 const int64 vocab_size = w_in.dim_size(0); 285 const int64 dims = w_in.dim_size(1); 286 const int64 batch_size = examples.dim_size(0); 287 OP_REQUIRES(ctx, vocab_size == sampler_->num(), 288 errors::InvalidArgument("vocab_size mismatches: ", vocab_size, 289 " vs. ", sampler_->num())); 290 291 // Gradient accumulator for v_in. 292 Tensor buf(DT_FLOAT, TensorShape({dims})); 293 auto Tbuf = buf.flat<float>(); 294 295 // Scalar buffer to hold sigmoid(+/- dot). 296 Tensor g_buf(DT_FLOAT, TensorShape({})); 297 auto g = g_buf.scalar<float>(); 298 299 // The following loop needs 2 random 32-bit values per negative 300 // sample. We reserve 8 values per sample just in case the 301 // underlying implementation changes. 302 auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8); 303 random::SimplePhilox srnd(&rnd); 304 305 for (int64 i = 0; i < batch_size; ++i) { 306 const int32 example = Texamples(i); 307 DCHECK(0 <= example && example < vocab_size) << example; 308 const int32 label = Tlabels(i); 309 DCHECK(0 <= label && label < vocab_size) << label; 310 auto v_in = Tw_in.chip<0>(example); 311 312 // Positive: example predicts label. 313 // forward: x = v_in' * v_out 314 // l = log(sigmoid(x)) 315 // backward: dl/dx = g = sigmoid(-x) 316 // dl/d(v_in) = g * v_out' 317 // dl/d(v_out) = v_in' * g 318 { 319 auto v_out = Tw_out.chip<0>(label); 320 auto dot = (v_in * v_out).sum(); 321 g = (dot.exp() + 1.f).inverse(); 322 Tbuf = v_out * (g() * lr); 323 v_out += v_in * (g() * lr); 324 } 325 326 // Negative samples: 327 // forward: x = v_in' * v_sample 328 // l = log(sigmoid(-x)) 329 // backward: dl/dx = g = -sigmoid(x) 330 // dl/d(v_in) = g * v_out' 331 // dl/d(v_out) = v_in' * g 332 for (int j = 0; j < num_samples_; ++j) { 333 const int sample = sampler_->Sample(&srnd); 334 if (sample == label) continue; // Skip. 335 auto v_sample = Tw_out.chip<0>(sample); 336 auto dot = (v_in * v_sample).sum(); 337 g = -((-dot).exp() + 1.f).inverse(); 338 Tbuf += v_sample * (g() * lr); 339 v_sample += v_in * (g() * lr); 340 } 341 342 // Applies the gradient on v_in. 343 v_in += Tbuf; 344 } 345 } 346 347 private: 348 int32 num_samples_ = 0; 349 random::DistributionSampler* sampler_ = nullptr; 350 GuardedPhiloxRandom base_; 351 }; 352 353 REGISTER_KERNEL_BUILDER(Name("NegTrain").Device(DEVICE_CPU), NegTrainOp); 354 355 } // end namespace tensorflow 356