From c07661af57562df1eaa4b61a95c73c15728ff2ec Mon Sep 17 00:00:00 2001 From: c9s Date: Fri, 27 Sep 2024 13:24:03 +0800 Subject: [PATCH] all: refactor depthmaker with connectivity --- pkg/strategy/xdepthmaker/strategy.go | 33 ++++------ pkg/types/connectivity.go | 31 ++++++---- pkg/types/connectivity_test.go | 91 ++++++++++++++++++++++++---- 3 files changed, 108 insertions(+), 47 deletions(-) diff --git a/pkg/strategy/xdepthmaker/strategy.go b/pkg/strategy/xdepthmaker/strategy.go index 69e04b8ef..c568f7b3f 100644 --- a/pkg/strategy/xdepthmaker/strategy.go +++ b/pkg/strategy/xdepthmaker/strategy.go @@ -267,9 +267,11 @@ type Strategy struct { lastSourcePrice fixedpoint.MutexValue - stopC, authedC chan struct{} + stopC chan struct{} logger logrus.FieldLogger + + connectivityGroup *types.ConnectivityGroup } func (s *Strategy) ID() string { @@ -559,9 +561,14 @@ func (s *Strategy) CrossRun( s.stopC = make(chan struct{}) - s.authedC = make(chan struct{}, 5) - bindAuthSignal(ctx, s.makerSession.UserDataStream, s.authedC) - bindAuthSignal(ctx, s.hedgeSession.UserDataStream, s.authedC) + makerConnectivity := types.NewConnectivity() + makerConnectivity.Bind(s.makerSession.UserDataStream) + + hedgerConnectivity := types.NewConnectivity() + hedgerConnectivity.Bind(s.hedgeSession.UserDataStream) + + connGroup := types.NewConnectivityGroup(makerConnectivity, hedgerConnectivity) + s.connectivityGroup = connGroup if s.RecoverTrade { go s.runTradeRecover(ctx) @@ -569,16 +576,11 @@ func (s *Strategy) CrossRun( go func() { log.Infof("waiting for user data stream to get authenticated") - select { - case <-ctx.Done(): - return - case <-s.authedC: - } select { case <-ctx.Done(): return - case <-s.authedC: + case <-connGroup.AllAuthedC(ctx, time.Minute): } log.Infof("user data stream authenticated, start placing orders...") @@ -1165,14 +1167,3 @@ func min(a, b int) int { return b } - -func bindAuthSignal(ctx context.Context, stream types.Stream, c chan<- struct{}) { - stream.OnAuth(func() { - select { - case <-ctx.Done(): - return - case c <- struct{}{}: - default: - } - }) -} diff --git a/pkg/types/connectivity.go b/pkg/types/connectivity.go index 52c313dad..1745ba866 100644 --- a/pkg/types/connectivity.go +++ b/pkg/types/connectivity.go @@ -24,16 +24,27 @@ func (g *ConnectivityGroup) Add(con *Connectivity) { g.connections = append(g.connections, con) } -func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}) { +func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}, allTimeoutDuration time.Duration) { + g.mu.Lock() + defer g.mu.Unlock() + authedConns := make([]bool, len(g.connections)) + allTimeout := time.After(allTimeoutDuration) for { for idx, con := range g.connections { - timeout := time.After(3 * time.Second) + // if the connection is not authed, mark it as false + if !con.authed { + // authedConns[idx] = false + } + timeout := time.After(3 * time.Second) select { case <-ctx.Done(): return + case <-allTimeout: + return + case <-timeout: continue @@ -49,9 +60,12 @@ func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}) } } -func (g *ConnectivityGroup) AllAuthedC(ctx context.Context) <-chan struct{} { +// AllAuthedC returns a channel that will be closed when all connections are authenticated +// the returned channel will be closed when all connections are authenticated +// and the channel can only be used once (because we can't close a channel twice) +func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} { c := make(chan struct{}) - go g.waitAllAuthed(ctx, c) + go g.waitAllAuthed(ctx, c, timeout) return c } @@ -105,15 +119,6 @@ func (c *Connectivity) handleDisconnect() { defer c.mu.Unlock() c.connected = false - - if c.connectedC != nil { - close(c.connectedC) - } - - if c.authedC != nil { - close(c.authedC) - } - c.authedC = make(chan struct{}) c.connectedC = make(chan struct{}) close(c.disconnectedC) diff --git a/pkg/types/connectivity_test.go b/pkg/types/connectivity_test.go index 8734e416b..5d36ede37 100644 --- a/pkg/types/connectivity_test.go +++ b/pkg/types/connectivity_test.go @@ -4,38 +4,103 @@ import ( "context" "testing" "time" + + "github.com/stretchr/testify/assert" ) -func TestConnectivityGroup(t *testing.T) { +func TestConnectivity(t *testing.T) { + t.Run("general", func(t *testing.T) { + conn1 := NewConnectivity() + conn1.handleConnect() + conn1.handleAuth() + conn1.handleDisconnect() + }) + + t.Run("reconnect", func(t *testing.T) { + conn1 := NewConnectivity() + conn1.handleConnect() + conn1.handleAuth() + conn1.handleDisconnect() + + conn1.handleConnect() + conn1.handleAuth() + conn1.handleDisconnect() + }) + + t.Run("no-auth reconnect", func(t *testing.T) { + conn1 := NewConnectivity() + conn1.handleConnect() + conn1.handleDisconnect() + + conn1.handleConnect() + conn1.handleDisconnect() + }) +} + +func TestConnectivityGroupAuthC(t *testing.T) { + timeout := 100 * time.Millisecond + ctx := context.Background() conn1 := NewConnectivity() conn2 := NewConnectivity() group := NewConnectivityGroup(conn1, conn2) - allAuthedC := group.AllAuthedC(ctx) + allAuthedC := group.AllAuthedC(ctx, time.Second) - time.Sleep(100 * time.Millisecond) + time.Sleep(timeout) conn1.handleConnect() - waitSigChan(t, conn1.ConnectedC(), time.Second) + assert.True(t, waitSigChan(conn1.ConnectedC(), timeout)) conn1.handleAuth() - waitSigChan(t, conn1.AuthedC(), time.Second) + assert.True(t, waitSigChan(conn1.AuthedC(), timeout)) - time.Sleep(100 * time.Millisecond) + time.Sleep(timeout) conn2.handleConnect() - waitSigChan(t, conn2.ConnectedC(), time.Second) + assert.True(t, waitSigChan(conn2.ConnectedC(), timeout)) conn2.handleAuth() - waitSigChan(t, conn2.AuthedC(), time.Second) + assert.True(t, waitSigChan(conn2.AuthedC(), timeout)) - waitSigChan(t, allAuthedC, time.Second) + assert.True(t, waitSigChan(allAuthedC, timeout)) } -func waitSigChan(t *testing.T, c <-chan struct{}, timeoutDuration time.Duration) { +func TestConnectivityGroupReconnect(t *testing.T) { + timeout := 100 * time.Millisecond + delay := timeout * 2 + + ctx := context.Background() + conn1 := NewConnectivity() + conn2 := NewConnectivity() + group := NewConnectivityGroup(conn1, conn2) + + time.Sleep(delay) + conn1.handleConnect() + conn1.handleAuth() + conn1authC := conn1.authedC + + time.Sleep(delay) + conn2.handleConnect() + conn2.handleAuth() + + assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated") + + // this should re-allocate authedC + conn1.handleDisconnect() + assert.NotEqual(t, conn1authC, conn1.authedC) + + assert.False(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "one connection should be un-authed") + + time.Sleep(delay) + + conn1.handleConnect() + conn1.handleAuth() + assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated, again") +} + +func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool { select { case <-time.After(timeoutDuration): - t.Log("timeout") - t.Fail() + return false case <-c: - t.Log("signal received") + return true } }