Merge pull request #1754 from c9s/c9s/xdepthmaker/refactor
Some checks are pending
Go / build (1.21, 6.2) (push) Waiting to run
golang-lint / lint (push) Waiting to run

This commit is contained in:
c9s 2024-09-27 15:03:57 +08:00 committed by GitHub
commit c8823e977f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 408 additions and 111 deletions

View File

@ -267,9 +267,11 @@ type Strategy struct {
lastSourcePrice fixedpoint.MutexValue
stopC, authedC chan struct{}
stopC chan struct{}
logger logrus.FieldLogger
connectivityGroup *types.ConnectivityGroup
}
func (s *Strategy) ID() string {
@ -388,6 +390,116 @@ func (s *Strategy) Defaults() error {
return nil
}
func (s *Strategy) quoteWorker(ctx context.Context) {
updateTicker := time.NewTicker(util.MillisecondsJitter(s.UpdateInterval.Duration(), 200))
defer updateTicker.Stop()
fullReplenishTicker := time.NewTicker(util.MillisecondsJitter(s.FullReplenishInterval.Duration(), 200))
defer fullReplenishTicker.Stop()
// clean up the previous open orders
if err := s.cleanUpOpenOrders(ctx, s.makerSession); err != nil {
log.WithError(err).Errorf("error cleaning up open orders")
return
}
s.updateQuote(ctx, 0)
lastOrderReplenishTime := time.Now()
for {
select {
case <-ctx.Done():
return
case <-s.stopC:
log.Warnf("%s maker goroutine stopped, due to the stop signal", s.Symbol)
return
case <-fullReplenishTicker.C:
s.updateQuote(ctx, 0)
lastOrderReplenishTime = time.Now()
case sig, ok := <-s.sourceBook.C:
// when any book change event happened
if !ok {
return
}
if time.Since(lastOrderReplenishTime) < 10*time.Second {
continue
}
switch sig.Type {
case types.BookSignalSnapshot:
s.updateQuote(ctx, 0)
case types.BookSignalUpdate:
s.updateQuote(ctx, 5)
}
lastOrderReplenishTime = time.Now()
}
}
}
func (s *Strategy) hedgeWorker(ctx context.Context) {
ticker := time.NewTicker(util.MillisecondsJitter(s.HedgeInterval.Duration(), 200))
defer ticker.Stop()
for {
select {
case <-ctx.Done():
s.logger.Warnf("maker goroutine stopped, due to context canceled")
return
case <-s.stopC:
s.logger.Warnf("maker goroutine stopped, due to the stop signal")
return
case <-ticker.C:
// For positive position and positive covered position:
// uncover position = +5 - +3 (covered position) = 2
//
// For positive position and negative covered position:
// uncover position = +5 - (-3) (covered position) = 8
//
// meaning we bought 5 on MAX and sent buy order with 3 on binance
//
// For negative position:
// uncover position = -5 - -3 (covered position) = -2
s.HedgeOrderExecutor.TradeCollector().Process()
s.MakerOrderExecutor.TradeCollector().Process()
position := s.Position.GetBase()
coveredPosition := s.CoveredPosition.Get()
uncoverPosition := position.Sub(coveredPosition)
absPos := uncoverPosition.Abs()
if !s.hedgeMarket.IsDustQuantity(absPos, s.lastSourcePrice.Get()) {
log.Infof("%s base position %v coveredPosition: %v uncoverPosition: %v",
s.Symbol,
position,
coveredPosition,
uncoverPosition,
)
if !s.DisableHedge {
if err := s.Hedge(ctx, uncoverPosition.Neg()); err != nil {
//goland:noinspection GoDirectComparisonOfErrors
switch err {
case ErrZeroQuantity, ErrDustQuantity:
default:
s.logger.WithError(err).Errorf("unable to hedge position")
}
}
}
}
}
}
}
func (s *Strategy) CrossRun(
ctx context.Context, _ bbgo.OrderExecutionRouter,
sessions map[string]*bbgo.ExchangeSession,
@ -450,9 +562,14 @@ func (s *Strategy) CrossRun(
s.stopC = make(chan struct{})
s.authedC = make(chan struct{}, 5)
bindAuthSignal(ctx, s.makerSession.UserDataStream, s.authedC)
bindAuthSignal(ctx, s.hedgeSession.UserDataStream, s.authedC)
makerConnectivity := types.NewConnectivity()
makerConnectivity.Bind(s.makerSession.UserDataStream)
hedgerConnectivity := types.NewConnectivity()
hedgerConnectivity.Bind(s.hedgeSession.UserDataStream)
connGroup := types.NewConnectivityGroup(makerConnectivity, hedgerConnectivity)
s.connectivityGroup = connGroup
if s.RecoverTrade {
go s.runTradeRecover(ctx)
@ -460,110 +577,17 @@ func (s *Strategy) CrossRun(
go func() {
log.Infof("waiting for user data stream to get authenticated")
select {
case <-ctx.Done():
return
case <-s.authedC:
}
select {
case <-ctx.Done():
return
case <-s.authedC:
case <-connGroup.AllAuthedC(ctx, time.Minute):
}
log.Infof("user data stream authenticated, start placing orders...")
posTicker := time.NewTicker(util.MillisecondsJitter(s.HedgeInterval.Duration(), 200))
defer posTicker.Stop()
fullReplenishTicker := time.NewTicker(util.MillisecondsJitter(s.FullReplenishInterval.Duration(), 200))
defer fullReplenishTicker.Stop()
// clean up the previous open orders
if err := s.cleanUpOpenOrders(ctx, s.makerSession); err != nil {
log.WithError(err).Errorf("error cleaning up open orders")
}
s.updateQuote(ctx, 0)
lastOrderReplenishTime := time.Now()
for {
select {
case <-s.stopC:
log.Warnf("%s maker goroutine stopped, due to the stop signal", s.Symbol)
return
case <-ctx.Done():
log.Warnf("%s maker goroutine stopped, due to the cancelled context", s.Symbol)
return
case <-fullReplenishTicker.C:
s.updateQuote(ctx, 0)
lastOrderReplenishTime = time.Now()
case sig, ok := <-s.sourceBook.C:
// when any book change event happened
if !ok {
return
}
if time.Since(lastOrderReplenishTime) < 10*time.Second {
continue
}
switch sig.Type {
case types.BookSignalSnapshot:
s.updateQuote(ctx, 0)
case types.BookSignalUpdate:
s.updateQuote(ctx, 5)
}
lastOrderReplenishTime = time.Now()
case <-posTicker.C:
// For positive position and positive covered position:
// uncover position = +5 - +3 (covered position) = 2
//
// For positive position and negative covered position:
// uncover position = +5 - (-3) (covered position) = 8
//
// meaning we bought 5 on MAX and sent buy order with 3 on binance
//
// For negative position:
// uncover position = -5 - -3 (covered position) = -2
s.HedgeOrderExecutor.TradeCollector().Process()
s.MakerOrderExecutor.TradeCollector().Process()
position := s.Position.GetBase()
coveredPosition := s.CoveredPosition.Get()
uncoverPosition := position.Sub(coveredPosition)
absPos := uncoverPosition.Abs()
if absPos.Compare(s.hedgeMarket.MinQuantity) > 0 {
log.Infof("%s base position %v coveredPosition: %v uncoverPosition: %v",
s.Symbol,
position,
coveredPosition,
uncoverPosition,
)
if !s.DisableHedge {
if err := s.Hedge(ctx, uncoverPosition.Neg()); err != nil {
//goland:noinspection GoDirectComparisonOfErrors
switch err {
case ErrZeroQuantity, ErrDustQuantity:
default:
s.logger.WithError(err).Errorf("unable to hedge position")
}
}
}
}
}
}
go s.hedgeWorker(ctx)
go s.quoteWorker(ctx)
}()
bbgo.OnShutdown(ctx, func(ctx context.Context, wg *sync.WaitGroup) {
@ -1144,14 +1168,3 @@ func min(a, b int) int {
return b
}
func bindAuthSignal(ctx context.Context, stream types.Stream, c chan<- struct{}) {
stream.OnAuth(func() {
select {
case <-ctx.Done():
return
case c <- struct{}{}:
default:
}
})
}

173
pkg/types/connectivity.go Normal file
View File

@ -0,0 +1,173 @@
package types
import (
"context"
"sync"
"time"
)
type ConnectivityGroup struct {
connections []*Connectivity
mu sync.Mutex
}
func NewConnectivityGroup(cons ...*Connectivity) *ConnectivityGroup {
return &ConnectivityGroup{
connections: cons,
}
}
func (g *ConnectivityGroup) Add(con *Connectivity) {
g.mu.Lock()
defer g.mu.Unlock()
g.connections = append(g.connections, con)
}
func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool {
g.mu.Lock()
conns := g.connections
g.mu.Unlock()
for _, conn := range conns {
select {
case <-ctx.Done():
return false
case <-conn.connectedC:
continue
case <-conn.disconnectedC:
return true
}
}
return false
}
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}, allTimeoutDuration time.Duration) {
g.mu.Lock()
conns := g.connections
g.mu.Unlock()
authedConns := make([]bool, len(conns))
allTimeout := time.After(allTimeoutDuration)
for {
for idx, con := range conns {
// if the connection is not authed, mark it as false
if !con.authed {
// authedConns[idx] = false
}
timeout := time.After(3 * time.Second)
select {
case <-ctx.Done():
return
case <-allTimeout:
return
case <-timeout:
continue
case <-con.AuthedC():
authedConns[idx] = true
}
}
if allTrue(authedConns) {
close(c)
return
}
}
}
// AllAuthedC returns a channel that will be closed when all connections are authenticated
// the returned channel will be closed when all connections are authenticated
// and the channel can only be used once (because we can't close a channel twice)
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} {
c := make(chan struct{})
go g.waitAllAuthed(ctx, c, timeout)
return c
}
func allTrue(bools []bool) bool {
for _, b := range bools {
if !b {
return false
}
}
return true
}
type Connectivity struct {
authed bool
authedC chan struct{}
connected bool
connectedC chan struct{}
disconnectedC chan struct{}
mu sync.Mutex
}
func NewConnectivity() *Connectivity {
closedC := make(chan struct{})
close(closedC)
return &Connectivity{
authed: false,
authedC: make(chan struct{}),
connected: false,
connectedC: make(chan struct{}),
disconnectedC: closedC,
}
}
func (c *Connectivity) handleConnect() {
c.mu.Lock()
defer c.mu.Unlock()
c.connected = true
close(c.connectedC)
c.disconnectedC = make(chan struct{})
}
func (c *Connectivity) handleDisconnect() {
c.mu.Lock()
defer c.mu.Unlock()
c.connected = false
c.authedC = make(chan struct{})
c.connectedC = make(chan struct{})
close(c.disconnectedC)
}
func (c *Connectivity) handleAuth() {
c.mu.Lock()
defer c.mu.Unlock()
c.authed = true
close(c.authedC)
}
func (c *Connectivity) AuthedC() chan struct{} {
c.mu.Lock()
defer c.mu.Unlock()
return c.authedC
}
func (c *Connectivity) ConnectedC() chan struct{} {
c.mu.Lock()
defer c.mu.Unlock()
return c.connectedC
}
func (c *Connectivity) Bind(stream Stream) {
stream.OnConnect(c.handleConnect)
stream.OnDisconnect(c.handleDisconnect)
stream.OnAuth(c.handleAuth)
}

View File

@ -0,0 +1,111 @@
package types
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestConnectivity(t *testing.T) {
t.Run("general", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.handleConnect()
conn1.handleAuth()
conn1.handleDisconnect()
})
t.Run("reconnect", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.handleConnect()
conn1.handleAuth()
conn1.handleDisconnect()
conn1.handleConnect()
conn1.handleAuth()
conn1.handleDisconnect()
})
t.Run("no-auth reconnect", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.handleConnect()
conn1.handleDisconnect()
conn1.handleConnect()
conn1.handleDisconnect()
})
}
func TestConnectivityGroupAuthC(t *testing.T) {
timeout := 100 * time.Millisecond
delay := timeout * 2
ctx := context.Background()
conn1 := NewConnectivity()
conn2 := NewConnectivity()
group := NewConnectivityGroup(conn1, conn2)
allAuthedC := group.AllAuthedC(ctx, time.Second)
time.Sleep(delay)
conn1.handleConnect()
assert.True(t, waitSigChan(conn1.ConnectedC(), timeout))
conn1.handleAuth()
assert.True(t, waitSigChan(conn1.AuthedC(), timeout))
time.Sleep(delay)
conn2.handleConnect()
assert.True(t, waitSigChan(conn2.ConnectedC(), timeout))
conn2.handleAuth()
assert.True(t, waitSigChan(conn2.AuthedC(), timeout))
assert.True(t, waitSigChan(allAuthedC, timeout))
}
func TestConnectivityGroupReconnect(t *testing.T) {
timeout := 100 * time.Millisecond
delay := timeout * 2
ctx := context.Background()
conn1 := NewConnectivity()
conn2 := NewConnectivity()
group := NewConnectivityGroup(conn1, conn2)
time.Sleep(delay)
conn1.handleConnect()
conn1.handleAuth()
conn1authC := conn1.authedC
time.Sleep(delay)
conn2.handleConnect()
conn2.handleAuth()
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated")
assert.False(t, group.AnyDisconnected(ctx))
// this should re-allocate authedC
conn1.handleDisconnect()
assert.NotEqual(t, conn1authC, conn1.authedC)
assert.True(t, group.AnyDisconnected(ctx))
assert.False(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "one connection should be un-authed")
time.Sleep(delay)
conn1.handleConnect()
conn1.handleAuth()
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated, again")
}
func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool {
select {
case <-time.After(timeoutDuration):
return false
case <-c:
return true
}
}