From 1dc1afc993974b0dda8fa9035871c80463a98926 Mon Sep 17 00:00:00 2001 From: c9s Date: Thu, 20 Jun 2024 16:54:05 +0800 Subject: [PATCH] batch: add TradeQueryOptionsMatcher for testing trade query options --- pkg/exchange/batch/trade.go | 3 +- pkg/exchange/batch/trade_test.go | 66 ++++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/pkg/exchange/batch/trade.go b/pkg/exchange/batch/trade.go index 5371adc7b..7572cd8d2 100644 --- a/pkg/exchange/batch/trade.go +++ b/pkg/exchange/batch/trade.go @@ -31,6 +31,7 @@ func (e TradeBatchQuery) Query( return e.ExchangeTradeHistoryService.QueryTrades(ctx, symbol, &types.TradeQueryOptions{ StartTime: &startTime, EndTime: &endTime, + Limit: options.Limit, LastTradeID: options.LastTradeID, }) }, @@ -45,7 +46,7 @@ func (e TradeBatchQuery) Query( return trade.Key().String() }, - JumpIfEmpty: 23 * time.Hour, // exchange may not have trades in the last 24 hours + JumpIfEmpty: 24*time.Hour - 5*time.Minute, // exchange may not have trades in the last 24 hours } for _, opt := range opts { diff --git a/pkg/exchange/batch/trade_test.go b/pkg/exchange/batch/trade_test.go index 26a175bfa..87725e9e4 100644 --- a/pkg/exchange/batch/trade_test.go +++ b/pkg/exchange/batch/trade_test.go @@ -2,6 +2,7 @@ package batch import ( "context" + "fmt" "sync" "testing" "time" @@ -14,6 +15,49 @@ import ( "github.com/c9s/bbgo/pkg/types/mocks" ) +func matchTradeQueryOptions(expected *types.TradeQueryOptions) *TradeQueryOptionsMatcher { + return &TradeQueryOptionsMatcher{ + expected: expected, + } +} + +type TradeQueryOptionsMatcher struct { + expected *types.TradeQueryOptions +} + +func (m TradeQueryOptionsMatcher) Matches(arg interface{}) bool { + given, ok := arg.(*types.TradeQueryOptions) + if !ok { + return false + } + + if given.StartTime != nil && m.expected.StartTime != nil { + if !given.StartTime.Equal(*m.expected.StartTime) { + return false + } + } + + if given.EndTime != nil && m.expected.EndTime != nil { + if !given.EndTime.Equal(*m.expected.EndTime) { + return false + } + } + + if given.Limit != m.expected.Limit { + return false + } + + if given.LastTradeID != m.expected.LastTradeID { + return false + } + + return true +} + +func (m TradeQueryOptionsMatcher) String() string { + return fmt.Sprintf("%+v", m.expected) +} + func Test_TradeBatchQuery(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -55,7 +99,7 @@ func Test_TradeBatchQuery(t *testing.T) { mockExchange = mocks.NewMockExchangeTradeHistoryService(ctrl) ) - mockExchange.EXPECT().QueryTrades(ctx, expSymbol, expOptions).DoAndReturn( + mockExchange.EXPECT().QueryTrades(ctx, expSymbol, matchTradeQueryOptions(expOptions)).DoAndReturn( func(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) { assert.Equal(t, startTime, *options.StartTime) assert.Equal(t, endTime, *options.EndTime) @@ -63,7 +107,13 @@ func Test_TradeBatchQuery(t *testing.T) { assert.Equal(t, expOptions.Limit, options.Limit) return queryTrades1, nil }).Times(1) - mockExchange.EXPECT().QueryTrades(ctx, expSymbol, expOptions).DoAndReturn( + + mockExchange.EXPECT().QueryTrades(ctx, expSymbol, matchTradeQueryOptions(&types.TradeQueryOptions{ + StartTime: timePtr(queryTrades1[0].Time.Time()), + EndTime: expOptions.EndTime, + LastTradeID: 1, + Limit: 50, + })).DoAndReturn( func(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) { assert.Equal(t, queryTrades1[0].Time.Time(), *options.StartTime) assert.Equal(t, endTime, *options.EndTime) @@ -71,7 +121,13 @@ func Test_TradeBatchQuery(t *testing.T) { assert.Equal(t, expOptions.Limit, options.Limit) return queryTrades2, nil }).Times(1) - mockExchange.EXPECT().QueryTrades(ctx, expSymbol, expOptions).DoAndReturn( + + mockExchange.EXPECT().QueryTrades(ctx, expSymbol, matchTradeQueryOptions(&types.TradeQueryOptions{ + StartTime: timePtr(queryTrades2[1].Time.Time()), + EndTime: expOptions.EndTime, + LastTradeID: queryTrades2[1].ID, + Limit: 50, + })).DoAndReturn( func(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) { assert.Equal(t, queryTrades2[1].Time.Time(), *options.StartTime) assert.Equal(t, endTime, *options.EndTime) @@ -138,3 +194,7 @@ func Test_TradeBatchQuery(t *testing.T) { assert.Equal(t, rcvCount, 0) }) } + +func timePtr(t time.Time) *time.Time { + return &t +}