Home | History | Annotate | Download | only in singleflight
      1 // Copyright 2013 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package singleflight
      6 
      7 import (
      8 	"errors"
      9 	"fmt"
     10 	"sync"
     11 	"sync/atomic"
     12 	"testing"
     13 	"time"
     14 )
     15 
     16 func TestDo(t *testing.T) {
     17 	var g Group
     18 	v, err, _ := g.Do("key", func() (interface{}, error) {
     19 		return "bar", nil
     20 	})
     21 	if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
     22 		t.Errorf("Do = %v; want %v", got, want)
     23 	}
     24 	if err != nil {
     25 		t.Errorf("Do error = %v", err)
     26 	}
     27 }
     28 
     29 func TestDoErr(t *testing.T) {
     30 	var g Group
     31 	someErr := errors.New("Some error")
     32 	v, err, _ := g.Do("key", func() (interface{}, error) {
     33 		return nil, someErr
     34 	})
     35 	if err != someErr {
     36 		t.Errorf("Do error = %v; want someErr %v", err, someErr)
     37 	}
     38 	if v != nil {
     39 		t.Errorf("unexpected non-nil value %#v", v)
     40 	}
     41 }
     42 
     43 func TestDoDupSuppress(t *testing.T) {
     44 	var g Group
     45 	var wg1, wg2 sync.WaitGroup
     46 	c := make(chan string, 1)
     47 	var calls int32
     48 	fn := func() (interface{}, error) {
     49 		if atomic.AddInt32(&calls, 1) == 1 {
     50 			// First invocation.
     51 			wg1.Done()
     52 		}
     53 		v := <-c
     54 		c <- v // pump; make available for any future calls
     55 
     56 		time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
     57 
     58 		return v, nil
     59 	}
     60 
     61 	const n = 10
     62 	wg1.Add(1)
     63 	for i := 0; i < n; i++ {
     64 		wg1.Add(1)
     65 		wg2.Add(1)
     66 		go func() {
     67 			defer wg2.Done()
     68 			wg1.Done()
     69 			v, err, _ := g.Do("key", fn)
     70 			if err != nil {
     71 				t.Errorf("Do error: %v", err)
     72 				return
     73 			}
     74 			if s, _ := v.(string); s != "bar" {
     75 				t.Errorf("Do = %T %v; want %q", v, v, "bar")
     76 			}
     77 		}()
     78 	}
     79 	wg1.Wait()
     80 	// At least one goroutine is in fn now and all of them have at
     81 	// least reached the line before the Do.
     82 	c <- "bar"
     83 	wg2.Wait()
     84 	if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
     85 		t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
     86 	}
     87 }
     88