1 // Copyright (c) 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "extensions/browser/api/declarative/declarative_rule.h" 6 7 #include "base/bind.h" 8 #include "base/message_loop/message_loop.h" 9 #include "base/test/values_test_util.h" 10 #include "base/values.h" 11 #include "components/url_matcher/url_matcher_constants.h" 12 #include "extensions/common/extension_builder.h" 13 #include "testing/gmock/include/gmock/gmock.h" 14 #include "testing/gtest/include/gtest/gtest.h" 15 16 using base::test::ParseJson; 17 using url_matcher::URLMatcher; 18 using url_matcher::URLMatcherConditionFactory; 19 using url_matcher::URLMatcherConditionSet; 20 21 namespace extensions { 22 23 namespace { 24 25 template<typename T> 26 linked_ptr<T> ScopedToLinkedPtr(scoped_ptr<T> ptr) { 27 return linked_ptr<T>(ptr.release()); 28 } 29 30 scoped_ptr<base::DictionaryValue> SimpleManifest() { 31 return DictionaryBuilder() 32 .Set("name", "extension") 33 .Set("manifest_version", 2) 34 .Set("version", "1.0") 35 .Build(); 36 } 37 38 } // namespace 39 40 struct RecordingCondition { 41 typedef int MatchData; 42 43 URLMatcherConditionFactory* factory; 44 scoped_ptr<base::Value> value; 45 46 void GetURLMatcherConditionSets( 47 URLMatcherConditionSet::Vector* condition_sets) const { 48 // No condition sets. 49 } 50 51 static scoped_ptr<RecordingCondition> Create( 52 const Extension* extension, 53 URLMatcherConditionFactory* url_matcher_condition_factory, 54 const base::Value& condition, 55 std::string* error) { 56 const base::DictionaryValue* dict = NULL; 57 if (condition.GetAsDictionary(&dict) && dict->HasKey("bad_key")) { 58 *error = "Found error key"; 59 return scoped_ptr<RecordingCondition>(); 60 } 61 62 scoped_ptr<RecordingCondition> result(new RecordingCondition()); 63 result->factory = url_matcher_condition_factory; 64 result->value.reset(condition.DeepCopy()); 65 return result.Pass(); 66 } 67 }; 68 typedef DeclarativeConditionSet<RecordingCondition> RecordingConditionSet; 69 70 TEST(DeclarativeConditionTest, ErrorConditionSet) { 71 URLMatcher matcher; 72 RecordingConditionSet::AnyVector conditions; 73 conditions.push_back(ScopedToLinkedPtr(ParseJson("{\"key\": 1}"))); 74 conditions.push_back(ScopedToLinkedPtr(ParseJson("{\"bad_key\": 2}"))); 75 76 std::string error; 77 scoped_ptr<RecordingConditionSet> result = RecordingConditionSet::Create( 78 NULL, matcher.condition_factory(), conditions, &error); 79 EXPECT_EQ("Found error key", error); 80 ASSERT_FALSE(result); 81 } 82 83 TEST(DeclarativeConditionTest, CreateConditionSet) { 84 URLMatcher matcher; 85 RecordingConditionSet::AnyVector conditions; 86 conditions.push_back(ScopedToLinkedPtr(ParseJson("{\"key\": 1}"))); 87 conditions.push_back(ScopedToLinkedPtr(ParseJson("[\"val1\", 2]"))); 88 89 // Test insertion 90 std::string error; 91 scoped_ptr<RecordingConditionSet> result = RecordingConditionSet::Create( 92 NULL, matcher.condition_factory(), conditions, &error); 93 EXPECT_EQ("", error); 94 ASSERT_TRUE(result); 95 EXPECT_EQ(2u, result->conditions().size()); 96 97 EXPECT_EQ(matcher.condition_factory(), result->conditions()[0]->factory); 98 EXPECT_TRUE(ParseJson("{\"key\": 1}")->Equals( 99 result->conditions()[0]->value.get())); 100 } 101 102 struct FulfillableCondition { 103 struct MatchData { 104 int value; 105 const std::set<URLMatcherConditionSet::ID>& url_matches; 106 }; 107 108 scoped_refptr<URLMatcherConditionSet> condition_set; 109 int condition_set_id; 110 int max_value; 111 112 URLMatcherConditionSet::ID url_matcher_condition_set_id() const { 113 return condition_set_id; 114 } 115 116 scoped_refptr<URLMatcherConditionSet> url_matcher_condition_set() const { 117 return condition_set; 118 } 119 120 void GetURLMatcherConditionSets( 121 URLMatcherConditionSet::Vector* condition_sets) const { 122 if (condition_set.get()) 123 condition_sets->push_back(condition_set); 124 } 125 126 bool IsFulfilled(const MatchData& match_data) const { 127 if (condition_set_id != -1 && 128 !ContainsKey(match_data.url_matches, condition_set_id)) 129 return false; 130 return match_data.value <= max_value; 131 } 132 133 static scoped_ptr<FulfillableCondition> Create( 134 const Extension* extension, 135 URLMatcherConditionFactory* url_matcher_condition_factory, 136 const base::Value& condition, 137 std::string* error) { 138 scoped_ptr<FulfillableCondition> result(new FulfillableCondition()); 139 const base::DictionaryValue* dict; 140 if (!condition.GetAsDictionary(&dict)) { 141 *error = "Expected dict"; 142 return result.Pass(); 143 } 144 if (!dict->GetInteger("url_id", &result->condition_set_id)) 145 result->condition_set_id = -1; 146 if (!dict->GetInteger("max", &result->max_value)) 147 *error = "Expected integer at ['max']"; 148 if (result->condition_set_id != -1) { 149 result->condition_set = new URLMatcherConditionSet( 150 result->condition_set_id, 151 URLMatcherConditionSet::Conditions()); 152 } 153 return result.Pass(); 154 } 155 }; 156 157 TEST(DeclarativeConditionTest, FulfillConditionSet) { 158 typedef DeclarativeConditionSet<FulfillableCondition> FulfillableConditionSet; 159 FulfillableConditionSet::AnyVector conditions; 160 conditions.push_back(ScopedToLinkedPtr(ParseJson( 161 "{\"url_id\": 1, \"max\": 3}"))); 162 conditions.push_back(ScopedToLinkedPtr(ParseJson( 163 "{\"url_id\": 2, \"max\": 5}"))); 164 conditions.push_back(ScopedToLinkedPtr(ParseJson( 165 "{\"url_id\": 3, \"max\": 1}"))); 166 conditions.push_back(ScopedToLinkedPtr(ParseJson( 167 "{\"max\": -5}"))); // No url. 168 169 // Test insertion 170 std::string error; 171 scoped_ptr<FulfillableConditionSet> result = 172 FulfillableConditionSet::Create(NULL, NULL, conditions, &error); 173 ASSERT_EQ("", error); 174 ASSERT_TRUE(result); 175 EXPECT_EQ(4u, result->conditions().size()); 176 177 std::set<URLMatcherConditionSet::ID> url_matches; 178 FulfillableCondition::MatchData match_data = { 0, url_matches }; 179 EXPECT_FALSE(result->IsFulfilled(1, match_data)) 180 << "Testing an ID that's not in url_matches forwards to the Condition, " 181 << "which doesn't match."; 182 EXPECT_FALSE(result->IsFulfilled(-1, match_data)) 183 << "Testing the 'no ID' value tries to match the 4th condition, but " 184 << "its max is too low."; 185 match_data.value = -5; 186 EXPECT_TRUE(result->IsFulfilled(-1, match_data)) 187 << "Testing the 'no ID' value tries to match the 4th condition, and " 188 << "this value is low enough."; 189 190 url_matches.insert(1); 191 match_data.value = 3; 192 EXPECT_TRUE(result->IsFulfilled(1, match_data)) 193 << "Tests a condition with a url matcher, for a matching value."; 194 match_data.value = 4; 195 EXPECT_FALSE(result->IsFulfilled(1, match_data)) 196 << "Tests a condition with a url matcher, for a non-matching value " 197 << "that would match a different condition."; 198 url_matches.insert(2); 199 EXPECT_TRUE(result->IsFulfilled(2, match_data)) 200 << "Tests with 2 elements in the match set."; 201 202 // Check the condition sets: 203 URLMatcherConditionSet::Vector condition_sets; 204 result->GetURLMatcherConditionSets(&condition_sets); 205 ASSERT_EQ(3U, condition_sets.size()); 206 EXPECT_EQ(1, condition_sets[0]->id()); 207 EXPECT_EQ(2, condition_sets[1]->id()); 208 EXPECT_EQ(3, condition_sets[2]->id()); 209 } 210 211 // DeclarativeAction 212 213 class SummingAction : public base::RefCounted<SummingAction> { 214 public: 215 typedef int ApplyInfo; 216 217 SummingAction(int increment, int min_priority) 218 : increment_(increment), min_priority_(min_priority) {} 219 220 static scoped_refptr<const SummingAction> Create( 221 content::BrowserContext* browser_context, 222 const Extension* extension, 223 const base::Value& action, 224 std::string* error, 225 bool* bad_message) { 226 int increment = 0; 227 int min_priority = 0; 228 const base::DictionaryValue* dict = NULL; 229 EXPECT_TRUE(action.GetAsDictionary(&dict)); 230 if (dict->HasKey("error")) { 231 EXPECT_TRUE(dict->GetString("error", error)); 232 return scoped_refptr<const SummingAction>(NULL); 233 } 234 if (dict->HasKey("bad")) { 235 *bad_message = true; 236 return scoped_refptr<const SummingAction>(NULL); 237 } 238 239 EXPECT_TRUE(dict->GetInteger("value", &increment)); 240 dict->GetInteger("priority", &min_priority); 241 return scoped_refptr<const SummingAction>( 242 new SummingAction(increment, min_priority)); 243 } 244 245 void Apply(const std::string& extension_id, 246 const base::Time& install_time, 247 int* sum) const { 248 *sum += increment_; 249 } 250 251 int increment() const { return increment_; } 252 int minimum_priority() const { 253 return min_priority_; 254 } 255 256 private: 257 friend class base::RefCounted<SummingAction>; 258 virtual ~SummingAction() {} 259 260 int increment_; 261 int min_priority_; 262 }; 263 typedef DeclarativeActionSet<SummingAction> SummingActionSet; 264 265 TEST(DeclarativeActionTest, ErrorActionSet) { 266 SummingActionSet::AnyVector actions; 267 actions.push_back(ScopedToLinkedPtr(ParseJson("{\"value\": 1}"))); 268 actions.push_back(ScopedToLinkedPtr(ParseJson("{\"error\": \"the error\"}"))); 269 270 std::string error; 271 bool bad = false; 272 scoped_ptr<SummingActionSet> result = 273 SummingActionSet::Create(NULL, NULL, actions, &error, &bad); 274 EXPECT_EQ("the error", error); 275 EXPECT_FALSE(bad); 276 EXPECT_FALSE(result); 277 278 actions.clear(); 279 actions.push_back(ScopedToLinkedPtr(ParseJson("{\"value\": 1}"))); 280 actions.push_back(ScopedToLinkedPtr(ParseJson("{\"bad\": 3}"))); 281 result = SummingActionSet::Create(NULL, NULL, actions, &error, &bad); 282 EXPECT_EQ("", error); 283 EXPECT_TRUE(bad); 284 EXPECT_FALSE(result); 285 } 286 287 TEST(DeclarativeActionTest, ApplyActionSet) { 288 SummingActionSet::AnyVector actions; 289 actions.push_back(ScopedToLinkedPtr(ParseJson( 290 "{\"value\": 1," 291 " \"priority\": 5}"))); 292 actions.push_back(ScopedToLinkedPtr(ParseJson("{\"value\": 2}"))); 293 294 // Test insertion 295 std::string error; 296 bool bad = false; 297 scoped_ptr<SummingActionSet> result = 298 SummingActionSet::Create(NULL, NULL, actions, &error, &bad); 299 EXPECT_EQ("", error); 300 EXPECT_FALSE(bad); 301 ASSERT_TRUE(result); 302 EXPECT_EQ(2u, result->actions().size()); 303 304 int sum = 0; 305 result->Apply("ext_id", base::Time(), &sum); 306 EXPECT_EQ(3, sum); 307 EXPECT_EQ(5, result->GetMinimumPriority()); 308 } 309 310 TEST(DeclarativeRuleTest, Create) { 311 typedef DeclarativeRule<FulfillableCondition, SummingAction> Rule; 312 linked_ptr<Rule::JsonRule> json_rule(new Rule::JsonRule); 313 ASSERT_TRUE(Rule::JsonRule::Populate( 314 *ParseJson("{ \n" 315 " \"id\": \"rule1\", \n" 316 " \"conditions\": [ \n" 317 " {\"url_id\": 1, \"max\": 3}, \n" 318 " {\"url_id\": 2, \"max\": 5}, \n" 319 " ], \n" 320 " \"actions\": [ \n" 321 " { \n" 322 " \"value\": 2 \n" 323 " } \n" 324 " ], \n" 325 " \"priority\": 200 \n" 326 "}"), 327 json_rule.get())); 328 329 const char kExtensionId[] = "ext1"; 330 scoped_refptr<Extension> extension = ExtensionBuilder() 331 .SetManifest(SimpleManifest()) 332 .SetID(kExtensionId) 333 .Build(); 334 335 base::Time install_time = base::Time::Now(); 336 337 URLMatcher matcher; 338 std::string error; 339 scoped_ptr<Rule> rule(Rule::Create(matcher.condition_factory(), 340 NULL, 341 extension.get(), 342 install_time, 343 json_rule, 344 Rule::ConsistencyChecker(), 345 &error)); 346 EXPECT_EQ("", error); 347 ASSERT_TRUE(rule.get()); 348 349 EXPECT_EQ(kExtensionId, rule->id().first); 350 EXPECT_EQ("rule1", rule->id().second); 351 352 EXPECT_EQ(200, rule->priority()); 353 354 const Rule::ConditionSet& condition_set = rule->conditions(); 355 const Rule::ConditionSet::Conditions& conditions = 356 condition_set.conditions(); 357 ASSERT_EQ(2u, conditions.size()); 358 EXPECT_EQ(3, conditions[0]->max_value); 359 EXPECT_EQ(5, conditions[1]->max_value); 360 361 const Rule::ActionSet& action_set = rule->actions(); 362 const Rule::ActionSet::Actions& actions = action_set.actions(); 363 ASSERT_EQ(1u, actions.size()); 364 EXPECT_EQ(2, actions[0]->increment()); 365 366 int sum = 0; 367 rule->Apply(&sum); 368 EXPECT_EQ(2, sum); 369 } 370 371 bool AtLeastOneCondition( 372 const DeclarativeConditionSet<FulfillableCondition>* conditions, 373 const DeclarativeActionSet<SummingAction>* actions, 374 std::string* error) { 375 if (conditions->conditions().empty()) { 376 *error = "No conditions"; 377 return false; 378 } 379 return true; 380 } 381 382 TEST(DeclarativeRuleTest, CheckConsistency) { 383 typedef DeclarativeRule<FulfillableCondition, SummingAction> Rule; 384 URLMatcher matcher; 385 std::string error; 386 linked_ptr<Rule::JsonRule> json_rule(new Rule::JsonRule); 387 const char kExtensionId[] = "ext1"; 388 scoped_refptr<Extension> extension = ExtensionBuilder() 389 .SetManifest(SimpleManifest()) 390 .SetID(kExtensionId) 391 .Build(); 392 393 ASSERT_TRUE(Rule::JsonRule::Populate( 394 *ParseJson("{ \n" 395 " \"id\": \"rule1\", \n" 396 " \"conditions\": [ \n" 397 " {\"url_id\": 1, \"max\": 3}, \n" 398 " {\"url_id\": 2, \"max\": 5}, \n" 399 " ], \n" 400 " \"actions\": [ \n" 401 " { \n" 402 " \"value\": 2 \n" 403 " } \n" 404 " ], \n" 405 " \"priority\": 200 \n" 406 "}"), 407 json_rule.get())); 408 scoped_ptr<Rule> rule(Rule::Create(matcher.condition_factory(), 409 NULL, 410 extension.get(), 411 base::Time(), 412 json_rule, 413 base::Bind(AtLeastOneCondition), 414 &error)); 415 EXPECT_TRUE(rule); 416 EXPECT_EQ("", error); 417 418 ASSERT_TRUE(Rule::JsonRule::Populate( 419 *ParseJson("{ \n" 420 " \"id\": \"rule1\", \n" 421 " \"conditions\": [ \n" 422 " ], \n" 423 " \"actions\": [ \n" 424 " { \n" 425 " \"value\": 2 \n" 426 " } \n" 427 " ], \n" 428 " \"priority\": 200 \n" 429 "}"), 430 json_rule.get())); 431 rule = Rule::Create(matcher.condition_factory(), 432 NULL, 433 extension.get(), 434 base::Time(), 435 json_rule, 436 base::Bind(AtLeastOneCondition), 437 &error); 438 EXPECT_FALSE(rule); 439 EXPECT_EQ("No conditions", error); 440 } 441 442 } // namespace extensions 443