1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ 17 #define TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "tensorflow/contrib/lite/model.h" 23 24 namespace tflite { 25 namespace custom { 26 namespace smartreply { 27 28 const int kDefaultNumResponse = 10; 29 const float kDefaultBackoffConfidence = 1e-4; 30 31 class PredictorResponse; 32 struct SmartReplyConfig; 33 34 // With a given string as input, predict the response with a Tflite model. 35 // When config.backoff_response is not empty, predictor_responses will be filled 36 // with messagees from backoff response. 37 void GetSegmentPredictions(const std::vector<std::string>& input, 38 const ::tflite::FlatBufferModel& model, 39 const SmartReplyConfig& config, 40 std::vector<PredictorResponse>* predictor_responses); 41 42 // Data object used to hold a single predictor response. 43 // It includes messages, and confidence. 44 class PredictorResponse { 45 public: 46 PredictorResponse(const std::string& response_text, float score) { 47 response_text_ = response_text; 48 prediction_score_ = score; 49 } 50 51 // Accessor methods. 52 const std::string& GetText() const { return response_text_; } 53 float GetScore() const { return prediction_score_; } 54 55 private: 56 std::string response_text_ = ""; 57 float prediction_score_ = 0.0; 58 }; 59 60 // Configurations for SmartReply. 61 struct SmartReplyConfig { 62 // Maximum responses to return. 63 int num_response; 64 // Default confidence for backoff responses. 65 float backoff_confidence; 66 // Backoff responses are used when predicted responses cannot fulfill the 67 // list. 68 const std::vector<std::string>& backoff_responses; 69 70 SmartReplyConfig(std::vector<std::string> backoff_responses) 71 : num_response(kDefaultNumResponse), 72 backoff_confidence(kDefaultBackoffConfidence), 73 backoff_responses(backoff_responses) {} 74 }; 75 76 } // namespace smartreply 77 } // namespace custom 78 } // namespace tflite 79 80 #endif // TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ 81