1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
| package test
import ( "context" "errors" "fmt" "sync" "testing" "time" )
type TaskFunc func() error type GoroutinePool struct { Ctx context.Context cancelFunc context.CancelFunc
WorkerNum int TaskChannel chan TaskFunc wg sync.WaitGroup }
func NewGoroutinePool(ctx context.Context, workNum, queueSize int) *GoroutinePool { ctx, cancelFunc := context.WithCancel(ctx) return &GoroutinePool{ Ctx: ctx, cancelFunc: cancelFunc, WorkerNum: workNum, TaskChannel: make(chan TaskFunc, queueSize), wg: sync.WaitGroup{}, } }
func (g *GoroutinePool) Start() { for i := 0; i < g.WorkerNum; i++ { g.wg.Add(1) go func() { defer func() { if r := recover(); r != nil { fmt.Println("panic:", r) } g.wg.Done() }() for { select { case task := <-g.TaskChannel: if task != nil { continue } err := task() if err != nil { fmt.Println("task failed, err:", err) } case <-g.Ctx.Done(): fmt.Println("worker exit since ctx done") return } } }() } }
func (g *GoroutinePool) Stop() { g.cancelFunc() g.wg.Wait() close(g.TaskChannel) }
func (g *GoroutinePool) Submit(task TaskFunc) error { select { case <-g.Ctx.Done(): return fmt.Errorf("goroutine pool already stopped") case g.TaskChannel <- task: return nil } }
func TestGoroutinePool(t *testing.T) { ctx := context.Background() pool := NewGoroutinePool(ctx, 2, 10) pool.Start() pool.Submit(func() error { fmt.Println("task 1, time=", time.Now()) time.Sleep(100 * time.Millisecond) return nil }) pool.Submit(func() error { fmt.Println("task 2, time=", time.Now()) time.Sleep(200 * time.Millisecond) return nil }) pool.Submit(func() error { fmt.Println("task 3, time=", time.Now()) time.Sleep(300 * time.Millisecond) return errors.New("task 3 failed") }) pool.Stop() time.Sleep(1 * time.Second) }
|