Home | History | Annotate | Download | only in tests
      1 #include <sstream>
      2 
      3 #include <marisa_alpha.h>
      4 
      5 #include "assert.h"
      6 
      7 namespace {
      8 
      9 class FindCallback {
     10  public:
     11   FindCallback(std::vector<marisa_alpha::UInt32> *key_ids,
     12       std::vector<std::size_t> *key_lengths)
     13       : key_ids_(key_ids), key_lengths_(key_lengths) {}
     14   FindCallback(const FindCallback &callback)
     15       : key_ids_(callback.key_ids_), key_lengths_(callback.key_lengths_) {}
     16 
     17   bool operator()(marisa_alpha::UInt32 key_id, std::size_t key_length) const {
     18     key_ids_->push_back(key_id);
     19     key_lengths_->push_back(key_length);
     20     return true;
     21   }
     22 
     23  private:
     24   std::vector<marisa_alpha::UInt32> *key_ids_;
     25   std::vector<std::size_t> *key_lengths_;
     26 
     27   // Disallows assignment.
     28   FindCallback &operator=(const FindCallback &);
     29 };
     30 
     31 class PredictCallback {
     32  public:
     33   PredictCallback(std::vector<marisa_alpha::UInt32> *key_ids,
     34       std::vector<std::string> *keys)
     35       : key_ids_(key_ids), keys_(keys) {}
     36   PredictCallback(const PredictCallback &callback)
     37       : key_ids_(callback.key_ids_), keys_(callback.keys_) {}
     38 
     39   bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) const {
     40     key_ids_->push_back(key_id);
     41     keys_->push_back(key);
     42     return true;
     43   }
     44 
     45  private:
     46   std::vector<marisa_alpha::UInt32> *key_ids_;
     47   std::vector<std::string> *keys_;
     48 
     49   // Disallows assignment.
     50   PredictCallback &operator=(const PredictCallback &);
     51 };
     52 
     53 void TestTrie() {
     54   TEST_START();
     55 
     56   marisa_alpha::Trie trie;
     57 
     58   ASSERT(trie.num_tries() == 0);
     59   ASSERT(trie.num_keys() == 0);
     60   ASSERT(trie.num_nodes() == 0);
     61   ASSERT(trie.total_size() == (sizeof(marisa_alpha::UInt32) * 23));
     62 
     63   std::vector<std::string> keys;
     64   trie.build(keys);
     65   ASSERT(trie.num_tries() == 1);
     66   ASSERT(trie.num_keys() == 0);
     67   ASSERT(trie.num_nodes() == 1);
     68 
     69   keys.push_back("apple");
     70   keys.push_back("and");
     71   keys.push_back("Bad");
     72   keys.push_back("apple");
     73   keys.push_back("app");
     74 
     75   std::vector<marisa_alpha::UInt32> key_ids;
     76   trie.build(keys, &key_ids,
     77       1 | MARISA_ALPHA_WITHOUT_TAIL | MARISA_ALPHA_LABEL_ORDER);
     78 
     79   ASSERT(trie.num_tries() == 1);
     80   ASSERT(trie.num_keys() == 4);
     81   ASSERT(trie.num_nodes() == 11);
     82 
     83   ASSERT(key_ids.size() == 5);
     84   ASSERT(key_ids[0] == 3);
     85   ASSERT(key_ids[1] == 1);
     86   ASSERT(key_ids[2] == 0);
     87   ASSERT(key_ids[3] == 3);
     88   ASSERT(key_ids[4] == 2);
     89 
     90   char key_buf[256];
     91   std::size_t key_length;
     92   for (std::size_t i = 0; i < keys.size(); ++i) {
     93     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
     94 
     95     ASSERT(trie[keys[i]] == key_ids[i]);
     96     ASSERT(trie[key_ids[i]] == keys[i]);
     97     ASSERT(key_length == keys[i].length());
     98     ASSERT(keys[i] == key_buf);
     99   }
    100 
    101   trie.clear();
    102 
    103   ASSERT(trie.num_tries() == 0);
    104   ASSERT(trie.num_keys() == 0);
    105   ASSERT(trie.num_nodes() == 0);
    106   ASSERT(trie.total_size() == (sizeof(marisa_alpha::UInt32) * 23));
    107 
    108   trie.build(keys, &key_ids,
    109       1 | MARISA_ALPHA_WITHOUT_TAIL | MARISA_ALPHA_WEIGHT_ORDER);
    110 
    111   ASSERT(trie.num_tries() == 1);
    112   ASSERT(trie.num_keys() == 4);
    113   ASSERT(trie.num_nodes() == 11);
    114 
    115   ASSERT(key_ids.size() == 5);
    116   ASSERT(key_ids[0] == 3);
    117   ASSERT(key_ids[1] == 1);
    118   ASSERT(key_ids[2] == 2);
    119   ASSERT(key_ids[3] == 3);
    120   ASSERT(key_ids[4] == 0);
    121 
    122   for (std::size_t i = 0; i < keys.size(); ++i) {
    123     ASSERT(trie[keys[i]] == key_ids[i]);
    124     ASSERT(trie[key_ids[i]] == keys[i]);
    125   }
    126 
    127   ASSERT(trie["appl"] == trie.notfound());
    128   ASSERT(trie["applex"] == trie.notfound());
    129   ASSERT(trie.find_first("ap") == trie.notfound());
    130   ASSERT(trie.find_first("applex") == trie["app"]);
    131   ASSERT(trie.find_last("ap") == trie.notfound());
    132   ASSERT(trie.find_last("applex") == trie["apple"]);
    133 
    134   std::vector<marisa_alpha::UInt32> ids;
    135   ASSERT(trie.find("ap", &ids) == 0);
    136   ASSERT(trie.find("applex", &ids) == 2);
    137   ASSERT(ids.size() == 2);
    138   ASSERT(ids[0] == trie["app"]);
    139   ASSERT(ids[1] == trie["apple"]);
    140 
    141   std::vector<std::size_t> lengths;
    142   ASSERT(trie.find("Baddie", &ids, &lengths) == 1);
    143   ASSERT(ids.size() == 3);
    144   ASSERT(ids[2] == trie["Bad"]);
    145   ASSERT(lengths.size() == 1);
    146   ASSERT(lengths[0] == 3);
    147 
    148   ASSERT(trie.find_callback("anderson", FindCallback(&ids, &lengths)) == 1);
    149   ASSERT(ids.size() == 4);
    150   ASSERT(ids[3] == trie["and"]);
    151   ASSERT(lengths.size() == 2);
    152   ASSERT(lengths[1] == 3);
    153 
    154   ASSERT(trie.predict("") == 4);
    155   ASSERT(trie.predict("a") == 3);
    156   ASSERT(trie.predict("ap") == 2);
    157   ASSERT(trie.predict("app") == 2);
    158   ASSERT(trie.predict("appl") == 1);
    159   ASSERT(trie.predict("apple") == 1);
    160   ASSERT(trie.predict("appleX") == 0);
    161   ASSERT(trie.predict("X") == 0);
    162 
    163   ids.clear();
    164   ASSERT(trie.predict("a", &ids) == 3);
    165   ASSERT(ids.size() == 3);
    166   ASSERT(ids[0] == trie["app"]);
    167   ASSERT(ids[1] == trie["and"]);
    168   ASSERT(ids[2] == trie["apple"]);
    169 
    170   std::vector<std::string> strs;
    171   ASSERT(trie.predict("a", &ids, &strs) == 3);
    172   ASSERT(ids.size() == 6);
    173   ASSERT(ids[3] == trie["app"]);
    174   ASSERT(ids[4] == trie["apple"]);
    175   ASSERT(ids[5] == trie["and"]);
    176   ASSERT(strs[0] == "app");
    177   ASSERT(strs[1] == "apple");
    178   ASSERT(strs[2] == "and");
    179 
    180   TEST_END();
    181 }
    182 
    183 void TestPrefixTrie() {
    184   TEST_START();
    185 
    186   std::vector<std::string> keys;
    187   keys.push_back("after");
    188   keys.push_back("bar");
    189   keys.push_back("car");
    190   keys.push_back("caster");
    191 
    192   marisa_alpha::Trie trie;
    193   std::vector<marisa_alpha::UInt32> key_ids;
    194   trie.build(keys, &key_ids, 1 | MARISA_ALPHA_PREFIX_TRIE
    195       | MARISA_ALPHA_TEXT_TAIL | MARISA_ALPHA_LABEL_ORDER);
    196 
    197   ASSERT(trie.num_tries() == 1);
    198   ASSERT(trie.num_keys() == 4);
    199   ASSERT(trie.num_nodes() == 7);
    200 
    201   char key_buf[256];
    202   std::size_t key_length;
    203   for (std::size_t i = 0; i < keys.size(); ++i) {
    204     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    205 
    206     ASSERT(trie[keys[i]] == key_ids[i]);
    207     ASSERT(trie[key_ids[i]] == keys[i]);
    208     ASSERT(key_length == keys[i].length());
    209     ASSERT(keys[i] == key_buf);
    210   }
    211 
    212   key_length = trie.restore(key_ids[0], NULL, 0);
    213 
    214   ASSERT(key_length == keys[0].length());
    215   EXCEPT(trie.restore(key_ids[0], NULL, 5), MARISA_ALPHA_PARAM_ERROR);
    216 
    217   key_length = trie.restore(key_ids[0], key_buf, 5);
    218 
    219   ASSERT(key_length == keys[0].length());
    220 
    221   key_length = trie.restore(key_ids[0], key_buf, 6);
    222 
    223   ASSERT(key_length == keys[0].length());
    224 
    225   trie.build(keys, &key_ids, 2 | MARISA_ALPHA_PREFIX_TRIE
    226       | MARISA_ALPHA_WITHOUT_TAIL | MARISA_ALPHA_WEIGHT_ORDER);
    227 
    228   ASSERT(trie.num_tries() == 2);
    229   ASSERT(trie.num_keys() == 4);
    230   ASSERT(trie.num_nodes() == 16);
    231 
    232   for (std::size_t i = 0; i < keys.size(); ++i) {
    233     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    234 
    235     ASSERT(trie[keys[i]] == key_ids[i]);
    236     ASSERT(trie[key_ids[i]] == keys[i]);
    237     ASSERT(key_length == keys[i].length());
    238     ASSERT(keys[i] == key_buf);
    239   }
    240 
    241   key_length = trie.restore(key_ids[0], NULL, 0);
    242 
    243   ASSERT(key_length == keys[0].length());
    244   EXCEPT(trie.restore(key_ids[0], NULL, 5), MARISA_ALPHA_PARAM_ERROR);
    245 
    246   key_length = trie.restore(key_ids[0], key_buf, 5);
    247 
    248   ASSERT(key_length == keys[0].length());
    249 
    250   key_length = trie.restore(key_ids[0], key_buf, 6);
    251 
    252   ASSERT(key_length == keys[0].length());
    253 
    254   trie.build(keys, &key_ids, 2 | MARISA_ALPHA_PREFIX_TRIE
    255       | MARISA_ALPHA_TEXT_TAIL | MARISA_ALPHA_LABEL_ORDER);
    256 
    257   ASSERT(trie.num_tries() == 2);
    258   ASSERT(trie.num_keys() == 4);
    259   ASSERT(trie.num_nodes() == 14);
    260 
    261   for (std::size_t i = 0; i < keys.size(); ++i) {
    262     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    263 
    264     ASSERT(trie[keys[i]] == key_ids[i]);
    265     ASSERT(trie[key_ids[i]] == keys[i]);
    266     ASSERT(key_length == keys[i].length());
    267     ASSERT(keys[i] == key_buf);
    268   }
    269 
    270   trie.save("trie-test.dat");
    271   trie.clear();
    272   marisa_alpha::Mapper mapper;
    273   trie.mmap(&mapper, "trie-test.dat");
    274 
    275   ASSERT(mapper.is_open());
    276   ASSERT(trie.num_tries() == 2);
    277   ASSERT(trie.num_keys() == 4);
    278   ASSERT(trie.num_nodes() == 14);
    279 
    280   for (std::size_t i = 0; i < keys.size(); ++i) {
    281     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    282 
    283     ASSERT(trie[keys[i]] == key_ids[i]);
    284     ASSERT(trie[key_ids[i]] == keys[i]);
    285     ASSERT(key_length == keys[i].length());
    286     ASSERT(keys[i] == key_buf);
    287   }
    288 
    289   std::stringstream stream;
    290   trie.write(stream);
    291   trie.clear();
    292   trie.read(stream);
    293 
    294   ASSERT(trie.num_tries() == 2);
    295   ASSERT(trie.num_keys() == 4);
    296   ASSERT(trie.num_nodes() == 14);
    297 
    298   for (std::size_t i = 0; i < keys.size(); ++i) {
    299     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    300 
    301     ASSERT(trie[keys[i]] == key_ids[i]);
    302     ASSERT(trie[key_ids[i]] == keys[i]);
    303     ASSERT(key_length == keys[i].length());
    304     ASSERT(keys[i] == key_buf);
    305   }
    306 
    307   trie.build(keys, &key_ids, 3 | MARISA_ALPHA_PREFIX_TRIE
    308       | MARISA_ALPHA_WITHOUT_TAIL | MARISA_ALPHA_WEIGHT_ORDER);
    309 
    310   ASSERT(trie.num_tries() == 3);
    311   ASSERT(trie.num_keys() == 4);
    312   ASSERT(trie.num_nodes() == 19);
    313 
    314   for (std::size_t i = 0; i < keys.size(); ++i) {
    315     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    316 
    317     ASSERT(trie[keys[i]] == key_ids[i]);
    318     ASSERT(trie[key_ids[i]] == keys[i]);
    319     ASSERT(key_length == keys[i].length());
    320     ASSERT(keys[i] == key_buf);
    321   }
    322 
    323   ASSERT(trie["ca"] == trie.notfound());
    324   ASSERT(trie["card"] == trie.notfound());
    325 
    326   std::size_t length = 0;
    327   ASSERT(trie.find_first("ca") == trie.notfound());
    328   ASSERT(trie.find_first("car") == trie["car"]);
    329   ASSERT(trie.find_first("card", &length) == trie["car"]);
    330   ASSERT(length == 3);
    331 
    332   ASSERT(trie.find_last("afte") == trie.notfound());
    333   ASSERT(trie.find_last("after") == trie["after"]);
    334   ASSERT(trie.find_last("afternoon", &length) == trie["after"]);
    335   ASSERT(length == 5);
    336 
    337   {
    338     std::vector<marisa_alpha::UInt32> ids;
    339     std::vector<std::size_t> lengths;
    340     ASSERT(trie.find("card", &ids, &lengths) == 1);
    341     ASSERT(ids.size() == 1);
    342     ASSERT(ids[0] == trie["car"]);
    343     ASSERT(lengths.size() == 1);
    344     ASSERT(lengths[0] == 3);
    345 
    346     ASSERT(trie.predict("ca", &ids) == 2);
    347     ASSERT(ids.size() == 3);
    348     ASSERT(ids[1] == trie["car"]);
    349     ASSERT(ids[2] == trie["caster"]);
    350 
    351     ASSERT(trie.predict("ca", &ids, NULL, 1) == 1);
    352     ASSERT(ids.size() == 4);
    353     ASSERT(ids[3] == trie["car"]);
    354 
    355     std::vector<std::string> strs;
    356     ASSERT(trie.predict("ca", &ids, &strs, 1) == 1);
    357     ASSERT(ids.size() == 5);
    358     ASSERT(ids[4] == trie["car"]);
    359     ASSERT(strs.size() == 1);
    360     ASSERT(strs[0] == "car");
    361 
    362     ASSERT(trie.predict_callback("", PredictCallback(&ids, &strs)) == 4);
    363     ASSERT(ids.size() == 9);
    364     ASSERT(ids[5] == trie["car"]);
    365     ASSERT(ids[6] == trie["caster"]);
    366     ASSERT(ids[7] == trie["after"]);
    367     ASSERT(ids[8] == trie["bar"]);
    368     ASSERT(strs.size() == 5);
    369     ASSERT(strs[1] == "car");
    370     ASSERT(strs[2] == "caster");
    371     ASSERT(strs[3] == "after");
    372     ASSERT(strs[4] == "bar");
    373   }
    374 
    375   {
    376     marisa_alpha::UInt32 ids[10];
    377     std::size_t lengths[10];
    378     ASSERT(trie.find("card", ids, lengths, 10) == 1);
    379     ASSERT(ids[0] == trie["car"]);
    380     ASSERT(lengths[0] == 3);
    381 
    382     ASSERT(trie.predict("ca", ids, NULL, 10) == 2);
    383     ASSERT(ids[0] == trie["car"]);
    384     ASSERT(ids[1] == trie["caster"]);
    385 
    386     ASSERT(trie.predict("ca", ids, NULL, 1) == 1);
    387     ASSERT(ids[0] == trie["car"]);
    388 
    389     std::string strs[10];
    390     ASSERT(trie.predict("ca", ids, strs, 1) == 1);
    391     ASSERT(ids[0] == trie["car"]);
    392     ASSERT(strs[0] == "car");
    393 
    394     ASSERT(trie.predict("", ids, strs, 10) == 4);
    395     ASSERT(ids[0] == trie["car"]);
    396     ASSERT(ids[1] == trie["caster"]);
    397     ASSERT(ids[2] == trie["after"]);
    398     ASSERT(ids[3] == trie["bar"]);
    399     ASSERT(strs[0] == "car");
    400     ASSERT(strs[1] == "caster");
    401     ASSERT(strs[2] == "after");
    402     ASSERT(strs[3] == "bar");
    403   }
    404 
    405   std::string trie_data = stream.str();
    406   trie.map(trie_data.c_str(), trie_data.length());
    407 
    408   ASSERT(trie.num_tries() == 2);
    409   ASSERT(trie.num_keys() == 4);
    410   ASSERT(trie.num_nodes() == 14);
    411 
    412   for (std::size_t i = 0; i < keys.size(); ++i) {
    413     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    414 
    415     ASSERT(trie[keys[i]] == key_ids[i]);
    416     ASSERT(trie[key_ids[i]] == keys[i]);
    417     ASSERT(key_length == keys[i].length());
    418     ASSERT(keys[i] == key_buf);
    419   }
    420 
    421   TEST_END();
    422 }
    423 
    424 void TestPatriciaTrie() {
    425   TEST_START();
    426 
    427   std::vector<std::string> keys;
    428   keys.push_back("bach");
    429   keys.push_back("bet");
    430   keys.push_back("chat");
    431   keys.push_back("check");
    432   keys.push_back("check");
    433 
    434   marisa_alpha::Trie trie;
    435   std::vector<marisa_alpha::UInt32> key_ids;
    436   trie.build(keys, &key_ids, 1);
    437 
    438   ASSERT(trie.num_tries() == 1);
    439   ASSERT(trie.num_keys() == 4);
    440   ASSERT(trie.num_nodes() == 7);
    441 
    442   ASSERT(key_ids.size() == 5);
    443   ASSERT(key_ids[0] == 2);
    444   ASSERT(key_ids[1] == 3);
    445   ASSERT(key_ids[2] == 1);
    446   ASSERT(key_ids[3] == 0);
    447   ASSERT(key_ids[4] == 0);
    448 
    449   char key_buf[256];
    450   std::size_t key_length;
    451   for (std::size_t i = 0; i < keys.size(); ++i) {
    452     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    453 
    454     ASSERT(trie[keys[i]] == key_ids[i]);
    455     ASSERT(trie[key_ids[i]] == keys[i]);
    456     ASSERT(key_length == keys[i].length());
    457     ASSERT(keys[i] == key_buf);
    458   }
    459 
    460   trie.build(keys, &key_ids, 2 | MARISA_ALPHA_WITHOUT_TAIL);
    461 
    462   ASSERT(trie.num_tries() == 2);
    463   ASSERT(trie.num_keys() == 4);
    464   ASSERT(trie.num_nodes() == 17);
    465 
    466   for (std::size_t i = 0; i < keys.size(); ++i) {
    467     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    468 
    469     ASSERT(trie[keys[i]] == key_ids[i]);
    470     ASSERT(trie[key_ids[i]] == keys[i]);
    471     ASSERT(key_length == keys[i].length());
    472     ASSERT(keys[i] == key_buf);
    473   }
    474 
    475   trie.build(keys, &key_ids, 2);
    476 
    477   ASSERT(trie.num_tries() == 2);
    478   ASSERT(trie.num_keys() == 4);
    479   ASSERT(trie.num_nodes() == 14);
    480 
    481   for (std::size_t i = 0; i < keys.size(); ++i) {
    482     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    483 
    484     ASSERT(trie[keys[i]] == key_ids[i]);
    485     ASSERT(trie[key_ids[i]] == keys[i]);
    486     ASSERT(key_length == keys[i].length());
    487     ASSERT(keys[i] == key_buf);
    488   }
    489 
    490   trie.build(keys, &key_ids, 3 | MARISA_ALPHA_WITHOUT_TAIL);
    491 
    492   ASSERT(trie.num_tries() == 3);
    493   ASSERT(trie.num_keys() == 4);
    494   ASSERT(trie.num_nodes() == 20);
    495 
    496   for (std::size_t i = 0; i < keys.size(); ++i) {
    497     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    498 
    499     ASSERT(trie[keys[i]] == key_ids[i]);
    500     ASSERT(trie[key_ids[i]] == keys[i]);
    501     ASSERT(key_length == keys[i].length());
    502     ASSERT(keys[i] == key_buf);
    503   }
    504 
    505   std::stringstream stream;
    506   trie.write(stream);
    507   trie.clear();
    508   trie.read(stream);
    509 
    510   ASSERT(trie.num_tries() == 3);
    511   ASSERT(trie.num_keys() == 4);
    512   ASSERT(trie.num_nodes() == 20);
    513 
    514   for (std::size_t i = 0; i < keys.size(); ++i) {
    515     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
    516 
    517     ASSERT(trie[keys[i]] == key_ids[i]);
    518     ASSERT(trie[key_ids[i]] == keys[i]);
    519     ASSERT(key_length == keys[i].length());
    520     ASSERT(keys[i] == key_buf);
    521   }
    522 
    523   TEST_END();
    524 }
    525 
    526 void TestEmptyString() {
    527   TEST_START();
    528 
    529   std::vector<std::string> keys;
    530   keys.push_back("");
    531 
    532   marisa_alpha::Trie trie;
    533   std::vector<marisa_alpha::UInt32> key_ids;
    534   trie.build(keys, &key_ids);
    535 
    536   ASSERT(trie.num_tries() == 1);
    537   ASSERT(trie.num_keys() == 1);
    538   ASSERT(trie.num_nodes() == 1);
    539 
    540   ASSERT(key_ids.size() == 1);
    541   ASSERT(key_ids[0] == 0);
    542 
    543   ASSERT(trie[""] == 0);
    544   ASSERT(trie[(marisa_alpha::UInt32)0] == "");
    545 
    546   ASSERT(trie["x"] == trie.notfound());
    547   ASSERT(trie.find_first("") == 0);
    548   ASSERT(trie.find_first("x") == 0);
    549   ASSERT(trie.find_last("") == 0);
    550   ASSERT(trie.find_last("x") == 0);
    551 
    552   std::vector<marisa_alpha::UInt32> ids;
    553   ASSERT(trie.find("xyz", &ids) == 1);
    554   ASSERT(ids.size() == 1);
    555   ASSERT(ids[0] == trie[""]);
    556 
    557   std::vector<std::size_t> lengths;
    558   ASSERT(trie.find("xyz", &ids, &lengths) == 1);
    559   ASSERT(ids.size() == 2);
    560   ASSERT(ids[0] == trie[""]);
    561   ASSERT(ids[1] == trie[""]);
    562   ASSERT(lengths.size() == 1);
    563   ASSERT(lengths[0] == 0);
    564 
    565   ASSERT(trie.find_callback("xyz", FindCallback(&ids, &lengths)) == 1);
    566   ASSERT(ids.size() == 3);
    567   ASSERT(ids[2] == trie[""]);
    568   ASSERT(lengths.size() == 2);
    569   ASSERT(lengths[1] == 0);
    570 
    571   ASSERT(trie.predict("xyz", &ids) == 0);
    572 
    573   ASSERT(trie.predict("", &ids) == 1);
    574   ASSERT(ids.size() == 4);
    575   ASSERT(ids[3] == trie[""]);
    576 
    577   std::vector<std::string> strs;
    578   ASSERT(trie.predict("", &ids, &strs) == 1);
    579   ASSERT(ids.size() == 5);
    580   ASSERT(ids[4] == trie[""]);
    581   ASSERT(strs[0] == "");
    582 
    583   TEST_END();
    584 }
    585 
    586 void TestBinaryKey() {
    587   TEST_START();
    588 
    589   std::string binary_key = "NP";
    590   binary_key += '\0';
    591   binary_key += "Trie";
    592 
    593   std::vector<std::string> keys;
    594   keys.push_back(binary_key);
    595 
    596   marisa_alpha::Trie trie;
    597   std::vector<marisa_alpha::UInt32> key_ids;
    598   trie.build(keys, &key_ids, 1 | MARISA_ALPHA_WITHOUT_TAIL);
    599 
    600   ASSERT(trie.num_tries() == 1);
    601   ASSERT(trie.num_keys() == 1);
    602   ASSERT(trie.num_nodes() == 8);
    603   ASSERT(key_ids.size() == 1);
    604 
    605   char key_buf[256];
    606   std::size_t key_length;
    607   key_length = trie.restore(0, key_buf, sizeof(key_buf));
    608 
    609   ASSERT(trie[keys[0]] == key_ids[0]);
    610   ASSERT(trie[key_ids[0]] == keys[0]);
    611   ASSERT(std::string(key_buf, key_length) == keys[0]);
    612 
    613   trie.build(keys, &key_ids,
    614       1 | MARISA_ALPHA_PREFIX_TRIE | MARISA_ALPHA_BINARY_TAIL);
    615 
    616   ASSERT(trie.num_tries() == 1);
    617   ASSERT(trie.num_keys() == 1);
    618   ASSERT(trie.num_nodes() == 2);
    619   ASSERT(key_ids.size() == 1);
    620 
    621   key_length = trie.restore(0, key_buf, sizeof(key_buf));
    622 
    623   ASSERT(trie[keys[0]] == key_ids[0]);
    624   ASSERT(trie[key_ids[0]] == keys[0]);
    625   ASSERT(std::string(key_buf, key_length) == keys[0]);
    626 
    627   trie.build(keys, &key_ids,
    628       1 | MARISA_ALPHA_PREFIX_TRIE | MARISA_ALPHA_TEXT_TAIL);
    629 
    630   ASSERT(trie.num_tries() == 1);
    631   ASSERT(trie.num_keys() == 1);
    632   ASSERT(trie.num_nodes() == 2);
    633   ASSERT(key_ids.size() == 1);
    634 
    635   key_length = trie.restore(0, key_buf, sizeof(key_buf));
    636 
    637   ASSERT(trie[keys[0]] == key_ids[0]);
    638   ASSERT(trie[key_ids[0]] == keys[0]);
    639   ASSERT(std::string(key_buf, key_length) == keys[0]);
    640 
    641   std::vector<marisa_alpha::UInt32> ids;
    642   ASSERT(trie.predict_breadth_first("", &ids) == 1);
    643   ASSERT(ids.size() == 1);
    644   ASSERT(ids[0] == key_ids[0]);
    645 
    646   std::vector<std::string> strs;
    647   ASSERT(trie.predict_depth_first("NP", &ids, &strs) == 1);
    648   ASSERT(ids.size() == 2);
    649   ASSERT(ids[1] == key_ids[0]);
    650   ASSERT(strs[0] == keys[0]);
    651 
    652   TEST_END();
    653 }
    654 
    655 }  // namespace
    656 
    657 int main() {
    658   TestTrie();
    659   TestPrefixTrie();
    660   TestPatriciaTrie();
    661   TestEmptyString();
    662   TestBinaryKey();
    663 
    664   return 0;
    665 }
    666