mihomo/common/batch/batch.go

106 lines
1.7 KiB
Go
Raw Normal View History

package batch
import (
"context"
"sync"
)
2022-04-24 02:07:57 +08:00
type Option[T any] func(b *Batch[T])
2022-04-24 02:07:57 +08:00
type Result[T any] struct {
Value T
Err error
}
type Error struct {
Key string
Err error
}
2022-04-24 02:07:57 +08:00
func WithConcurrencyNum[T any](n int) Option[T] {
return func(b *Batch[T]) {
q := make(chan struct{}, n)
for i := 0; i < n; i++ {
q <- struct{}{}
}
b.queue = q
}
}
// Batch similar to errgroup, but can control the maximum number of concurrent
2022-04-24 02:07:57 +08:00
type Batch[T any] struct {
result map[string]Result[T]
queue chan struct{}
wg sync.WaitGroup
mux sync.Mutex
err *Error
once sync.Once
cancel func()
}
2022-04-24 02:07:57 +08:00
func (b *Batch[T]) Go(key string, fn func() (T, error)) {
b.wg.Add(1)
go func() {
defer b.wg.Done()
if b.queue != nil {
<-b.queue
defer func() {
b.queue <- struct{}{}
}()
}
value, err := fn()
if err != nil {
b.once.Do(func() {
b.err = &Error{key, err}
if b.cancel != nil {
b.cancel()
}
})
}
2022-04-24 02:07:57 +08:00
ret := Result[T]{value, err}
b.mux.Lock()
defer b.mux.Unlock()
b.result[key] = ret
}()
}
2022-04-24 02:07:57 +08:00
func (b *Batch[T]) Wait() *Error {
b.wg.Wait()
if b.cancel != nil {
b.cancel()
}
return b.err
}
2022-04-24 02:07:57 +08:00
func (b *Batch[T]) WaitAndGetResult() (map[string]Result[T], *Error) {
err := b.Wait()
return b.Result(), err
}
2022-04-24 02:07:57 +08:00
func (b *Batch[T]) Result() map[string]Result[T] {
b.mux.Lock()
defer b.mux.Unlock()
2022-04-24 02:07:57 +08:00
copyM := map[string]Result[T]{}
for k, v := range b.result {
2022-04-24 02:07:57 +08:00
copyM[k] = v
}
2022-04-24 02:07:57 +08:00
return copyM
}
2022-04-24 02:07:57 +08:00
func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
2021-07-23 00:30:23 +08:00
ctx, cancel := context.WithCancel(ctx)
2022-04-24 02:07:57 +08:00
b := &Batch[T]{
result: map[string]Result[T]{},
}
for _, o := range opts {
o(b)
}
b.cancel = cancel
return b, ctx
}