improve config loading by adding unmarshal yaml method

This commit is contained in:
c9s 2020-10-26 15:31:13 +08:00
parent cd666fdf9e
commit 19f259111d
4 changed files with 44 additions and 60 deletions

View File

@ -27,12 +27,10 @@ reportPnL:
sessions: sessions:
max: max:
exchange: max exchange: max
keyVar: MAX_API_KEY envVarPrefix: "MAX_API_"
secretVar: MAX_API_SECRET
binance: binance:
exchange: binance exchange: binance
keyVar: BINANCE_API_KEY envVarPrefix: "BINANCE_API_"
secretVar: BINANCE_API_SECRET
exchangeStrategies: exchangeStrategies:
- on: binance - on: binance

View File

@ -25,13 +25,13 @@ func TestTradeService(t *testing.T) {
service.NewTradeService(xdb) service.NewTradeService(xdb)
/* /*
stmt := mock.ExpectQuery(`SELECT \* FROM trades WHERE symbol = \? ORDER BY gid DESC LIMIT 1`) stmt := mock.ExpectQuery(`SELECT \* FROM trades WHERE symbol = \? ORDER BY gid DESC LIMIT 1`)
stmt.WithArgs("BTCUSDT") stmt.WithArgs("BTCUSDT")
stmt.WillReturnRows(sqlmock.NewRows([]string{"gid", "id", "exchange", "symbol", "price", "quantity"})) stmt.WillReturnRows(sqlmock.NewRows([]string{"gid", "id", "exchange", "symbol", "price", "quantity"}))
stmt2 := mock.ExpectQuery(`INSERT INTO trades (id, exchange, symbol, price, quantity, quote_quantity, side, is_buyer, is_maker, fee, fee_currency, traded_at) stmt2 := mock.ExpectQuery(`INSERT INTO trades (id, exchange, symbol, price, quantity, quote_quantity, side, is_buyer, is_maker, fee, fee_currency, traded_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
stmt2.WithArgs() stmt2.WithArgs()
*/ */
} }
@ -53,16 +53,12 @@ func TestEnvironment_Connect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
xdb, err := sqlx.Connect("mysql", mysqlURL)
assert.NoError(t, err)
environment := NewEnvironment() environment := NewEnvironment()
environment.AddExchange("binance", exchange). environment.AddExchange("binance", exchange).
Subscribe(types.KLineChannel,"BTCUSDT", types.SubscribeOptions{}) Subscribe(types.KLineChannel, "BTCUSDT", types.SubscribeOptions{})
err = environment.Connect(ctx) err := environment.Connect(ctx)
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
} }

View File

