Home | History | Annotate | Download | only in lib
      1 #ifndef __UBOOT__
      2 #include <linux/kernel.h>
      3 #include <linux/module.h>
      4 #include <linux/slab.h>
      5 #else
      6 #include <linux/compat.h>
      7 #include <common.h>
      8 #include <malloc.h>
      9 #endif
     10 #include <linux/list.h>
     11 #include <linux/list_sort.h>
     12 
     13 #define MAX_LIST_LENGTH_BITS 20
     14 
     15 /*
     16  * Returns a list organized in an intermediate format suited
     17  * to chaining of merge() calls: null-terminated, no reserved or
     18  * sentinel head node, "prev" links not maintained.
     19  */
     20 static struct list_head *merge(void *priv,
     21 				int (*cmp)(void *priv, struct list_head *a,
     22 					struct list_head *b),
     23 				struct list_head *a, struct list_head *b)
     24 {
     25 	struct list_head head, *tail = &head;
     26 
     27 	while (a && b) {
     28 		/* if equal, take 'a' -- important for sort stability */
     29 		if ((*cmp)(priv, a, b) <= 0) {
     30 			tail->next = a;
     31 			a = a->next;
     32 		} else {
     33 			tail->next = b;
     34 			b = b->next;
     35 		}
     36 		tail = tail->next;
     37 	}
     38 	tail->next = a?:b;
     39 	return head.next;
     40 }
     41 
     42 /*
     43  * Combine final list merge with restoration of standard doubly-linked
     44  * list structure.  This approach duplicates code from merge(), but
     45  * runs faster than the tidier alternatives of either a separate final
     46  * prev-link restoration pass, or maintaining the prev links
     47  * throughout.
     48  */
     49 static void merge_and_restore_back_links(void *priv,
     50 				int (*cmp)(void *priv, struct list_head *a,
     51 					struct list_head *b),
     52 				struct list_head *head,
     53 				struct list_head *a, struct list_head *b)
     54 {
     55 	struct list_head *tail = head;
     56 
     57 	while (a && b) {
     58 		/* if equal, take 'a' -- important for sort stability */
     59 		if ((*cmp)(priv, a, b) <= 0) {
     60 			tail->next = a;
     61 			a->prev = tail;
     62 			a = a->next;
     63 		} else {
     64 			tail->next = b;
     65 			b->prev = tail;
     66 			b = b->next;
     67 		}
     68 		tail = tail->next;
     69 	}
     70 	tail->next = a ? : b;
     71 
     72 	do {
     73 		/*
     74 		 * In worst cases this loop may run many iterations.
     75 		 * Continue callbacks to the client even though no
     76 		 * element comparison is needed, so the client's cmp()
     77 		 * routine can invoke cond_resched() periodically.
     78 		 */
     79 		(*cmp)(priv, tail->next, tail->next);
     80 
     81 		tail->next->prev = tail;
     82 		tail = tail->next;
     83 	} while (tail->next);
     84 
     85 	tail->next = head;
     86 	head->prev = tail;
     87 }
     88 
     89 /**
     90  * list_sort - sort a list
     91  * @priv: private data, opaque to list_sort(), passed to @cmp
     92  * @head: the list to sort
     93  * @cmp: the elements comparison function
     94  *
     95  * This function implements "merge sort", which has O(nlog(n))
     96  * complexity.
     97  *
     98  * The comparison function @cmp must return a negative value if @a
     99  * should sort before @b, and a positive value if @a should sort after
    100  * @b. If @a and @b are equivalent, and their original relative
    101  * ordering is to be preserved, @cmp must return 0.
    102  */
    103 void list_sort(void *priv, struct list_head *head,
    104 		int (*cmp)(void *priv, struct list_head *a,
    105 			struct list_head *b))
    106 {
    107 	struct list_head *part[MAX_LIST_LENGTH_BITS+1]; /* sorted partial lists
    108 						-- last slot is a sentinel */
    109 	int lev;  /* index into part[] */
    110 	int max_lev = 0;
    111 	struct list_head *list;
    112 
    113 	if (list_empty(head))
    114 		return;
    115 
    116 	memset(part, 0, sizeof(part));
    117 
    118 	head->prev->next = NULL;
    119 	list = head->next;
    120 
    121 	while (list) {
    122 		struct list_head *cur = list;
    123 		list = list->next;
    124 		cur->next = NULL;
    125 
    126 		for (lev = 0; part[lev]; lev++) {
    127 			cur = merge(priv, cmp, part[lev], cur);
    128 			part[lev] = NULL;
    129 		}
    130 		if (lev > max_lev) {
    131 			if (unlikely(lev >= ARRAY_SIZE(part)-1)) {
    132 				printk_once(KERN_DEBUG "list passed to"
    133 					" list_sort() too long for"
    134 					" efficiency\n");
    135 				lev--;
    136 			}
    137 			max_lev = lev;
    138 		}
    139 		part[lev] = cur;
    140 	}
    141 
    142 	for (lev = 0; lev < max_lev; lev++)
    143 		if (part[lev])
    144 			list = merge(priv, cmp, part[lev], list);
    145 
    146 	merge_and_restore_back_links(priv, cmp, head, part[max_lev], list);
    147 }
    148 EXPORT_SYMBOL(list_sort);
    149 
    150 #ifdef CONFIG_TEST_LIST_SORT
    151 
    152 #include <linux/random.h>
    153 
    154 /*
    155  * The pattern of set bits in the list length determines which cases
    156  * are hit in list_sort().
    157  */
    158 #define TEST_LIST_LEN (512+128+2) /* not including head */
    159 
    160 #define TEST_POISON1 0xDEADBEEF
    161 #define TEST_POISON2 0xA324354C
    162 
    163 struct debug_el {
    164 	unsigned int poison1;
    165 	struct list_head list;
    166 	unsigned int poison2;
    167 	int value;
    168 	unsigned serial;
    169 };
    170 
    171 /* Array, containing pointers to all elements in the test list */
    172 static struct debug_el **elts __initdata;
    173 
    174 static int __init check(struct debug_el *ela, struct debug_el *elb)
    175 {
    176 	if (ela->serial >= TEST_LIST_LEN) {
    177 		printk(KERN_ERR "list_sort_test: error: incorrect serial %d\n",
    178 				ela->serial);
    179 		return -EINVAL;
    180 	}
    181 	if (elb->serial >= TEST_LIST_LEN) {
    182 		printk(KERN_ERR "list_sort_test: error: incorrect serial %d\n",
    183 				elb->serial);
    184 		return -EINVAL;
    185 	}
    186 	if (elts[ela->serial] != ela || elts[elb->serial] != elb) {
    187 		printk(KERN_ERR "list_sort_test: error: phantom element\n");
    188 		return -EINVAL;
    189 	}
    190 	if (ela->poison1 != TEST_POISON1 || ela->poison2 != TEST_POISON2) {
    191 		printk(KERN_ERR "list_sort_test: error: bad poison: %#x/%#x\n",
    192 				ela->poison1, ela->poison2);
    193 		return -EINVAL;
    194 	}
    195 	if (elb->poison1 != TEST_POISON1 || elb->poison2 != TEST_POISON2) {
    196 		printk(KERN_ERR "list_sort_test: error: bad poison: %#x/%#x\n",
    197 				elb->poison1, elb->poison2);
    198 		return -EINVAL;
    199 	}
    200 	return 0;
    201 }
    202 
    203 static int __init cmp(void *priv, struct list_head *a, struct list_head *b)
    204 {
    205 	struct debug_el *ela, *elb;
    206 
    207 	ela = container_of(a, struct debug_el, list);
    208 	elb = container_of(b, struct debug_el, list);
    209 
    210 	check(ela, elb);
    211 	return ela->value - elb->value;
    212 }
    213 
    214 static int __init list_sort_test(void)
    215 {
    216 	int i, count = 1, err = -EINVAL;
    217 	struct debug_el *el;
    218 	struct list_head *cur, *tmp;
    219 	LIST_HEAD(head);
    220 
    221 	printk(KERN_DEBUG "list_sort_test: start testing list_sort()\n");
    222 
    223 	elts = kmalloc(sizeof(void *) * TEST_LIST_LEN, GFP_KERNEL);
    224 	if (!elts) {
    225 		printk(KERN_ERR "list_sort_test: error: cannot allocate "
    226 				"memory\n");
    227 		goto exit;
    228 	}
    229 
    230 	for (i = 0; i < TEST_LIST_LEN; i++) {
    231 		el = kmalloc(sizeof(*el), GFP_KERNEL);
    232 		if (!el) {
    233 			printk(KERN_ERR "list_sort_test: error: cannot "
    234 					"allocate memory\n");
    235 			goto exit;
    236 		}
    237 		 /* force some equivalencies */
    238 		el->value = prandom_u32() % (TEST_LIST_LEN / 3);
    239 		el->serial = i;
    240 		el->poison1 = TEST_POISON1;
    241 		el->poison2 = TEST_POISON2;
    242 		elts[i] = el;
    243 		list_add_tail(&el->list, &head);
    244 	}
    245 
    246 	list_sort(NULL, &head, cmp);
    247 
    248 	for (cur = head.next; cur->next != &head; cur = cur->next) {
    249 		struct debug_el *el1;
    250 		int cmp_result;
    251 
    252 		if (cur->next->prev != cur) {
    253 			printk(KERN_ERR "list_sort_test: error: list is "
    254 					"corrupted\n");
    255 			goto exit;
    256 		}
    257 
    258 		cmp_result = cmp(NULL, cur, cur->next);
    259 		if (cmp_result > 0) {
    260 			printk(KERN_ERR "list_sort_test: error: list is not "
    261 					"sorted\n");
    262 			goto exit;
    263 		}
    264 
    265 		el = container_of(cur, struct debug_el, list);
    266 		el1 = container_of(cur->next, struct debug_el, list);
    267 		if (cmp_result == 0 && el->serial >= el1->serial) {
    268 			printk(KERN_ERR "list_sort_test: error: order of "
    269 					"equivalent elements not preserved\n");
    270 			goto exit;
    271 		}
    272 
    273 		if (check(el, el1)) {
    274 			printk(KERN_ERR "list_sort_test: error: element check "
    275 					"failed\n");
    276 			goto exit;
    277 		}
    278 		count++;
    279 	}
    280 
    281 	if (count != TEST_LIST_LEN) {
    282 		printk(KERN_ERR "list_sort_test: error: bad list length %d",
    283 				count);
    284 		goto exit;
    285 	}
    286 
    287 	err = 0;
    288 exit:
    289 	kfree(elts);
    290 	list_for_each_safe(cur, tmp, &head) {
    291 		list_del(cur);
    292 		kfree(container_of(cur, struct debug_el, list));
    293 	}
    294 	return err;
    295 }
    296 module_init(list_sort_test);
    297 #endif /* CONFIG_TEST_LIST_SORT */
    298