1 // Copyright (c) 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/browser/extensions/api/declarative/declarative_rule.h" 6 7 #include "base/bind.h" 8 #include "base/message_loop/message_loop.h" 9 #include "base/test/values_test_util.h" 10 #include "base/values.h" 11 #include "components/url_matcher/url_matcher_constants.h" 12 #include "extensions/common/extension_builder.h" 13 #include "testing/gmock/include/gmock/gmock.h" 14 #include "testing/gtest/include/gtest/gtest.h" 15 16 using base::test::ParseJson; 17 using url_matcher::URLMatcher; 18 using url_matcher::URLMatcherConditionFactory; 19 using url_matcher::URLMatcherConditionSet; 20 21 namespace extensions { 22 23 namespace { 24 25 template<typename T> 26 linked_ptr<T> ScopedToLinkedPtr(scoped_ptr<T> ptr) { 27 return linked_ptr<T>(ptr.release()); 28 } 29 30 scoped_ptr<base::DictionaryValue> SimpleManifest() { 31 return DictionaryBuilder() 32 .Set("name", "extension") 33 .Set("manifest_version", 2) 34 .Set("version", "1.0") 35 .Build(); 36 } 37 38 } // namespace 39 40 struct RecordingCondition { 41 typedef int MatchData; 42 43 URLMatcherConditionFactory* factory; 44 scoped_ptr<base::Value> value; 45 46 void GetURLMatcherConditionSets( 47 URLMatcherConditionSet::Vector* condition_sets) const { 48 // No condition sets. 49 } 50 51 static scoped_ptr<RecordingCondition> Create( 52 const Extension* extension, 53 URLMatcherConditionFactory* url_matcher_condition_factory, 54 const base::Value& condition, 55 std::string* error) { 56 const base::DictionaryValue* dict = NULL; 57 if (condition.GetAsDictionary(&dict) && dict->HasKey("bad_key")) { 58 *error = "Found error key"; 59 return scoped_ptr<RecordingCondition>(); 60 } 61 62 scoped_ptr<RecordingCondition> result(new RecordingCondition()); 63 result->factory = url_matcher_condition_factory; 64 result->value.reset(condition.DeepCopy()); 65 return result.Pass(); 66 } 67 }; 68 typedef DeclarativeConditionSet<RecordingCondition> RecordingConditionSet; 69 70 TEST(DeclarativeConditionTest, ErrorConditionSet) { 71 URLMatcher matcher; 72 RecordingConditionSet::AnyVector conditions; 73 conditions.push_back(ScopedToLinkedPtr(ParseJson("{\"key\": 1}"))); 74 conditions.push_back(ScopedToLinkedPtr(ParseJson("{\"bad_key\": 2}"))); 75 76 std::string error; 77 scoped_ptr<RecordingConditionSet> result = RecordingConditionSet::Create( 78 NULL, matcher.condition_factory(), conditions, &error); 79 EXPECT_EQ("Found error key", error); 80 ASSERT_FALSE(result); 81 } 82 83 TEST(DeclarativeConditionTest, CreateConditionSet) { 84 URLMatcher matcher; 85 RecordingConditionSet::AnyVector conditions; 86 conditions.push_back(ScopedToLinkedPtr(ParseJson("{\"key\": 1}"))); 87 conditions.push_back(ScopedToLinkedPtr(ParseJson("[\"val1\", 2]"))); 88 89 // Test insertion 90 std::string error; 91 scoped_ptr<RecordingConditionSet> result = RecordingConditionSet::Create( 92 NULL, matcher.condition_factory(), conditions, &error); 93 EXPECT_EQ("", error); 94 ASSERT_TRUE(result); 95 EXPECT_EQ(2u, result->conditions().size()); 96 97 EXPECT_EQ(matcher.condition_factory(), result->conditions()[0]->factory); 98 EXPECT_TRUE(ParseJson("{\"key\": 1}")->Equals( 99 result->conditions()[0]->value.get())); 100 } 101 102 struct FulfillableCondition { 103 struct MatchData { 104 int value; 105 const std::set<URLMatcherConditionSet::ID>& url_matches; 106 }; 107 108 scoped_refptr<URLMatcherConditionSet> condition_set; 109 int condition_set_id; 110 int max_value; 111 112 URLMatcherConditionSet::ID url_matcher_condition_set_id() const { 113 return condition_set_id; 114 } 115 116 scoped_refptr<URLMatcherConditionSet> url_matcher_condition_set() const { 117 return condition_set; 118 } 119 120 void GetURLMatcherConditionSets( 121 URLMatcherConditionSet::Vector* condition_sets) const { 122 if (condition_set.get()) 123 condition_sets->push_back(condition_set); 124 } 125 126 bool IsFulfilled(const MatchData& match_data) const { 127 if (condition_set_id != -1 && 128 !ContainsKey(match_data.url_matches, condition_set_id)) 129 return false; 130 return match_data.value <= max_value; 131 } 132 133 static scoped_ptr<FulfillableCondition> Create( 134 const Extension* extension, 135 URLMatcherConditionFactory* url_matcher_condition_factory, 136 const base::Value& condition, 137 std::string* error) { 138 scoped_ptr<FulfillableCondition> result(new FulfillableCondition()); 139 const base::DictionaryValue* dict; 140 if (!condition.GetAsDictionary(&dict)) { 141 *error = "Expected dict"; 142 return result.Pass(); 143 } 144 if (!dict->GetInteger("url_id", &result->condition_set_id)) 145 result->condition_set_id = -1; 146 if (!dict->GetInteger("max", &result->max_value)) 147 *error = "Expected integer at ['max']"; 148 if (result->condition_set_id != -1) { 149 result->condition_set = new URLMatcherConditionSet( 150 result->condition_set_id, 151 URLMatcherConditionSet::Conditions()); 152 } 153 return result.Pass(); 154 } 155 }; 156 157 TEST(DeclarativeConditionTest, FulfillConditionSet) { 158 typedef DeclarativeConditionSet<FulfillableCondition> FulfillableConditionSet; 159 FulfillableConditionSet::AnyVector conditions; 160 conditions.push_back(ScopedToLinkedPtr(ParseJson( 161 "{\"url_id\": 1, \"max\": 3}"))); 162 conditions.push_back(ScopedToLinkedPtr(ParseJson( 163 "{\"url_id\": 2, \"max\": 5}"))); 164 conditions.push_back(ScopedToLinkedPtr(ParseJson( 165 "{\"url_id\": 3, \"max\": 1}"))); 166 conditions.push_back(ScopedToLinkedPtr(ParseJson( 167 "{\"max\": -5}"))); // No url. 168 169 // Test insertion 170 std::string error; 171 scoped_ptr<FulfillableConditionSet> result = 172 FulfillableConditionSet::Create(NULL, NULL, conditions, &error); 173 ASSERT_EQ("", error); 174 ASSERT_TRUE(result); 175 EXPECT_EQ(4u, result->conditions().size()); 176 177 std::set<URLMatcherConditionSet::ID> url_matches; 178 FulfillableCondition::MatchData match_data = { 0, url_matches }; 179 EXPECT_FALSE(result->IsFulfilled(1, match_data)) 180 << "Testing an ID that's not in url_matches forwards to the Condition, " 181 << "which doesn't match."; 182 EXPECT_FALSE(result->IsFulfilled(-1, match_data)) 183 << "Testing the 'no ID' value tries to match the 4th condition, but " 184 << "its max is too low."; 185 match_data.value = -5; 186 EXPECT_TRUE(result->IsFulfilled(-1, match_data)) 187 << "Testing the 'no ID' value tries to match the 4th condition, and " 188 << "this value is low enough."; 189 190 url_matches.insert(1); 191 match_data.value = 3; 192 EXPECT_TRUE(result->IsFulfilled(1, match_data)) 193 << "Tests a condition with a url matcher, for a matching value."; 194 match_data.value = 4; 195 EXPECT_FALSE(result->IsFulfilled(1, match_data)) 196 << "Tests a condition with a url matcher, for a non-matching value " 197 << "that would match a different condition."; 198 url_matches.insert(2); 199 EXPECT_TRUE(result->IsFulfilled(2, match_data)) 200 << "Tests with 2 elements in the match set."; 201 202 // Check the condition sets: 203 URLMatcherConditionSet::Vector condition_sets; 204 result->GetURLMatcherConditionSets(&condition_sets); 205 ASSERT_EQ(3U, condition_sets.size()); 206 EXPECT_EQ(1, condition_sets[0]->id()); 207 EXPECT_EQ(2, condition_sets[1]->id()); 208 EXPECT_EQ(3, condition_sets[2]->id()); 209 } 210 211 // DeclarativeAction 212 213 class SummingAction : public base::RefCounted<SummingAction> { 214 public: 215 typedef int ApplyInfo; 216 217 SummingAction(int increment, int min_priority) 218 : increment_(increment), min_priority_(min_priority) {} 219 220 static scoped_refptr<const SummingAction> Create(const Extension* extension, 221 const base::Value& action, 222 std::string* error, 223 bool* bad_message) { 224 int increment = 0; 225 int min_priority = 0; 226 const base::DictionaryValue* dict = NULL; 227 EXPECT_TRUE(action.GetAsDictionary(&dict)); 228 if (dict->HasKey("error")) { 229 EXPECT_TRUE(dict->GetString("error", error)); 230 return scoped_refptr<const SummingAction>(NULL); 231 } 232 if (dict->HasKey("bad")) { 233 *bad_message = true; 234 return scoped_refptr<const SummingAction>(NULL); 235 } 236 237 EXPECT_TRUE(dict->GetInteger("value", &increment)); 238 dict->GetInteger("priority", &min_priority); 239 return scoped_refptr<const SummingAction>( 240 new SummingAction(increment, min_priority)); 241 } 242 243 void Apply(const std::string& extension_id, 244 const base::Time& install_time, 245 int* sum) const { 246 *sum += increment_; 247 } 248 249 int increment() const { return increment_; } 250 int minimum_priority() const { 251 return min_priority_; 252 } 253 254 private: 255 friend class base::RefCounted<SummingAction>; 256 virtual ~SummingAction() {} 257 258 int increment_; 259 int min_priority_; 260 }; 261 typedef DeclarativeActionSet<SummingAction> SummingActionSet; 262 263 TEST(DeclarativeActionTest, ErrorActionSet) { 264 SummingActionSet::AnyVector actions; 265 actions.push_back(ScopedToLinkedPtr(ParseJson("{\"value\": 1}"))); 266 actions.push_back(ScopedToLinkedPtr(ParseJson("{\"error\": \"the error\"}"))); 267 268 std::string error; 269 bool bad = false; 270 scoped_ptr<SummingActionSet> result = 271 SummingActionSet::Create(NULL, actions, &error, &bad); 272 EXPECT_EQ("the error", error); 273 EXPECT_FALSE(bad); 274 EXPECT_FALSE(result); 275 276 actions.clear(); 277 actions.push_back(ScopedToLinkedPtr(ParseJson("{\"value\": 1}"))); 278 actions.push_back(ScopedToLinkedPtr(ParseJson("{\"bad\": 3}"))); 279 result = SummingActionSet::Create(NULL, actions, &error, &bad); 280 EXPECT_EQ("", error); 281 EXPECT_TRUE(bad); 282 EXPECT_FALSE(result); 283 } 284 285 TEST(DeclarativeActionTest, ApplyActionSet) { 286 SummingActionSet::AnyVector actions; 287 actions.push_back(ScopedToLinkedPtr(ParseJson( 288 "{\"value\": 1," 289 " \"priority\": 5}"))); 290 actions.push_back(ScopedToLinkedPtr(ParseJson("{\"value\": 2}"))); 291 292 // Test insertion 293 std::string error; 294 bool bad = false; 295 scoped_ptr<SummingActionSet> result = 296 SummingActionSet::Create(NULL, actions, &error, &bad); 297 EXPECT_EQ("", error); 298 EXPECT_FALSE(bad); 299 ASSERT_TRUE(result); 300 EXPECT_EQ(2u, result->actions().size()); 301 302 int sum = 0; 303 result->Apply("ext_id", base::Time(), &sum); 304 EXPECT_EQ(3, sum); 305 EXPECT_EQ(5, result->GetMinimumPriority()); 306 } 307 308 TEST(DeclarativeRuleTest, Create) { 309 typedef DeclarativeRule<FulfillableCondition, SummingAction> Rule; 310 linked_ptr<Rule::JsonRule> json_rule(new Rule::JsonRule); 311 ASSERT_TRUE(Rule::JsonRule::Populate( 312 *ParseJson("{ \n" 313 " \"id\": \"rule1\", \n" 314 " \"conditions\": [ \n" 315 " {\"url_id\": 1, \"max\": 3}, \n" 316 " {\"url_id\": 2, \"max\": 5}, \n" 317 " ], \n" 318 " \"actions\": [ \n" 319 " { \n" 320 " \"value\": 2 \n" 321 " } \n" 322 " ], \n" 323 " \"priority\": 200 \n" 324 "}"), 325 json_rule.get())); 326 327 const char kExtensionId[] = "ext1"; 328 scoped_refptr<Extension> extension = ExtensionBuilder() 329 .SetManifest(SimpleManifest()) 330 .SetID(kExtensionId) 331 .Build(); 332 333 base::Time install_time = base::Time::Now(); 334 335 URLMatcher matcher; 336 std::string error; 337 scoped_ptr<Rule> rule(Rule::Create(matcher.condition_factory(), 338 extension.get(), 339 install_time, 340 json_rule, 341 Rule::ConsistencyChecker(), 342 &error)); 343 EXPECT_EQ("", error); 344 ASSERT_TRUE(rule.get()); 345 346 EXPECT_EQ(kExtensionId, rule->id().first); 347 EXPECT_EQ("rule1", rule->id().second); 348 349 EXPECT_EQ(200, rule->priority()); 350 351 const Rule::ConditionSet& condition_set = rule->conditions(); 352 const Rule::ConditionSet::Conditions& conditions = 353 condition_set.conditions(); 354 ASSERT_EQ(2u, conditions.size()); 355 EXPECT_EQ(3, conditions[0]->max_value); 356 EXPECT_EQ(5, conditions[1]->max_value); 357 358 const Rule::ActionSet& action_set = rule->actions(); 359 const Rule::ActionSet::Actions& actions = action_set.actions(); 360 ASSERT_EQ(1u, actions.size()); 361 EXPECT_EQ(2, actions[0]->increment()); 362 363 int sum = 0; 364 rule->Apply(&sum); 365 EXPECT_EQ(2, sum); 366 } 367 368 bool AtLeastOneCondition( 369 const DeclarativeConditionSet<FulfillableCondition>* conditions, 370 const DeclarativeActionSet<SummingAction>* actions, 371 std::string* error) { 372 if (conditions->conditions().empty()) { 373 *error = "No conditions"; 374 return false; 375 } 376 return true; 377 } 378 379 TEST(DeclarativeRuleTest, CheckConsistency) { 380 typedef DeclarativeRule<FulfillableCondition, SummingAction> Rule; 381 URLMatcher matcher; 382 std::string error; 383 linked_ptr<Rule::JsonRule> json_rule(new Rule::JsonRule); 384 const char kExtensionId[] = "ext1"; 385 scoped_refptr<Extension> extension = ExtensionBuilder() 386 .SetManifest(SimpleManifest()) 387 .SetID(kExtensionId) 388 .Build(); 389 390 ASSERT_TRUE(Rule::JsonRule::Populate( 391 *ParseJson("{ \n" 392 " \"id\": \"rule1\", \n" 393 " \"conditions\": [ \n" 394 " {\"url_id\": 1, \"max\": 3}, \n" 395 " {\"url_id\": 2, \"max\": 5}, \n" 396 " ], \n" 397 " \"actions\": [ \n" 398 " { \n" 399 " \"value\": 2 \n" 400 " } \n" 401 " ], \n" 402 " \"priority\": 200 \n" 403 "}"), 404 json_rule.get())); 405 scoped_ptr<Rule> rule(Rule::Create(matcher.condition_factory(), 406 extension.get(), 407 base::Time(), 408 json_rule, 409 base::Bind(AtLeastOneCondition), 410 &error)); 411 EXPECT_TRUE(rule); 412 EXPECT_EQ("", error); 413 414 ASSERT_TRUE(Rule::JsonRule::Populate( 415 *ParseJson("{ \n" 416 " \"id\": \"rule1\", \n" 417 " \"conditions\": [ \n" 418 " ], \n" 419 " \"actions\": [ \n" 420 " { \n" 421 " \"value\": 2 \n" 422 " } \n" 423 " ], \n" 424 " \"priority\": 200 \n" 425 "}"), 426 json_rule.get())); 427 rule = Rule::Create(matcher.condition_factory(), 428 extension.get(), 429 base::Time(), 430 json_rule, 431 base::Bind(AtLeastOneCondition), 432 &error); 433 EXPECT_FALSE(rule); 434 EXPECT_EQ("No conditions", error); 435 } 436 437 } // namespace extensions 438