Home | History | Annotate | Download | only in smartreply
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/lite/models/smartreply/predictor.h"
     17 
     18 #include "absl/strings/str_split.h"
     19 #include "re2/re2.h"
     20 #include "tensorflow/contrib/lite/interpreter.h"
     21 #include "tensorflow/contrib/lite/kernels/register.h"
     22 #include "tensorflow/contrib/lite/model.h"
     23 #include "tensorflow/contrib/lite/string_util.h"
     24 #include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
     25 
     26 void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
     27 
     28 namespace tflite {
     29 namespace custom {
     30 namespace smartreply {
     31 
     32 // Split sentence into segments (using punctuation).
     33 std::vector<std::string> SplitSentence(const std::string& input) {
     34   string result(input);
     35 
     36   RE2::GlobalReplace(&result, "([?.!,])+", " \\1");
     37   RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t");
     38   RE2::GlobalReplace(&result, "[ ]+", " ");
     39   RE2::GlobalReplace(&result, "\t+$", "");
     40 
     41   return absl::StrSplit(result, '\t');
     42 }
     43 
     44 // Predict with TfLite model.
     45 void ExecuteTfLite(const std::string& sentence,
     46                    ::tflite::Interpreter* interpreter,
     47                    std::map<std::string, float>* response_map) {
     48   {
     49     TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
     50     tflite::DynamicBuffer buf;
     51     buf.AddString(sentence.data(), sentence.length());
     52     buf.WriteToTensor(input);
     53     interpreter->AllocateTensors();
     54 
     55     interpreter->Invoke();
     56 
     57     TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]);
     58     TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]);
     59 
     60     for (int i = 0; i < confidence->dims->data[0]; i++) {
     61       float weight = confidence->data.f[i];
     62       auto response_text = tflite::GetString(messages, i);
     63       if (response_text.len > 0) {
     64         (*response_map)[string(response_text.str, response_text.len)] += weight;
     65       }
     66     }
     67   }
     68 }
     69 
     70 void GetSegmentPredictions(
     71     const std::vector<std::string>& input,
     72     const ::tflite::FlatBufferModel& model, const SmartReplyConfig& config,
     73     std::vector<PredictorResponse>* predictor_responses) {
     74   // Initialize interpreter
     75   std::unique_ptr<::tflite::Interpreter> interpreter;
     76   ::tflite::MutableOpResolver resolver;
     77   RegisterSelectedOps(&resolver);
     78   ::tflite::InterpreterBuilder(model, resolver)(&interpreter);
     79 
     80   if (!model.initialized()) {
     81     fprintf(stderr, "Failed to mmap model \n");
     82     return;
     83   }
     84 
     85   // Execute Tflite Model
     86   std::map<std::string, float> response_map;
     87   std::vector<std::string> sentences;
     88   for (const std::string& str : input) {
     89     std::vector<std::string> splitted_str = SplitSentence(str);
     90     sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end());
     91   }
     92   for (const auto& sentence : sentences) {
     93     ExecuteTfLite(sentence, interpreter.get(), &response_map);
     94   }
     95 
     96   // Generate the result.
     97   for (const auto& iter : response_map) {
     98     PredictorResponse prediction(iter.first, iter.second);
     99     predictor_responses->emplace_back(prediction);
    100   }
    101   std::sort(predictor_responses->begin(), predictor_responses->end(),
    102             [](const PredictorResponse& a, const PredictorResponse& b) {
    103               return a.GetScore() > b.GetScore();
    104             });
    105 
    106   // Add backoff response.
    107   for (const string& backoff : config.backoff_responses) {
    108     if (predictor_responses->size() >= config.num_response) {
    109       break;
    110     }
    111     predictor_responses->push_back({backoff, config.backoff_confidence});
    112   }
    113 }
    114 
    115 }  // namespace smartreply
    116 }  // namespace custom
    117 }  // namespace tflite
    118