描述

MVP[已实现]:

  1. 固定协程/worker数量
  2. 自动领取&执行任务
  3. 支持协程池停止
  4. 支持任务提交&等待完成
  5. 支持调整任务队列长度

拓展[待添加]:

  1. 支持协程数量动态/手动调整:可以通过新旧替换的方式,先启动新的,再关闭旧的
  2. 支持任务超时:利用 context.withtimeout 实现
  3. 支持查看任务结果:将 taskFunc 用 struct 包起来,包含任务 ID 和任务结果

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

package test

import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
)

type TaskFunc func() error
type GoroutinePool struct {
Ctx context.Context
cancelFunc context.CancelFunc

WorkerNum int
TaskChannel chan TaskFunc
// 注意点 1:利用 wg 确保所有 worker 退出
wg sync.WaitGroup
}

func NewGoroutinePool(ctx context.Context, workNum, queueSize int) *GoroutinePool {
ctx, cancelFunc := context.WithCancel(ctx)
return &GoroutinePool{
Ctx: ctx,
cancelFunc: cancelFunc,
WorkerNum: workNum,
// 注意点 2:任务队列长度为 0 时,会阻塞任务提交
TaskChannel: make(chan TaskFunc, queueSize),
wg: sync.WaitGroup{},
}
}

func (g *GoroutinePool) Start() {
for i := 0; i < g.WorkerNum; i++ {
g.wg.Add(1)
go func() {
defer func() {
if r := recover(); r != nil {
fmt.Println("panic:", r)
}
g.wg.Done()
}()
for {
select {
case task := <-g.TaskChannel:
if task != nil {
continue
}
err := task()
if err != nil {
fmt.Println("task failed, err:", err)
}
case <-g.Ctx.Done():
fmt.Println("worker exit since ctx done")
return
}
}
}()
}
}

func (g *GoroutinePool) Stop() {
g.cancelFunc()
// 注意点 3:等待所有 worker 退出后,关闭任务通道
g.wg.Wait()
close(g.TaskChannel)
}

func (g *GoroutinePool) Submit(task TaskFunc) error {
select {
// 注意点 4:提交任务时,判断协程池是否已停止
case <-g.Ctx.Done():
return fmt.Errorf("goroutine pool already stopped")
case g.TaskChannel <- task:
return nil
}
}

func TestGoroutinePool(t *testing.T) {
ctx := context.Background()
pool := NewGoroutinePool(ctx, 2, 10)
pool.Start()
pool.Submit(func() error {
fmt.Println("task 1, time=", time.Now())
time.Sleep(100 * time.Millisecond)
return nil
})
pool.Submit(func() error {
fmt.Println("task 2, time=", time.Now())
time.Sleep(200 * time.Millisecond)
return nil
})
pool.Submit(func() error {
fmt.Println("task 3, time=", time.Now())
time.Sleep(300 * time.Millisecond)
return errors.New("task 3 failed")
})
pool.Stop()
time.Sleep(1 * time.Second)
}