1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "common/libs/net/netlink_client.h" 17 18 #include <linux/rtnetlink.h> 19 20 #include <gmock/gmock.h> 21 #include <gtest/gtest.h> 22 #include <glog/logging.h> 23 24 #include <iostream> 25 #include <memory> 26 27 using ::testing::ElementsAreArray; 28 using ::testing::MatchResultListener; 29 using ::testing::Return; 30 31 namespace cvd { 32 namespace { 33 extern "C" void klog_write(int /* level */, const char* /* format */, ...) {} 34 35 // Dump hex buffer to test log. 36 void Dump(MatchResultListener* result_listener, const char* title, 37 const uint8_t* data, size_t length) { 38 for (size_t item = 0; item < length;) { 39 *result_listener << title; 40 do { 41 result_listener->stream()->width(2); 42 result_listener->stream()->fill('0'); 43 *result_listener << std::hex << +data[item] << " "; 44 ++item; 45 } while (item & 0xf); 46 *result_listener << "\n"; 47 } 48 } 49 50 // Compare two memory areas byte by byte, print information about first 51 // difference. Dumps both bufferst to user log. 52 bool Compare(MatchResultListener* result_listener, 53 const uint8_t* exp, const uint8_t* act, size_t length) { 54 for (size_t index = 0; index < length; ++index) { 55 if (exp[index] != act[index]) { 56 *result_listener << "\nUnexpected data at offset " << index << "\n"; 57 Dump(result_listener, "Data Expected: ", exp, length); 58 Dump(result_listener, " Data Actual: ", act, length); 59 return false; 60 } 61 } 62 63 return true; 64 } 65 66 // Matcher validating Netlink Request data. 67 MATCHER_P2(RequestDataIs, data, length, "Matches expected request data") { 68 size_t offset = sizeof(nlmsghdr); 69 if (offset + length != arg.RequestLength()) { 70 *result_listener << "Unexpected request length: " 71 << arg.RequestLength() - offset << " vs " << length; 72 return false; 73 } 74 75 // Note: Request begins with header (nlmsghdr). Header is not covered by this 76 // call. 77 const uint8_t* exp_data = static_cast<const uint8_t*>( 78 static_cast<const void*>(data)); 79 const uint8_t* act_data = static_cast<const uint8_t*>(arg.RequestData()); 80 return Compare( 81 result_listener, exp_data, &act_data[offset], length); 82 } 83 84 MATCHER_P4(RequestHeaderIs, length, type, flags, seq, 85 "Matches request header") { 86 nlmsghdr* header = static_cast<nlmsghdr*>(arg.RequestData()); 87 if (arg.RequestLength() < sizeof(header)) { 88 *result_listener << "Malformed header: too short."; 89 return false; 90 } 91 92 if (header->nlmsg_len != length) { 93 *result_listener << "Invalid message length: " 94 << header->nlmsg_len << " vs " << length; 95 return false; 96 } 97 98 if (header->nlmsg_type != type) { 99 *result_listener << "Invalid header type: " 100 << header->nlmsg_type << " vs " << type; 101 return false; 102 } 103 104 if (header->nlmsg_flags != flags) { 105 *result_listener << "Invalid header flags: " 106 << header->nlmsg_flags << " vs " << flags; 107 return false; 108 } 109 110 if (header->nlmsg_seq != seq) { 111 *result_listener << "Invalid header sequence number: " 112 << header->nlmsg_seq << " vs " << seq; 113 return false; 114 } 115 116 return true; 117 } 118 } // namespace 119 120 class NetlinkClientTest : public ::testing::Test { 121 void SetUp() { 122 google::InstallFailureSignalHandler(); 123 } 124 protected: 125 std::unique_ptr<NetlinkClient> nl_client_; 126 }; 127 128 TEST_F(NetlinkClientTest, BasicStringNode) { 129 constexpr uint16_t kDummyTag = 0xfce2; 130 constexpr char kLongString[] = "long string"; 131 132 struct { 133 // 11 bytes of text + padding 0 + 4 bytes of header. 134 const uint16_t attr_length = 0x10; 135 const uint16_t attr_type = kDummyTag; 136 char text[sizeof(kLongString)]; // sizeof includes padding 0. 137 } expected; 138 139 memcpy(&expected.text, kLongString, sizeof(kLongString)); 140 141 cvd::NetlinkRequest request(RTM_SETLINK, 0); 142 request.AddString(kDummyTag, kLongString); 143 EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected))); 144 } 145 146 TEST_F(NetlinkClientTest, BasicIntNode) { 147 // Basic { Dummy: Value } test. 148 constexpr uint16_t kDummyTag = 0xfce2; 149 constexpr int32_t kValue = 0x1badd00d; 150 151 struct { 152 const uint16_t attr_length = 0x8; // 4 bytes of value + 4 bytes of header. 153 const uint16_t attr_type = kDummyTag; 154 const uint32_t attr_value = kValue; 155 } expected; 156 157 cvd::NetlinkRequest request(RTM_SETLINK, 0); 158 request.AddInt32(kDummyTag, kValue); 159 EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected))); 160 } 161 162 TEST_F(NetlinkClientTest, SingleList) { 163 // List: { Dummy: Value} 164 constexpr uint16_t kDummyTag = 0xfce2; 165 constexpr uint16_t kListTag = 0xcafe; 166 constexpr int32_t kValue = 0x1badd00d; 167 168 struct { 169 const uint16_t list_length = 0xc; 170 const uint16_t list_type = kListTag; 171 const uint16_t attr_length = 0x8; // 4 bytes of value + 4 bytes of header. 172 const uint16_t attr_type = kDummyTag; 173 const uint32_t attr_value = kValue; 174 } expected; 175 176 cvd::NetlinkRequest request(RTM_SETLINK, 0); 177 request.PushList(kListTag); 178 request.AddInt32(kDummyTag, kValue); 179 request.PopList(); 180 181 EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected))); 182 } 183 184 TEST_F(NetlinkClientTest, NestedList) { 185 // List1: { List2: { Dummy: Value}} 186 constexpr uint16_t kDummyTag = 0xfce2; 187 constexpr uint16_t kList1Tag = 0xcafe; 188 constexpr uint16_t kList2Tag = 0xfeed; 189 constexpr int32_t kValue = 0x1badd00d; 190 191 struct { 192 const uint16_t list1_length = 0x10; 193 const uint16_t list1_type = kList1Tag; 194 const uint16_t list2_length = 0xc; 195 const uint16_t list2_type = kList2Tag; 196 const uint16_t attr_length = 0x8; 197 const uint16_t attr_type = kDummyTag; 198 const uint32_t attr_value = kValue; 199 } expected; 200 201 cvd::NetlinkRequest request(RTM_SETLINK, 0); 202 request.PushList(kList1Tag); 203 request.PushList(kList2Tag); 204 request.AddInt32(kDummyTag, kValue); 205 request.PopList(); 206 request.PopList(); 207 208 EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected))); 209 } 210 211 TEST_F(NetlinkClientTest, ListSequence) { 212 // List1: { Dummy1: Value1}, List2: { Dummy2: Value2 } 213 constexpr uint16_t kDummy1Tag = 0xfce2; 214 constexpr uint16_t kDummy2Tag = 0xfd38; 215 constexpr uint16_t kList1Tag = 0xcafe; 216 constexpr uint16_t kList2Tag = 0xfeed; 217 constexpr int32_t kValue1 = 0x1badd00d; 218 constexpr int32_t kValue2 = 0xfee1; 219 220 struct { 221 const uint16_t list1_length = 0xc; 222 const uint16_t list1_type = kList1Tag; 223 const uint16_t attr1_length = 0x8; 224 const uint16_t attr1_type = kDummy1Tag; 225 const uint32_t attr1_value = kValue1; 226 const uint16_t list2_length = 0xc; 227 const uint16_t list2_type = kList2Tag; 228 const uint16_t attr2_length = 0x8; 229 const uint16_t attr2_type = kDummy2Tag; 230 const uint32_t attr2_value = kValue2; 231 } expected; 232 233 cvd::NetlinkRequest request(RTM_SETLINK, 0); 234 request.PushList(kList1Tag); 235 request.AddInt32(kDummy1Tag, kValue1); 236 request.PopList(); 237 request.PushList(kList2Tag); 238 request.AddInt32(kDummy2Tag, kValue2); 239 request.PopList(); 240 241 EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected))); 242 } 243 244 TEST_F(NetlinkClientTest, ComplexList) { 245 // List1: { List2: { Dummy1: Value1 }, Dummy2: Value2 } 246 constexpr uint16_t kDummy1Tag = 0xfce2; 247 constexpr uint16_t kDummy2Tag = 0xfd38; 248 constexpr uint16_t kList1Tag = 0xcafe; 249 constexpr uint16_t kList2Tag = 0xfeed; 250 constexpr int32_t kValue1 = 0x1badd00d; 251 constexpr int32_t kValue2 = 0xfee1; 252 253 struct { 254 const uint16_t list1_length = 0x18; 255 const uint16_t list1_type = kList1Tag; 256 const uint16_t list2_length = 0xc; // Note, this only covers until kValue1. 257 const uint16_t list2_type = kList2Tag; 258 const uint16_t attr1_length = 0x8; 259 const uint16_t attr1_type = kDummy1Tag; 260 const uint32_t attr1_value = kValue1; 261 const uint16_t attr2_length = 0x8; 262 const uint16_t attr2_type = kDummy2Tag; 263 const uint32_t attr2_value = kValue2; 264 } expected; 265 266 cvd::NetlinkRequest request(RTM_SETLINK, 0); 267 request.PushList(kList1Tag); 268 request.PushList(kList2Tag); 269 request.AddInt32(kDummy1Tag, kValue1); 270 request.PopList(); 271 request.AddInt32(kDummy2Tag, kValue2); 272 request.PopList(); 273 274 EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected))); 275 } 276 277 TEST_F(NetlinkClientTest, SimpleNetlinkCreateHeader) { 278 cvd::NetlinkRequest request(RTM_NEWLINK, NLM_F_CREATE | NLM_F_EXCL); 279 constexpr char kValue[] = "random string"; 280 request.AddString(0, kValue); // Have something to work with. 281 282 constexpr size_t kMsgLength = 283 sizeof(nlmsghdr) + sizeof(nlattr) + RTA_ALIGN(sizeof(kValue)); 284 int base_seq = request.SeqNo(); 285 286 EXPECT_THAT(request, RequestHeaderIs( 287 kMsgLength, 288 RTM_NEWLINK, 289 NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL | NLM_F_REQUEST, 290 base_seq)); 291 292 cvd::NetlinkRequest request2(RTM_NEWLINK, NLM_F_CREATE | NLM_F_EXCL); 293 request2.AddString(0, kValue); // Have something to work with. 294 EXPECT_THAT(request2, RequestHeaderIs( 295 kMsgLength, 296 RTM_NEWLINK, 297 NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL | NLM_F_REQUEST, 298 base_seq + 1)); 299 } 300 301 TEST_F(NetlinkClientTest, SimpleNetlinkUpdateHeader) { 302 cvd::NetlinkRequest request(RTM_SETLINK, 0); 303 constexpr char kValue[] = "random string"; 304 request.AddString(0, kValue); // Have something to work with. 305 306 constexpr size_t kMsgLength = 307 sizeof(nlmsghdr) + sizeof(nlattr) + RTA_ALIGN(sizeof(kValue)); 308 int base_seq = request.SeqNo(); 309 310 EXPECT_THAT(request, RequestHeaderIs( 311 kMsgLength, RTM_SETLINK, NLM_F_REQUEST | NLM_F_ACK, base_seq)); 312 313 cvd::NetlinkRequest request2(RTM_SETLINK, 0); 314 request2.AddString(0, kValue); // Have something to work with. 315 EXPECT_THAT(request2, RequestHeaderIs( 316 kMsgLength, RTM_SETLINK, NLM_F_REQUEST | NLM_F_ACK, base_seq + 1)); 317 } 318 319 } // namespace cvd 320