mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-24 07:45:15 +00:00
107 lines
2.4 KiB
Go
107 lines
2.4 KiB
Go
package util
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func addAndCheck(a *int, target int) error {
|
|
if *a++; *a == target {
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("a is not %v. It is %v\n", target, *a)
|
|
}
|
|
}
|
|
|
|
func TestRetry(t *testing.T) {
|
|
type test struct {
|
|
input int
|
|
targetNum int
|
|
ans int
|
|
ansErr error
|
|
}
|
|
tests := []test{
|
|
{input: 0, targetNum: 3, ans: 3, ansErr: nil},
|
|
{input: 0, targetNum: 10, ans: 3, ansErr: errors.New("failed in retry")},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
errHandled := false
|
|
|
|
err := Retry(context.Background(), 3, 1*time.Second, func() error {
|
|
return addAndCheck(&tc.input, tc.targetNum)
|
|
}, func(e error) { errHandled = true })
|
|
|
|
assert.Equal(t, true, errHandled)
|
|
if tc.ansErr == nil {
|
|
assert.NoError(t, err)
|
|
} else {
|
|
assert.Contains(t, err.Error(), tc.ansErr.Error())
|
|
}
|
|
assert.Equal(t, tc.ans, tc.input)
|
|
}
|
|
}
|
|
|
|
func TestRetryWithPredicator(t *testing.T) {
|
|
type test struct {
|
|
count int
|
|
f func() error
|
|
errHandler func(error)
|
|
predicator RetryPredicator
|
|
ansCount int
|
|
ansErr error
|
|
}
|
|
knownErr := errors.New("Duplicate entry '1-389837488-1' for key 'UNI_Trade'")
|
|
unknownErr := errors.New("Some Error")
|
|
tests := []test{
|
|
{
|
|
predicator: func(err error) bool {
|
|
return !strings.Contains(err.Error(), "Duplicate entry")
|
|
},
|
|
f: func() error { return knownErr },
|
|
ansCount: 1,
|
|
ansErr: knownErr,
|
|
},
|
|
{
|
|
predicator: func(err error) bool {
|
|
return !strings.Contains(err.Error(), "Duplicate entry")
|
|
},
|
|
f: func() error { return unknownErr },
|
|
ansCount: 3,
|
|
ansErr: unknownErr,
|
|
},
|
|
}
|
|
attempts := 3
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
for _, tc := range tests {
|
|
err := Retry(ctx, attempts, 100*time.Millisecond, func() error {
|
|
tc.count++
|
|
return tc.f()
|
|
}, tc.errHandler, tc.predicator)
|
|
|
|
assert.Equal(t, tc.ansCount, tc.count)
|
|
assert.EqualError(t, errors.Cause(err), tc.ansErr.Error(), "should be equal")
|
|
}
|
|
}
|
|
|
|
func TestRetryCtxCancel(t *testing.T) {
|
|
result := int(0)
|
|
target := int(3)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
err := Retry(ctx, 5, 1*time.Second, func() error { return addAndCheck(&result, target) }, func(error) {})
|
|
assert.Error(t, err)
|
|
fmt.Println("Error:", err.Error())
|
|
assert.Equal(t, int(0), result)
|
|
}
|