Home | History | Annotate | Download | only in smartreply
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/lite/models/smartreply/predictor.h"
     17 
     18 #include <fstream>
     19 #include <unordered_set>
     20 
     21 #include <gmock/gmock.h>
     22 #include <gtest/gtest.h>
     23 #include "absl/strings/str_cat.h"
     24 #include "absl/strings/str_split.h"
     25 //#include "tensorflow/contrib/lite/models/test_utils.h"
     26 #include "tensorflow/contrib/lite/string_util.h"
     27 #include "tensorflow/core/platform/test.h"
     28 
     29 namespace tflite {
     30 namespace custom {
     31 namespace smartreply {
     32 namespace {
     33 
     34 const char kModelName[] = "smartreply_ondevice_model.bin";
     35 const char kSamples[] = "smartreply_samples.tsv";
     36 
     37 string TestDataPath() {
     38   return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/",
     39                        "contrib/lite/models/testdata/"));
     40 }
     41 
     42 MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") {
     43   bool has_expected_response = false;
     44   for (const auto &item : *arg) {
     45     const string &response = item.GetText();
     46     if (expected_response.find(response) != expected_response.end()) {
     47       has_expected_response = true;
     48       break;
     49     }
     50   }
     51   return has_expected_response;
     52 }
     53 
     54 class PredictorTest : public ::testing::Test {
     55  protected:
     56   PredictorTest() {
     57     model_ = tflite::FlatBufferModel::BuildFromFile(
     58         StrCat(TestDataPath(), "/", kModelName).c_str());
     59     CHECK(model_);
     60   }
     61   ~PredictorTest() override {}
     62 
     63   std::unique_ptr<::tflite::FlatBufferModel> model_;
     64 };
     65 
     66 TEST_F(PredictorTest, GetSegmentPredictions) {
     67   std::vector<PredictorResponse> predictions;
     68 
     69   GetSegmentPredictions({"Welcome"}, *model_, /*config=*/{{}}, &predictions);
     70   EXPECT_GT(predictions.size(), 0);
     71 
     72   float max = 0;
     73   for (const auto &item : predictions) {
     74     if (item.GetScore() > max) {
     75       max = item.GetScore();
     76     }
     77   }
     78 
     79   EXPECT_GT(max, 0.3);
     80   EXPECT_THAT(
     81       &predictions,
     82       IncludeAnyResponesIn(std::unordered_set<string>({"Thanks very much"})));
     83 }
     84 
     85 TEST_F(PredictorTest, TestTwoSentences) {
     86   std::vector<PredictorResponse> predictions;
     87 
     88   GetSegmentPredictions({"Hello", "How are you?"}, *model_, /*config=*/{{}},
     89                         &predictions);
     90   EXPECT_GT(predictions.size(), 0);
     91 
     92   float max = 0;
     93   for (const auto &item : predictions) {
     94     if (item.GetScore() > max) {
     95       max = item.GetScore();
     96     }
     97   }
     98 
     99   EXPECT_GT(max, 0.3);
    100   EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
    101                                 {"Hi, how are you doing?"})));
    102 }
    103 
    104 TEST_F(PredictorTest, TestBackoff) {
    105   std::vector<PredictorResponse> predictions;
    106 
    107   GetSegmentPredictions({""}, *model_, /*config=*/{{}}, &predictions);
    108   EXPECT_EQ(predictions.size(), 0);
    109 
    110   // Backoff responses are returned in order.
    111   GetSegmentPredictions({""}, *model_, /*config=*/{{"Yes", "Ok"}},
    112                         &predictions);
    113   EXPECT_EQ(predictions.size(), 2);
    114   EXPECT_EQ(predictions[0].GetText(), "Yes");
    115   EXPECT_EQ(predictions[1].GetText(), "Ok");
    116 }
    117 
    118 TEST_F(PredictorTest, BatchTest) {
    119   int total_items = 0;
    120   int total_responses = 0;
    121   int total_triggers = 0;
    122 
    123   string line;
    124   std::ifstream fin(StrCat(TestDataPath(), "/", kSamples));
    125   while (std::getline(fin, line)) {
    126     const std::vector<string> fields = absl::StrSplit(line, '\t');
    127     if (fields.empty()) {
    128       continue;
    129     }
    130 
    131     // Parse sample file and predict
    132     const string &msg = fields[0];
    133     std::vector<PredictorResponse> predictions;
    134     GetSegmentPredictions({msg}, *model_, /*config=*/{{}}, &predictions);
    135 
    136     // Validate response and generate stats.
    137     total_items++;
    138     total_responses += predictions.size();
    139     if (!predictions.empty()) {
    140       total_triggers++;
    141     }
    142     EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
    143                                   fields.begin() + 1, fields.end())));
    144   }
    145 
    146   EXPECT_EQ(total_triggers, total_items);
    147   EXPECT_GE(total_responses, total_triggers);
    148 }
    149 
    150 }  // namespace
    151 }  // namespace smartreply
    152 }  // namespace custom
    153 }  // namespace tflite
    154