Home | History | Annotate | Download | only in smartreply
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
     17 #define TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
     18 
     19 #include <string>
     20 #include <vector>
     21 
     22 #include "tensorflow/contrib/lite/model.h"
     23 
     24 namespace tflite {
     25 namespace custom {
     26 namespace smartreply {
     27 
     28 const int kDefaultNumResponse = 10;
     29 const float kDefaultBackoffConfidence = 1e-4;
     30 
     31 class PredictorResponse;
     32 struct SmartReplyConfig;
     33 
     34 // With a given string as input, predict the response with a Tflite model.
     35 // When config.backoff_response is not empty, predictor_responses will be filled
     36 // with messagees from backoff response.
     37 void GetSegmentPredictions(const std::vector<std::string>& input,
     38                            const ::tflite::FlatBufferModel& model,
     39                            const SmartReplyConfig& config,
     40                            std::vector<PredictorResponse>* predictor_responses);
     41 
     42 // Data object used to hold a single predictor response.
     43 // It includes messages, and confidence.
     44 class PredictorResponse {
     45  public:
     46   PredictorResponse(const std::string& response_text, float score) {
     47     response_text_ = response_text;
     48     prediction_score_ = score;
     49   }
     50 
     51   // Accessor methods.
     52   const std::string& GetText() const { return response_text_; }
     53   float GetScore() const { return prediction_score_; }
     54 
     55  private:
     56   std::string response_text_ = "";
     57   float prediction_score_ = 0.0;
     58 };
     59 
     60 // Configurations for SmartReply.
     61 struct SmartReplyConfig {
     62   // Maximum responses to return.
     63   int num_response;
     64   // Default confidence for backoff responses.
     65   float backoff_confidence;
     66   // Backoff responses are used when predicted responses cannot fulfill the
     67   // list.
     68   const std::vector<std::string>& backoff_responses;
     69 
     70   SmartReplyConfig(std::vector<std::string> backoff_responses)
     71       : num_response(kDefaultNumResponse),
     72         backoff_confidence(kDefaultBackoffConfidence),
     73         backoff_responses(backoff_responses) {}
     74 };
     75 
     76 }  // namespace smartreply
     77 }  // namespace custom
     78 }  // namespace tflite
     79 
     80 #endif  // TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
     81