@ -66,7 +66,7 @@ func compileRunFile(filepath string, config *config.Config) error {
return ioutil.WriteFile(filepath, buf.Bytes(), 0644) return ioutil.WriteFile(filepath, buf.Bytes(), 0644)
} }
func runConfig(ctx context.Context, config *config.Config) error { func runConfig(ctx context.Context, userConfig *config.Config) error {
// configure notifiers // configure notifiers
slackToken := viper.GetString("slack-token") slackToken := viper.GetString("slack-token")
if len(slackToken) > 0 { if len(slackToken) > 0 {
@ -93,21 +93,22 @@ func runConfig(ctx context.Context, config *config.Config) error {
trader := bbgo.NewTrader(environ) trader := bbgo.NewTrader(environ)
trader.AddNotifier(notifierSet) trader.AddNotifier(notifierSet)
for _, entry := range config.ExchangeStrategies { for _, entry := range userConfig.ExchangeStrategies {
for _, mount := range entry.Mounts { for _, mount := range entry.Mounts {
log.Infof("attaching strategy %T on %s...", entry.Strategy, mount) log.Infof("attaching strategy %T on %s...", entry.Strategy, mount)
trader.AttachStrategyOn(mount, entry.Strategy) trader.AttachStrategyOn(mount, entry.Strategy)
} }
} }
for _, strategy := range config.CrossExchangeStrategies { for _, strategy := range userConfig.CrossExchangeStrategies {
log.Infof("attaching strategy %T", strategy) log.Infof("attaching strategy %T", strategy)
trader.AttachCrossExchangeStrategy(strategy) trader.AttachCrossExchangeStrategy(strategy)
} }
for _, report := range config.PnLReporters { for _, report := range userConfig.PnLReporters {
if len(report.AverageCostBySymbols) > 0 { if len(report.AverageCostBySymbols) > 0 {
log.Infof("setting up average cost pnl reporter on symbols: %v", report.AverageCostBySymbols)
trader.ReportPnL(notifierSet). trader.ReportPnL(notifierSet).
AverageCostBySymbols(report.AverageCostBySymbols...). AverageCostBySymbols(report.AverageCostBySymbols...).
Of(report.Of...). Of(report.Of...).

View File

@ -40,6 +40,23 @@ func (s *StringSlice) decode(a interface{}) error {
return nil return nil
} }
func (s *StringSlice) UnmarshalYAML(unmarshal func(interface{}) error) (err error) {
var ss []string
err = unmarshal(&ss)
if err == nil {
*s = ss
return
}
var as string
err = unmarshal(&as)
if err == nil {
*s = append(*s, as)
}
return err
}
func (s *StringSlice) UnmarshalJSON(b []byte) error { func (s *StringSlice) UnmarshalJSON(b []byte) error {
var a interface{} var a interface{}
var err = json.Unmarshal(b, &a) var err = json.Unmarshal(b, &a)
@ -51,18 +68,25 @@ func (s *StringSlice) UnmarshalJSON(b []byte) error {
} }
type PnLReporter struct { type PnLReporter struct {
AverageCostBySymbols StringSlice `json:"averageCostBySymbols"` AverageCostBySymbols StringSlice `json:"averageCostBySymbols" yaml:"averageCostBySymbols"`
Of StringSlice `json:"of" yaml:"of"` Of StringSlice `json:"of" yaml:"of"`
When StringSlice `json:"when" yaml:"when"` When StringSlice `json:"when" yaml:"when"`
} }
type Session struct {
ExchangeName string `json:"exchange" yaml:"exchange"`
EnvVarPrefix string `json:"envVarPrefix" yaml:"envVarPrefix"`
}
type Config struct { type Config struct {
Imports []string `json:"imports" yaml:"imports"` Imports []string `json:"imports" yaml:"imports"`
Sessions map[string]Session `json:"sessions,omitempty" yaml:"sessions,omitempty"`
ExchangeStrategies []SingleExchangeStrategyConfig ExchangeStrategies []SingleExchangeStrategyConfig
CrossExchangeStrategies []bbgo.CrossExchangeStrategy CrossExchangeStrategies []bbgo.CrossExchangeStrategy
PnLReporters []PnLReporter `json:"reportPnL" yaml:"reportPnL"` PnLReporters []PnLReporter `json:"reportPnL,omitempty" yaml:"reportPnL,omitempty"`
} }
type Stash map[string]interface{} type Stash map[string]interface{}
@ -84,12 +108,12 @@ func Load(configFile string) (*Config, error) {
return nil, err return nil, err
} }
stash, err := loadStash(content) if err := yaml.Unmarshal(content, &config); err != nil {
if err != nil {
return nil, err return nil, err
} }
if err := loadImports(&config, stash); err != nil { stash, err := loadStash(content)
if err != nil {
return nil, err return nil, err
} }
@ -101,43 +125,9 @@ func Load(configFile string) (*Config, error) {
return nil, err return nil, err
} }
if err := loadReportPnL(&config, stash); err != nil {
return nil, err
}
return &config, nil return &config, nil
} }
func loadImports(config *Config, stash Stash) error {
importStash, ok := stash["imports"]
if !ok {
return nil
}
imports, err := reUnmarshal(importStash, &config.Imports)
if err != nil {
return err
}
config.Imports = *imports.(*[]string)
return nil
}
func loadReportPnL(config *Config, stash Stash) error {
reporterStash, ok := stash["reportPnL"]
if !ok {
return nil
}
reporters, err := reUnmarshal(reporterStash, &config.PnLReporters)
if err != nil {
return err
}
config.PnLReporters = *(reporters.(*[]PnLReporter))
return nil
}
func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) { func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) {
exchangeStrategiesConf, ok := stash["crossExchangeStrategies"] exchangeStrategiesConf, ok := stash["crossExchangeStrategies"]
if !ok { if !ok {
@ -148,7 +138,6 @@ func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) {
return errors.New("no cross exchange strategy is registered") return errors.New("no cross exchange strategy is registered")
} }
configList, ok := exchangeStrategiesConf.([]interface{}) configList, ok := exchangeStrategiesConf.([]interface{})
if !ok { if !ok {
return errors.New("expecting list in crossExchangeStrategies") return errors.New("expecting list in crossExchangeStrategies")