From df125c0efbe841de0df11de9bf2d70db6ba62dfc Mon Sep 17 00:00:00 2001 From: c9s Date: Wed, 19 Jun 2024 17:35:38 +0800 Subject: [PATCH 1/2] batch: improve trade batch query --- pkg/exchange/batch/trade.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pkg/exchange/batch/trade.go b/pkg/exchange/batch/trade.go index 4fce26b65..5371adc7b 100644 --- a/pkg/exchange/batch/trade.go +++ b/pkg/exchange/batch/trade.go @@ -17,20 +17,22 @@ type TradeBatchQuery struct { types.ExchangeTradeHistoryService } -func (e TradeBatchQuery) Query(ctx context.Context, symbol string, options *types.TradeQueryOptions, opts ...Option) (c chan types.Trade, errC chan error) { +func (e TradeBatchQuery) Query( + ctx context.Context, symbol string, options *types.TradeQueryOptions, opts ...Option, +) (c chan types.Trade, errC chan error) { if options.EndTime == nil { now := time.Now() options.EndTime = &now } - startTime := *options.StartTime - endTime := *options.EndTime query := &AsyncTimeRangedBatchQuery{ Type: types.Trade{}, Q: func(startTime, endTime time.Time) (interface{}, error) { - options.StartTime = &startTime - options.EndTime = &endTime - return e.ExchangeTradeHistoryService.QueryTrades(ctx, symbol, options) + return e.ExchangeTradeHistoryService.QueryTrades(ctx, symbol, &types.TradeQueryOptions{ + StartTime: &startTime, + EndTime: &endTime, + LastTradeID: options.LastTradeID, + }) }, T: func(obj interface{}) time.Time { return time.Time(obj.(types.Trade).Time) @@ -40,9 +42,10 @@ func (e TradeBatchQuery) Query(ctx context.Context, symbol string, options *type if trade.ID > options.LastTradeID { options.LastTradeID = trade.ID } + return trade.Key().String() }, - JumpIfEmpty: 24 * time.Hour, + JumpIfEmpty: 23 * time.Hour, // exchange may not have trades in the last 24 hours } for _, opt := range opts { @@ -50,6 +53,6 @@ func (e TradeBatchQuery) Query(ctx context.Context, symbol string, options *type } c = make(chan types.Trade, 100) - errC = query.Query(ctx, c, startTime, endTime) + errC = query.Query(ctx, c, *options.StartTime, *options.EndTime) return c, errC } From 1dc1afc993974b0dda8fa9035871c80463a98926 Mon Sep 17 00:00:00 2001 From: c9s Date: Thu, 20 Jun 2024 16:54:05 +0800 Subject: [PATCH 2/2] batch: add TradeQueryOptionsMatcher for testing trade query options --- pkg/exchange/batch/trade.go | 3 +- pkg/exchange/batch/trade_test.go | 66 ++++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/pkg/exchange/batch/trade.go b/pkg/exchange/batch/trade.go index 5371adc7b..7572cd8d2 100644 --- a/pkg/exchange/batch/trade.go +++ b/pkg/exchange/batch/trade.go @@ -31,6 +31,7 @@ func (e TradeBatchQuery) Query( return e.ExchangeTradeHistoryService.QueryTrades(ctx, symbol, &types.TradeQueryOptions{ StartTime: &startTime, EndTime: &endTime, + Limit: options.Limit, LastTradeID: options.LastTradeID, }) }, @@ -45,7 +46,7 @@ func (e TradeBatchQuery) Query( return trade.Key().String() }, - JumpIfEmpty: 23 * time.Hour, // exchange may not have trades in the last 24 hours + JumpIfEmpty: 24*time.Hour - 5*time.Minute, // exchange may not have trades in the last 24 hours } for _, opt := range opts { diff --git a/pkg/exchange/batch/trade_test.go b/pkg/exchange/batch/trade_test.go index 26a175bfa..87725e9e4 100644 --- a/pkg/exchange/batch/trade_test.go +++ b/pkg/exchange/batch/trade_test.go @@ -2,6 +2,7 @@ package batch import ( "context" + "fmt" "sync" "testing" "time" @@ -14,6 +15,49 @@ import ( "github.com/c9s/bbgo/pkg/types/mocks" ) +func matchTradeQueryOptions(expected *types.TradeQueryOptions) *TradeQueryOptionsMatcher { + return &TradeQueryOptionsMatcher{ + expected: expected, + } +} + +type TradeQueryOptionsMatcher struct { + expected *types.TradeQueryOptions +} + +func (m TradeQueryOptionsMatcher) Matches(arg interface{}) bool { + given, ok := arg.(*types.TradeQueryOptions) + if !ok { + return false + } + + if given.StartTime != nil && m.expected.StartTime != nil { + if !given.StartTime.Equal(*m.expected.StartTime) { + return false + } + } + + if given.EndTime != nil && m.expected.EndTime != nil { + if !given.EndTime.Equal(*m.expected.EndTime) { + return false + } + } + + if given.Limit != m.expected.Limit { + return false + } + + if given.LastTradeID != m.expected.LastTradeID { + return false + } + + return true +} + +func (m TradeQueryOptionsMatcher) String() string { + return fmt.Sprintf("%+v", m.expected) +} + func Test_TradeBatchQuery(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -55,7 +99,7 @@ func Test_TradeBatchQuery(t *testing.T) { mockExchange = mocks.NewMockExchangeTradeHistoryService(ctrl) ) - mockExchange.EXPECT().QueryTrades(ctx, expSymbol, expOptions).DoAndReturn( + mockExchange.EXPECT().QueryTrades(ctx, expSymbol, matchTradeQueryOptions(expOptions)).DoAndReturn( func(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) { assert.Equal(t, startTime, *options.StartTime) assert.Equal(t, endTime, *options.EndTime) @@ -63,7 +107,13 @@ func Test_TradeBatchQuery(t *testing.T) { assert.Equal(t, expOptions.Limit, options.Limit) return queryTrades1, nil }).Times(1) - mockExchange.EXPECT().QueryTrades(ctx, expSymbol, expOptions).DoAndReturn( + + mockExchange.EXPECT().QueryTrades(ctx, expSymbol, matchTradeQueryOptions(&types.TradeQueryOptions{ + StartTime: timePtr(queryTrades1[0].Time.Time()), + EndTime: expOptions.EndTime, + LastTradeID: 1, + Limit: 50, + })).DoAndReturn( func(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) { assert.Equal(t, queryTrades1[0].Time.Time(), *options.StartTime) assert.Equal(t, endTime, *options.EndTime) @@ -71,7 +121,13 @@ func Test_TradeBatchQuery(t *testing.T) { assert.Equal(t, expOptions.Limit, options.Limit) return queryTrades2, nil }).Times(1) - mockExchange.EXPECT().QueryTrades(ctx, expSymbol, expOptions).DoAndReturn( + + mockExchange.EXPECT().QueryTrades(ctx, expSymbol, matchTradeQueryOptions(&types.TradeQueryOptions{ + StartTime: timePtr(queryTrades2[1].Time.Time()), + EndTime: expOptions.EndTime, + LastTradeID: queryTrades2[1].ID, + Limit: 50, + })).DoAndReturn( func(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) { assert.Equal(t, queryTrades2[1].Time.Time(), *options.StartTime) assert.Equal(t, endTime, *options.EndTime) @@ -138,3 +194,7 @@ func Test_TradeBatchQuery(t *testing.T) { assert.Equal(t, rcvCount, 0) }) } + +func timePtr(t time.Time) *time.Time { + return &t +}