all: refactor depthmaker with connectivity

This commit is contained in:
c9s 2024-09-27 13:24:03 +08:00
parent ccb617f30f
commit c07661af57
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
3 changed files with 108 additions and 47 deletions

View File

@ -267,9 +267,11 @@ type Strategy struct {
lastSourcePrice fixedpoint.MutexValue
stopC, authedC chan struct{}
stopC chan struct{}
logger logrus.FieldLogger
connectivityGroup *types.ConnectivityGroup
}
func (s *Strategy) ID() string {
@ -559,9 +561,14 @@ func (s *Strategy) CrossRun(
s.stopC = make(chan struct{})
s.authedC = make(chan struct{}, 5)
bindAuthSignal(ctx, s.makerSession.UserDataStream, s.authedC)
bindAuthSignal(ctx, s.hedgeSession.UserDataStream, s.authedC)
makerConnectivity := types.NewConnectivity()
makerConnectivity.Bind(s.makerSession.UserDataStream)
hedgerConnectivity := types.NewConnectivity()
hedgerConnectivity.Bind(s.hedgeSession.UserDataStream)
connGroup := types.NewConnectivityGroup(makerConnectivity, hedgerConnectivity)
s.connectivityGroup = connGroup
if s.RecoverTrade {
go s.runTradeRecover(ctx)
@ -569,16 +576,11 @@ func (s *Strategy) CrossRun(
go func() {
log.Infof("waiting for user data stream to get authenticated")
select {
case <-ctx.Done():
return
case <-s.authedC:
}
select {
case <-ctx.Done():
return
case <-s.authedC:
case <-connGroup.AllAuthedC(ctx, time.Minute):
}
log.Infof("user data stream authenticated, start placing orders...")
@ -1165,14 +1167,3 @@ func min(a, b int) int {
return b
}
func bindAuthSignal(ctx context.Context, stream types.Stream, c chan<- struct{}) {
stream.OnAuth(func() {
select {
case <-ctx.Done():
return
case c <- struct{}{}:
default:
}
})
}

View File

@ -24,16 +24,27 @@ func (g *ConnectivityGroup) Add(con *Connectivity) {
g.connections = append(g.connections, con)
}
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}) {
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}, allTimeoutDuration time.Duration) {
g.mu.Lock()
defer g.mu.Unlock()
authedConns := make([]bool, len(g.connections))
allTimeout := time.After(allTimeoutDuration)
for {
for idx, con := range g.connections {
timeout := time.After(3 * time.Second)
// if the connection is not authed, mark it as false
if !con.authed {
// authedConns[idx] = false
}
timeout := time.After(3 * time.Second)
select {
case <-ctx.Done():
return
case <-allTimeout:
return
case <-timeout:
continue
@ -49,9 +60,12 @@ func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{})
}
}
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context) <-chan struct{} {
// AllAuthedC returns a channel that will be closed when all connections are authenticated
// the returned channel will be closed when all connections are authenticated
// and the channel can only be used once (because we can't close a channel twice)
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} {
c := make(chan struct{})
go g.waitAllAuthed(ctx, c)
go g.waitAllAuthed(ctx, c, timeout)
return c
}
@ -105,15 +119,6 @@ func (c *Connectivity) handleDisconnect() {
defer c.mu.Unlock()
c.connected = false
if c.connectedC != nil {
close(c.connectedC)
}
if c.authedC != nil {
close(c.authedC)
}
c.authedC = make(chan struct{})
c.connectedC = make(chan struct{})
close(c.disconnectedC)

View File

@ -4,38 +4,103 @@ import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestConnectivityGroup(t *testing.T) {
func TestConnectivity(t *testing.T) {
t.Run("general", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.handleConnect()
conn1.handleAuth()
conn1.handleDisconnect()
})
t.Run("reconnect", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.handleConnect()
conn1.handleAuth()
conn1.handleDisconnect()
conn1.handleConnect()
conn1.handleAuth()
conn1.handleDisconnect()
})
t.Run("no-auth reconnect", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.handleConnect()
conn1.handleDisconnect()
conn1.handleConnect()
conn1.handleDisconnect()
})
}
func TestConnectivityGroupAuthC(t *testing.T) {
timeout := 100 * time.Millisecond
ctx := context.Background()
conn1 := NewConnectivity()
conn2 := NewConnectivity()
group := NewConnectivityGroup(conn1, conn2)
allAuthedC := group.AllAuthedC(ctx)
allAuthedC := group.AllAuthedC(ctx, time.Second)
time.Sleep(100 * time.Millisecond)
time.Sleep(timeout)
conn1.handleConnect()
waitSigChan(t, conn1.ConnectedC(), time.Second)
assert.True(t, waitSigChan(conn1.ConnectedC(), timeout))
conn1.handleAuth()
waitSigChan(t, conn1.AuthedC(), time.Second)
assert.True(t, waitSigChan(conn1.AuthedC(), timeout))
time.Sleep(100 * time.Millisecond)
time.Sleep(timeout)
conn2.handleConnect()
waitSigChan(t, conn2.ConnectedC(), time.Second)
assert.True(t, waitSigChan(conn2.ConnectedC(), timeout))
conn2.handleAuth()
waitSigChan(t, conn2.AuthedC(), time.Second)
assert.True(t, waitSigChan(conn2.AuthedC(), timeout))
waitSigChan(t, allAuthedC, time.Second)
assert.True(t, waitSigChan(allAuthedC, timeout))
}
func waitSigChan(t *testing.T, c <-chan struct{}, timeoutDuration time.Duration) {
func TestConnectivityGroupReconnect(t *testing.T) {
timeout := 100 * time.Millisecond
delay := timeout * 2
ctx := context.Background()
conn1 := NewConnectivity()
conn2 := NewConnectivity()
group := NewConnectivityGroup(conn1, conn2)
time.Sleep(delay)
conn1.handleConnect()
conn1.handleAuth()
conn1authC := conn1.authedC
time.Sleep(delay)
conn2.handleConnect()
conn2.handleAuth()
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated")
// this should re-allocate authedC
conn1.handleDisconnect()
assert.NotEqual(t, conn1authC, conn1.authedC)
assert.False(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "one connection should be un-authed")
time.Sleep(delay)
conn1.handleConnect()
conn1.handleAuth()
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated, again")
}
func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool {
select {
case <-time.After(timeoutDuration):
t.Log("timeout")
t.Fail()
return false
case <-c:
t.Log("signal received")
return true
}
}