mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-25 16:25:16 +00:00
all: refactor depthmaker with connectivity
This commit is contained in:
parent
ccb617f30f
commit
c07661af57
|
@ -267,9 +267,11 @@ type Strategy struct {
|
|||
|
||||
lastSourcePrice fixedpoint.MutexValue
|
||||
|
||||
stopC, authedC chan struct{}
|
||||
stopC chan struct{}
|
||||
|
||||
logger logrus.FieldLogger
|
||||
|
||||
connectivityGroup *types.ConnectivityGroup
|
||||
}
|
||||
|
||||
func (s *Strategy) ID() string {
|
||||
|
@ -559,9 +561,14 @@ func (s *Strategy) CrossRun(
|
|||
|
||||
s.stopC = make(chan struct{})
|
||||
|
||||
s.authedC = make(chan struct{}, 5)
|
||||
bindAuthSignal(ctx, s.makerSession.UserDataStream, s.authedC)
|
||||
bindAuthSignal(ctx, s.hedgeSession.UserDataStream, s.authedC)
|
||||
makerConnectivity := types.NewConnectivity()
|
||||
makerConnectivity.Bind(s.makerSession.UserDataStream)
|
||||
|
||||
hedgerConnectivity := types.NewConnectivity()
|
||||
hedgerConnectivity.Bind(s.hedgeSession.UserDataStream)
|
||||
|
||||
connGroup := types.NewConnectivityGroup(makerConnectivity, hedgerConnectivity)
|
||||
s.connectivityGroup = connGroup
|
||||
|
||||
if s.RecoverTrade {
|
||||
go s.runTradeRecover(ctx)
|
||||
|
@ -569,16 +576,11 @@ func (s *Strategy) CrossRun(
|
|||
|
||||
go func() {
|
||||
log.Infof("waiting for user data stream to get authenticated")
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-s.authedC:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-s.authedC:
|
||||
case <-connGroup.AllAuthedC(ctx, time.Minute):
|
||||
}
|
||||
|
||||
log.Infof("user data stream authenticated, start placing orders...")
|
||||
|
@ -1165,14 +1167,3 @@ func min(a, b int) int {
|
|||
|
||||
return b
|
||||
}
|
||||
|
||||
func bindAuthSignal(ctx context.Context, stream types.Stream, c chan<- struct{}) {
|
||||
stream.OnAuth(func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case c <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -24,16 +24,27 @@ func (g *ConnectivityGroup) Add(con *Connectivity) {
|
|||
g.connections = append(g.connections, con)
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}) {
|
||||
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}, allTimeoutDuration time.Duration) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
authedConns := make([]bool, len(g.connections))
|
||||
allTimeout := time.After(allTimeoutDuration)
|
||||
for {
|
||||
for idx, con := range g.connections {
|
||||
timeout := time.After(3 * time.Second)
|
||||
// if the connection is not authed, mark it as false
|
||||
if !con.authed {
|
||||
// authedConns[idx] = false
|
||||
}
|
||||
|
||||
timeout := time.After(3 * time.Second)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case <-allTimeout:
|
||||
return
|
||||
|
||||
case <-timeout:
|
||||
continue
|
||||
|
||||
|
@ -49,9 +60,12 @@ func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{})
|
|||
}
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context) <-chan struct{} {
|
||||
// AllAuthedC returns a channel that will be closed when all connections are authenticated
|
||||
// the returned channel will be closed when all connections are authenticated
|
||||
// and the channel can only be used once (because we can't close a channel twice)
|
||||
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} {
|
||||
c := make(chan struct{})
|
||||
go g.waitAllAuthed(ctx, c)
|
||||
go g.waitAllAuthed(ctx, c, timeout)
|
||||
return c
|
||||
}
|
||||
|
||||
|
@ -105,15 +119,6 @@ func (c *Connectivity) handleDisconnect() {
|
|||
defer c.mu.Unlock()
|
||||
|
||||
c.connected = false
|
||||
|
||||
if c.connectedC != nil {
|
||||
close(c.connectedC)
|
||||
}
|
||||
|
||||
if c.authedC != nil {
|
||||
close(c.authedC)
|
||||
}
|
||||
|
||||
c.authedC = make(chan struct{})
|
||||
c.connectedC = make(chan struct{})
|
||||
close(c.disconnectedC)
|
||||
|
|
|
@ -4,38 +4,103 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConnectivityGroup(t *testing.T) {
|
||||
func TestConnectivity(t *testing.T) {
|
||||
t.Run("general", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.handleConnect()
|
||||
conn1.handleAuth()
|
||||
conn1.handleDisconnect()
|
||||
})
|
||||
|
||||
t.Run("reconnect", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.handleConnect()
|
||||
conn1.handleAuth()
|
||||
conn1.handleDisconnect()
|
||||
|
||||
conn1.handleConnect()
|
||||
conn1.handleAuth()
|
||||
conn1.handleDisconnect()
|
||||
})
|
||||
|
||||
t.Run("no-auth reconnect", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.handleConnect()
|
||||
conn1.handleDisconnect()
|
||||
|
||||
conn1.handleConnect()
|
||||
conn1.handleDisconnect()
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectivityGroupAuthC(t *testing.T) {
|
||||
timeout := 100 * time.Millisecond
|
||||
|
||||
ctx := context.Background()
|
||||
conn1 := NewConnectivity()
|
||||
conn2 := NewConnectivity()
|
||||
group := NewConnectivityGroup(conn1, conn2)
|
||||
allAuthedC := group.AllAuthedC(ctx)
|
||||
allAuthedC := group.AllAuthedC(ctx, time.Second)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
time.Sleep(timeout)
|
||||
conn1.handleConnect()
|
||||
waitSigChan(t, conn1.ConnectedC(), time.Second)
|
||||
assert.True(t, waitSigChan(conn1.ConnectedC(), timeout))
|
||||
conn1.handleAuth()
|
||||
waitSigChan(t, conn1.AuthedC(), time.Second)
|
||||
assert.True(t, waitSigChan(conn1.AuthedC(), timeout))
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
time.Sleep(timeout)
|
||||
conn2.handleConnect()
|
||||
waitSigChan(t, conn2.ConnectedC(), time.Second)
|
||||
assert.True(t, waitSigChan(conn2.ConnectedC(), timeout))
|
||||
|
||||
conn2.handleAuth()
|
||||
waitSigChan(t, conn2.AuthedC(), time.Second)
|
||||
assert.True(t, waitSigChan(conn2.AuthedC(), timeout))
|
||||
|
||||
waitSigChan(t, allAuthedC, time.Second)
|
||||
assert.True(t, waitSigChan(allAuthedC, timeout))
|
||||
}
|
||||
|
||||
func waitSigChan(t *testing.T, c <-chan struct{}, timeoutDuration time.Duration) {
|
||||
func TestConnectivityGroupReconnect(t *testing.T) {
|
||||
timeout := 100 * time.Millisecond
|
||||
delay := timeout * 2
|
||||
|
||||
ctx := context.Background()
|
||||
conn1 := NewConnectivity()
|
||||
conn2 := NewConnectivity()
|
||||
group := NewConnectivityGroup(conn1, conn2)
|
||||
|
||||
time.Sleep(delay)
|
||||
conn1.handleConnect()
|
||||
conn1.handleAuth()
|
||||
conn1authC := conn1.authedC
|
||||
|
||||
time.Sleep(delay)
|
||||
conn2.handleConnect()
|
||||
conn2.handleAuth()
|
||||
|
||||
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated")
|
||||
|
||||
// this should re-allocate authedC
|
||||
conn1.handleDisconnect()
|
||||
assert.NotEqual(t, conn1authC, conn1.authedC)
|
||||
|
||||
assert.False(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "one connection should be un-authed")
|
||||
|
||||
time.Sleep(delay)
|
||||
|
||||
conn1.handleConnect()
|
||||
conn1.handleAuth()
|
||||
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated, again")
|
||||
}
|
||||
|
||||
func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool {
|
||||
select {
|
||||
case <-time.After(timeoutDuration):
|
||||
t.Log("timeout")
|
||||
t.Fail()
|
||||
return false
|
||||
|
||||
case <-c:
|
||||
t.Log("signal received")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user