Home | History | Annotate | Download | only in src
      1 #include <fstream>
      2 #include <iostream>
      3 #include <sstream>
      4 
      5 #include "gtest/gtest.h"
      6 #include "gflags/gflags.h"
      7 #include "nugget/app/protoapi/control.pb.h"
      8 #include "nugget/app/protoapi/header.pb.h"
      9 #include "nugget/app/protoapi/testing_api.pb.h"
     10 #include "src/macros.h"
     11 #include "src/util.h"
     12 
     13 using nugget::app::protoapi::AesGcmEncryptTest;
     14 using nugget::app::protoapi::AesGcmEncryptTestResult;
     15 using nugget::app::protoapi::APImessageID;
     16 using nugget::app::protoapi::DcryptError;
     17 using nugget::app::protoapi::Notice;
     18 using nugget::app::protoapi::NoticeCode;
     19 using nugget::app::protoapi::OneofTestParametersCase;
     20 using nugget::app::protoapi::OneofTestResultsCase;
     21 using std::cout;
     22 using std::stringstream;
     23 using std::unique_ptr;
     24 
     25 DEFINE_bool(nos_test_dump_protos, false, "Dump binary protobufs to a file.");
     26 DEFINE_int32(test_input_number, -1, "Run a specific test input.");
     27 
     28 #define ASSERT_MSG_TYPE(msg, type_) \
     29 do{if(type_ != APImessageID::NOTICE && msg.type == APImessageID::NOTICE){ \
     30   Notice received; \
     31   received.ParseFromArray(reinterpret_cast<char *>(msg.data), msg.data_len); \
     32   ASSERT_EQ(msg.type, type_) \
     33       << msg.type << " is " << APImessageID_Name((APImessageID) msg.type) \
     34       << "\n" << received.DebugString(); \
     35 }else{ \
     36   ASSERT_EQ(msg.type, type_) \
     37       << msg.type << " is " << APImessageID_Name((APImessageID) msg.type); \
     38 }}while(0)
     39 
     40 #define ASSERT_SUBTYPE(msg, type_) \
     41   EXPECT_GT(msg.data_len, 2); \
     42   uint16_t subtype = (msg.data[0] << 8) | msg.data[1]; \
     43   EXPECT_EQ(subtype, type_)
     44 
     45 namespace {
     46 
     47 using test_harness::BYTE_TIME;
     48 
     49 class NuggetOsTest: public testing::Test {
     50  protected:
     51   static void SetUpTestCase();
     52   static void TearDownTestCase();
     53 
     54  public:
     55   static unique_ptr<test_harness::TestHarness> harness;
     56 };
     57 
     58 unique_ptr<test_harness::TestHarness> NuggetOsTest::harness;
     59 
     60 void NuggetOsTest::SetUpTestCase() {
     61   harness = TestHarness::MakeUnique();
     62 
     63   if (!harness->UsingSpi()) {
     64     EXPECT_TRUE(harness->SwitchFromConsoleToProtoApi());
     65     EXPECT_TRUE(harness->ttyState());
     66   }
     67 }
     68 
     69 void NuggetOsTest::TearDownTestCase() {
     70   harness->ReadUntil(test_harness::BYTE_TIME * 1024);
     71   if (!harness->UsingSpi()) {
     72     EXPECT_TRUE(harness->SwitchFromProtoApiToConsole(NULL));
     73   }
     74   harness = unique_ptr<test_harness::TestHarness>();
     75 }
     76 
     77 #include "src/test-data/NIST-CAVP/aes-gcm-cavp.h"
     78 
     79 TEST_F(NuggetOsTest, AesGcm) {
     80   const int verbosity = harness->getVerbosity();
     81   harness->setVerbosity(verbosity - 1);
     82   harness->ReadUntil(test_harness::BYTE_TIME * 1024);
     83 
     84   size_t i = 0;
     85   size_t test_input_count = ARRAYSIZE(NIST_GCM_DATA);
     86   if (FLAGS_test_input_number != -1) {
     87     i = FLAGS_test_input_number;
     88     test_input_count = FLAGS_test_input_number + 1;
     89   }
     90   for (; i < test_input_count; i++) {
     91     const gcm_data *test_case = &NIST_GCM_DATA[i];
     92 
     93     AesGcmEncryptTest request;
     94     request.set_key(test_case->key, test_case->key_len / 8);
     95     request.set_iv(test_case->IV, test_case->IV_len / 8);
     96     request.set_plain_text(test_case->PT, test_case->PT_len / 8);
     97     request.set_aad(test_case->AAD, test_case->AAD_len / 8);
     98     request.set_tag_len(test_case->tag_len / 8);
     99 
    100     if (FLAGS_nos_test_dump_protos) {
    101       std::ofstream outfile;
    102       outfile.open("AesGcmEncryptTest_" + std::to_string(test_case->key_len) +
    103                    ".proto.bin", std::ios_base::binary);
    104       outfile << request.SerializeAsString();
    105       outfile.close();
    106     }
    107 
    108     ASSERT_NO_ERROR(harness->SendOneofProto(
    109         APImessageID::TESTING_API_CALL,
    110         OneofTestParametersCase::kAesGcmEncryptTest,
    111         request), "");
    112 
    113     test_harness::raw_message msg;
    114     ASSERT_NO_ERROR(harness->GetData(&msg, 4096 * BYTE_TIME), "");
    115     ASSERT_MSG_TYPE(msg, APImessageID::TESTING_API_RESPONSE);
    116     ASSERT_SUBTYPE(msg, OneofTestResultsCase::kAesGcmEncryptTestResult);
    117 
    118     AesGcmEncryptTestResult result;
    119     ASSERT_TRUE(result.ParseFromArray(reinterpret_cast<char *>(msg.data + 2),
    120                                       msg.data_len - 2));
    121     EXPECT_EQ(result.result_code(), DcryptError::DE_NO_ERROR)
    122         << result.result_code() << " is "
    123         << DcryptError_Name(result.result_code());
    124 
    125     ASSERT_EQ(result.cipher_text().size(), test_case->PT_len / 8)
    126             << "\n" << result.DebugString();
    127     const uint8_t *CT = (const uint8_t *)test_case->CT;
    128     stringstream ct_ss;
    129     for (size_t j = 0; j < test_case->PT_len / 8; j++) {
    130       if (CT[j] < 16) {
    131         ct_ss << '0';
    132       }
    133       ct_ss << std::hex << (unsigned int)CT[j];
    134     }
    135     for (size_t j = 0; j < test_case->PT_len / 8; j++) {
    136       ASSERT_EQ(result.cipher_text()[j] & 0x00FF, CT[j] & 0x00FF)
    137               << "\n"
    138               << "test_case: " << i << "\n"
    139               << "result   : " << result.DebugString()
    140               << "CT       : " << ct_ss.str() << "\n"
    141               << "mis-match: " << j;
    142     }
    143 
    144     ASSERT_EQ(result.tag().size(), test_case->tag_len / 8)
    145             << "\n" << result.DebugString();
    146     const uint8_t *tag = (const uint8_t *)test_case->tag;
    147     stringstream tag_ss;
    148     for (size_t j = 0; j < test_case->tag_len / 8; j++) {
    149       if (tag[j] < 16) {
    150         tag_ss << '0';
    151       }
    152       tag_ss << std::hex << (unsigned int)tag[j];
    153     }
    154     for (size_t j = 0; j < test_case->tag_len / 8; j++) {
    155       ASSERT_EQ(result.tag()[j] & 0x00ff, tag[j] & 0x00ff)
    156               << "\n"
    157               << "test_case: " << i << "\n"
    158               << "result   : " << result.DebugString()
    159               << "TAG      : " << tag_ss.str() << "\n"
    160               << "mis-match: " << j;
    161     }
    162   }
    163 
    164   harness->ReadUntil(test_harness::BYTE_TIME * 1024);
    165   harness->setVerbosity(verbosity);
    166 }
    167 
    168 }  // namespace
    169