1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/lite/models/smartreply/predictor.h" 17 18 #include "absl/strings/str_split.h" 19 #include "re2/re2.h" 20 #include "tensorflow/contrib/lite/interpreter.h" 21 #include "tensorflow/contrib/lite/kernels/register.h" 22 #include "tensorflow/contrib/lite/model.h" 23 #include "tensorflow/contrib/lite/string_util.h" 24 #include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" 25 26 void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); 27 28 namespace tflite { 29 namespace custom { 30 namespace smartreply { 31 32 // Split sentence into segments (using punctuation). 33 std::vector<std::string> SplitSentence(const std::string& input) { 34 string result(input); 35 36 RE2::GlobalReplace(&result, "([?.!,])+", " \\1"); 37 RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t"); 38 RE2::GlobalReplace(&result, "[ ]+", " "); 39 RE2::GlobalReplace(&result, "\t+$", ""); 40 41 return absl::StrSplit(result, '\t'); 42 } 43 44 // Predict with TfLite model. 45 void ExecuteTfLite(const std::string& sentence, 46 ::tflite::Interpreter* interpreter, 47 std::map<std::string, float>* response_map) { 48 { 49 TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]); 50 tflite::DynamicBuffer buf; 51 buf.AddString(sentence.data(), sentence.length()); 52 buf.WriteToTensor(input); 53 interpreter->AllocateTensors(); 54 55 interpreter->Invoke(); 56 57 TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]); 58 TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]); 59 60 for (int i = 0; i < confidence->dims->data[0]; i++) { 61 float weight = confidence->data.f[i]; 62 auto response_text = tflite::GetString(messages, i); 63 if (response_text.len > 0) { 64 (*response_map)[string(response_text.str, response_text.len)] += weight; 65 } 66 } 67 } 68 } 69 70 void GetSegmentPredictions( 71 const std::vector<std::string>& input, 72 const ::tflite::FlatBufferModel& model, const SmartReplyConfig& config, 73 std::vector<PredictorResponse>* predictor_responses) { 74 // Initialize interpreter 75 std::unique_ptr<::tflite::Interpreter> interpreter; 76 ::tflite::MutableOpResolver resolver; 77 RegisterSelectedOps(&resolver); 78 ::tflite::InterpreterBuilder(model, resolver)(&interpreter); 79 80 if (!model.initialized()) { 81 fprintf(stderr, "Failed to mmap model \n"); 82 return; 83 } 84 85 // Execute Tflite Model 86 std::map<std::string, float> response_map; 87 std::vector<std::string> sentences; 88 for (const std::string& str : input) { 89 std::vector<std::string> splitted_str = SplitSentence(str); 90 sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end()); 91 } 92 for (const auto& sentence : sentences) { 93 ExecuteTfLite(sentence, interpreter.get(), &response_map); 94 } 95 96 // Generate the result. 97 for (const auto& iter : response_map) { 98 PredictorResponse prediction(iter.first, iter.second); 99 predictor_responses->emplace_back(prediction); 100 } 101 std::sort(predictor_responses->begin(), predictor_responses->end(), 102 [](const PredictorResponse& a, const PredictorResponse& b) { 103 return a.GetScore() > b.GetScore(); 104 }); 105 106 // Add backoff response. 107 for (const string& backoff : config.backoff_responses) { 108 if (predictor_responses->size() >= config.num_response) { 109 break; 110 } 111 predictor_responses->push_back({backoff, config.backoff_confidence}); 112 } 113 } 114 115 } // namespace smartreply 116 } // namespace custom 117 } // namespace tflite 118