From f277b191d23724a8edb312c7b1ddf2c58470c1db Mon Sep 17 00:00:00 2001 From: c9s Date: Thu, 8 Aug 2024 17:00:45 +0800 Subject: [PATCH] core: add ConverterManager --- pkg/core/converter.go | 43 +++++++++++++++++++++ pkg/core/tradecollector.go | 76 ++++++++++++++++++++++---------------- 2 files changed, 88 insertions(+), 31 deletions(-) create mode 100644 pkg/core/converter.go diff --git a/pkg/core/converter.go b/pkg/core/converter.go new file mode 100644 index 000000000..2a2f24d62 --- /dev/null +++ b/pkg/core/converter.go @@ -0,0 +1,43 @@ +package core + +import "github.com/c9s/bbgo/pkg/types" + +type Converter interface { + OrderConverter + TradeConverter +} + +// OrderConverter converts the order to another order +type OrderConverter interface { + ConvertOrder(order types.Order) (types.Order, error) +} + +// TradeConverter converts the trade to another trade +type TradeConverter interface { + ConvertTrade(trade types.Trade) (types.Trade, error) +} + +// SymbolConverter converts the symbol to another symbol +type SymbolConverter struct { + fromSymbol, toSymbol string +} + +func NewSymbolConverter(fromSymbol, toSymbol string) *SymbolConverter { + return &SymbolConverter{fromSymbol: fromSymbol, toSymbol: toSymbol} +} + +func (c *SymbolConverter) ConvertOrder(order types.Order) (types.Order, error) { + if order.Symbol == c.fromSymbol { + order.Symbol = c.toSymbol + } + + return order, nil +} + +func (c *SymbolConverter) ConvertTrade(trade types.Trade) (types.Trade, error) { + if trade.Symbol == c.fromSymbol { + trade.Symbol = c.toSymbol + } + + return trade, nil +} diff --git a/pkg/core/tradecollector.go b/pkg/core/tradecollector.go index ab2b54017..e9245b49b 100644 --- a/pkg/core/tradecollector.go +++ b/pkg/core/tradecollector.go @@ -12,12 +12,48 @@ import ( "github.com/c9s/bbgo/pkg/types" ) -type OrderConverter interface { - ConvertOrder(order types.Order) (types.Order, error) +type ConverterManager struct { + converters []Converter } -type TradeConverter interface { - ConvertTrade(trade types.Trade) (types.Trade, error) +func (c *ConverterManager) AddConverter(converter Converter) { + c.converters = append(c.converters, converter) +} + +func (c *ConverterManager) ConvertOrder(order types.Order) types.Order { + if len(c.converters) == 0 { + return order + } + + for _, converter := range c.converters { + convOrder, err := converter.ConvertOrder(order) + if err != nil { + logrus.WithError(err).Errorf("converter %+v error, order: %s", converter, order.String()) + continue + } + + order = convOrder + } + + return order +} + +func (c *ConverterManager) ConvertTrade(trade types.Trade) types.Trade { + if len(c.converters) == 0 { + return trade + } + + for _, converter := range c.converters { + convTrade, err := converter.ConvertTrade(trade) + if err != nil { + logrus.WithError(err).Errorf("converter %+v error, trade: %s", converter, trade.String()) + continue + } + + trade = convTrade + } + + return trade } //go:generate callbackgen -type TradeCollector @@ -33,14 +69,14 @@ type TradeCollector struct { mu sync.Mutex - tradeConverters []TradeConverter - recoverCallbacks []func(trade types.Trade) tradeCallbacks []func(trade types.Trade, profit, netProfit fixedpoint.Value) positionUpdateCallbacks []func(position *types.Position) profitCallbacks []func(trade types.Trade, profit *types.Profit) + + ConverterManager } func NewTradeCollector(symbol string, position *types.Position, orderStore *OrderStore) *TradeCollector { @@ -59,28 +95,6 @@ func NewTradeCollector(symbol string, position *types.Position, orderStore *Orde } } -func (c *TradeCollector) AddTradeConverter(converter TradeConverter) { - c.tradeConverters = append(c.tradeConverters, converter) -} - -func (c *TradeCollector) convertTrade(trade types.Trade) types.Trade { - if len(c.tradeConverters) == 0 { - return trade - } - - for _, converter := range c.tradeConverters { - convTrade, err := converter.ConvertTrade(trade) - if err != nil { - logrus.WithError(err).Errorf("trade %+v converter error, trade: %s", converter, trade.String()) - continue - } - - trade = convTrade - } - - return trade -} - // OrderStore returns the order store used by the trade collector func (c *TradeCollector) OrderStore() *OrderStore { return c.orderStore @@ -148,7 +162,7 @@ func (c *TradeCollector) Recover( } func (c *TradeCollector) RecoverTrade(td types.Trade) bool { - td = c.convertTrade(td) + td = c.ConvertTrade(td) logrus.Debugf("checking trade: %s", td.String()) if c.processTrade(td) { @@ -264,7 +278,7 @@ func (c *TradeCollector) processTrade(trade types.Trade) bool { // return true when the given trade is added // return false when the given trade is not added func (c *TradeCollector) ProcessTrade(trade types.Trade) bool { - return c.processTrade(c.convertTrade(trade)) + return c.processTrade(c.ConvertTrade(trade)) } // Run is a goroutine executed in the background @@ -283,7 +297,7 @@ func (c *TradeCollector) Run(ctx context.Context) { c.Process() case trade := <-c.tradeC: - c.processTrade(c.convertTrade(trade)) + c.processTrade(c.ConvertTrade(trade)) } }