From 19f259111dc0dead3479133b47c90ec9e53af1c6 Mon Sep 17 00:00:00 2001 From: c9s Date: Mon, 26 Oct 2020 15:31:13 +0800 Subject: [PATCH] improve config loading by adding unmarshal yaml method --- config/bbgo.yaml | 6 ++-- pkg/bbgo/strategy_test.go | 20 +++++------- pkg/cmd/run.go | 9 ++--- pkg/config/loader.go | 69 ++++++++++++++++----------------------- 4 files changed, 44 insertions(+), 60 deletions(-) diff --git a/config/bbgo.yaml b/config/bbgo.yaml index 400df0f73..ae7edb351 100644 --- a/config/bbgo.yaml +++ b/config/bbgo.yaml @@ -27,12 +27,10 @@ reportPnL: sessions: max: exchange: max - keyVar: MAX_API_KEY - secretVar: MAX_API_SECRET + envVarPrefix: "MAX_API_" binance: exchange: binance - keyVar: BINANCE_API_KEY - secretVar: BINANCE_API_SECRET + envVarPrefix: "BINANCE_API_" exchangeStrategies: - on: binance diff --git a/pkg/bbgo/strategy_test.go b/pkg/bbgo/strategy_test.go index fd3dc8197..50b4a5a58 100644 --- a/pkg/bbgo/strategy_test.go +++ b/pkg/bbgo/strategy_test.go @@ -25,13 +25,13 @@ func TestTradeService(t *testing.T) { service.NewTradeService(xdb) /* - stmt := mock.ExpectQuery(`SELECT \* FROM trades WHERE symbol = \? ORDER BY gid DESC LIMIT 1`) - stmt.WithArgs("BTCUSDT") - stmt.WillReturnRows(sqlmock.NewRows([]string{"gid", "id", "exchange", "symbol", "price", "quantity"})) + stmt := mock.ExpectQuery(`SELECT \* FROM trades WHERE symbol = \? ORDER BY gid DESC LIMIT 1`) + stmt.WithArgs("BTCUSDT") + stmt.WillReturnRows(sqlmock.NewRows([]string{"gid", "id", "exchange", "symbol", "price", "quantity"})) - stmt2 := mock.ExpectQuery(`INSERT INTO trades (id, exchange, symbol, price, quantity, quote_quantity, side, is_buyer, is_maker, fee, fee_currency, traded_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) - stmt2.WithArgs() + stmt2 := mock.ExpectQuery(`INSERT INTO trades (id, exchange, symbol, price, quantity, quote_quantity, side, is_buyer, is_maker, fee, fee_currency, traded_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + stmt2.WithArgs() */ } @@ -53,16 +53,12 @@ func TestEnvironment_Connect(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - xdb, err := sqlx.Connect("mysql", mysqlURL) - assert.NoError(t, err) - environment := NewEnvironment() environment.AddExchange("binance", exchange). - Subscribe(types.KLineChannel,"BTCUSDT", types.SubscribeOptions{}) + Subscribe(types.KLineChannel, "BTCUSDT", types.SubscribeOptions{}) - err = environment.Connect(ctx) + err := environment.Connect(ctx) assert.NoError(t, err) time.Sleep(5 * time.Second) } - diff --git a/pkg/cmd/run.go b/pkg/cmd/run.go index 6de4d3fcc..7b6bee140 100644 --- a/pkg/cmd/run.go +++ b/pkg/cmd/run.go @@ -66,7 +66,7 @@ func compileRunFile(filepath string, config *config.Config) error { return ioutil.WriteFile(filepath, buf.Bytes(), 0644) } -func runConfig(ctx context.Context, config *config.Config) error { +func runConfig(ctx context.Context, userConfig *config.Config) error { // configure notifiers slackToken := viper.GetString("slack-token") if len(slackToken) > 0 { @@ -93,21 +93,22 @@ func runConfig(ctx context.Context, config *config.Config) error { trader := bbgo.NewTrader(environ) trader.AddNotifier(notifierSet) - for _, entry := range config.ExchangeStrategies { + for _, entry := range userConfig.ExchangeStrategies { for _, mount := range entry.Mounts { log.Infof("attaching strategy %T on %s...", entry.Strategy, mount) trader.AttachStrategyOn(mount, entry.Strategy) } } - for _, strategy := range config.CrossExchangeStrategies { + for _, strategy := range userConfig.CrossExchangeStrategies { log.Infof("attaching strategy %T", strategy) trader.AttachCrossExchangeStrategy(strategy) } - for _, report := range config.PnLReporters { + for _, report := range userConfig.PnLReporters { if len(report.AverageCostBySymbols) > 0 { + log.Infof("setting up average cost pnl reporter on symbols: %v", report.AverageCostBySymbols) trader.ReportPnL(notifierSet). AverageCostBySymbols(report.AverageCostBySymbols...). Of(report.Of...). diff --git a/pkg/config/loader.go b/pkg/config/loader.go index f1fee61f8..414768ad0 100644 --- a/pkg/config/loader.go +++ b/pkg/config/loader.go @@ -40,6 +40,23 @@ func (s *StringSlice) decode(a interface{}) error { return nil } +func (s *StringSlice) UnmarshalYAML(unmarshal func(interface{}) error) (err error) { + var ss []string + err = unmarshal(&ss) + if err == nil { + *s = ss + return + } + + var as string + err = unmarshal(&as) + if err == nil { + *s = append(*s, as) + } + + return err +} + func (s *StringSlice) UnmarshalJSON(b []byte) error { var a interface{} var err = json.Unmarshal(b, &a) @@ -51,18 +68,25 @@ func (s *StringSlice) UnmarshalJSON(b []byte) error { } type PnLReporter struct { - AverageCostBySymbols StringSlice `json:"averageCostBySymbols"` + AverageCostBySymbols StringSlice `json:"averageCostBySymbols" yaml:"averageCostBySymbols"` Of StringSlice `json:"of" yaml:"of"` When StringSlice `json:"when" yaml:"when"` } +type Session struct { + ExchangeName string `json:"exchange" yaml:"exchange"` + EnvVarPrefix string `json:"envVarPrefix" yaml:"envVarPrefix"` +} + type Config struct { Imports []string `json:"imports" yaml:"imports"` + Sessions map[string]Session `json:"sessions,omitempty" yaml:"sessions,omitempty"` + ExchangeStrategies []SingleExchangeStrategyConfig CrossExchangeStrategies []bbgo.CrossExchangeStrategy - PnLReporters []PnLReporter `json:"reportPnL" yaml:"reportPnL"` + PnLReporters []PnLReporter `json:"reportPnL,omitempty" yaml:"reportPnL,omitempty"` } type Stash map[string]interface{} @@ -84,12 +108,12 @@ func Load(configFile string) (*Config, error) { return nil, err } - stash, err := loadStash(content) - if err != nil { + if err := yaml.Unmarshal(content, &config); err != nil { return nil, err } - if err := loadImports(&config, stash); err != nil { + stash, err := loadStash(content) + if err != nil { return nil, err } @@ -101,43 +125,9 @@ func Load(configFile string) (*Config, error) { return nil, err } - if err := loadReportPnL(&config, stash); err != nil { - return nil, err - } - return &config, nil } -func loadImports(config *Config, stash Stash) error { - importStash, ok := stash["imports"] - if !ok { - return nil - } - - imports, err := reUnmarshal(importStash, &config.Imports) - if err != nil { - return err - } - - config.Imports = *imports.(*[]string) - return nil -} - -func loadReportPnL(config *Config, stash Stash) error { - reporterStash, ok := stash["reportPnL"] - if !ok { - return nil - } - - reporters, err := reUnmarshal(reporterStash, &config.PnLReporters) - if err != nil { - return err - } - - config.PnLReporters = *(reporters.(*[]PnLReporter)) - return nil -} - func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) { exchangeStrategiesConf, ok := stash["crossExchangeStrategies"] if !ok { @@ -148,7 +138,6 @@ func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) { return errors.New("no cross exchange strategy is registered") } - configList, ok := exchangeStrategiesConf.([]interface{}) if !ok { return errors.New("expecting list in crossExchangeStrategies")