improve config loading by adding unmarshal yaml method

This commit is contained in:
c9s 2020-10-26 15:31:13 +08:00
parent cd666fdf9e
commit 19f259111d
4 changed files with 44 additions and 60 deletions

View File

@ -27,12 +27,10 @@ reportPnL:
sessions:
max:
exchange: max
keyVar: MAX_API_KEY
secretVar: MAX_API_SECRET
envVarPrefix: "MAX_API_"
binance:
exchange: binance
keyVar: BINANCE_API_KEY
secretVar: BINANCE_API_SECRET
envVarPrefix: "BINANCE_API_"
exchangeStrategies:
- on: binance

View File

@ -53,16 +53,12 @@ func TestEnvironment_Connect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
xdb, err := sqlx.Connect("mysql", mysqlURL)
assert.NoError(t, err)
environment := NewEnvironment()
environment.AddExchange("binance", exchange).
Subscribe(types.KLineChannel,"BTCUSDT", types.SubscribeOptions{})
Subscribe(types.KLineChannel, "BTCUSDT", types.SubscribeOptions{})
err = environment.Connect(ctx)
err := environment.Connect(ctx)
assert.NoError(t, err)
time.Sleep(5 * time.Second)
}

View File

@ -66,7 +66,7 @@ func compileRunFile(filepath string, config *config.Config) error {
return ioutil.WriteFile(filepath, buf.Bytes(), 0644)
}
func runConfig(ctx context.Context, config *config.Config) error {
func runConfig(ctx context.Context, userConfig *config.Config) error {
// configure notifiers
slackToken := viper.GetString("slack-token")
if len(slackToken) > 0 {
@ -93,21 +93,22 @@ func runConfig(ctx context.Context, config *config.Config) error {
trader := bbgo.NewTrader(environ)
trader.AddNotifier(notifierSet)
for _, entry := range config.ExchangeStrategies {
for _, entry := range userConfig.ExchangeStrategies {
for _, mount := range entry.Mounts {
log.Infof("attaching strategy %T on %s...", entry.Strategy, mount)
trader.AttachStrategyOn(mount, entry.Strategy)
}
}
for _, strategy := range config.CrossExchangeStrategies {
for _, strategy := range userConfig.CrossExchangeStrategies {
log.Infof("attaching strategy %T", strategy)
trader.AttachCrossExchangeStrategy(strategy)
}
for _, report := range config.PnLReporters {
for _, report := range userConfig.PnLReporters {
if len(report.AverageCostBySymbols) > 0 {
log.Infof("setting up average cost pnl reporter on symbols: %v", report.AverageCostBySymbols)
trader.ReportPnL(notifierSet).
AverageCostBySymbols(report.AverageCostBySymbols...).
Of(report.Of...).

View File

@ -40,6 +40,23 @@ func (s *StringSlice) decode(a interface{}) error {
return nil
}
func (s *StringSlice) UnmarshalYAML(unmarshal func(interface{}) error) (err error) {
var ss []string
err = unmarshal(&ss)
if err == nil {
*s = ss
return
}
var as string
err = unmarshal(&as)
if err == nil {
*s = append(*s, as)
}
return err
}
func (s *StringSlice) UnmarshalJSON(b []byte) error {
var a interface{}
var err = json.Unmarshal(b, &a)
@ -51,18 +68,25 @@ func (s *StringSlice) UnmarshalJSON(b []byte) error {
}
type PnLReporter struct {
AverageCostBySymbols StringSlice `json:"averageCostBySymbols"`
AverageCostBySymbols StringSlice `json:"averageCostBySymbols" yaml:"averageCostBySymbols"`
Of StringSlice `json:"of" yaml:"of"`
When StringSlice `json:"when" yaml:"when"`
}
type Session struct {
ExchangeName string `json:"exchange" yaml:"exchange"`
EnvVarPrefix string `json:"envVarPrefix" yaml:"envVarPrefix"`
}
type Config struct {
Imports []string `json:"imports" yaml:"imports"`
Sessions map[string]Session `json:"sessions,omitempty" yaml:"sessions,omitempty"`
ExchangeStrategies []SingleExchangeStrategyConfig
CrossExchangeStrategies []bbgo.CrossExchangeStrategy
PnLReporters []PnLReporter `json:"reportPnL" yaml:"reportPnL"`
PnLReporters []PnLReporter `json:"reportPnL,omitempty" yaml:"reportPnL,omitempty"`
}
type Stash map[string]interface{}
@ -84,12 +108,12 @@ func Load(configFile string) (*Config, error) {
return nil, err
}
stash, err := loadStash(content)
if err != nil {
if err := yaml.Unmarshal(content, &config); err != nil {
return nil, err
}
if err := loadImports(&config, stash); err != nil {
stash, err := loadStash(content)
if err != nil {
return nil, err
}
@ -101,43 +125,9 @@ func Load(configFile string) (*Config, error) {
return nil, err
}
if err := loadReportPnL(&config, stash); err != nil {
return nil, err
}
return &config, nil
}
func loadImports(config *Config, stash Stash) error {
importStash, ok := stash["imports"]
if !ok {
return nil
}
imports, err := reUnmarshal(importStash, &config.Imports)
if err != nil {
return err
}
config.Imports = *imports.(*[]string)
return nil
}
func loadReportPnL(config *Config, stash Stash) error {
reporterStash, ok := stash["reportPnL"]
if !ok {
return nil
}
reporters, err := reUnmarshal(reporterStash, &config.PnLReporters)
if err != nil {
return err
}
config.PnLReporters = *(reporters.(*[]PnLReporter))
return nil
}
func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) {
exchangeStrategiesConf, ok := stash["crossExchangeStrategies"]
if !ok {
@ -148,7 +138,6 @@ func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) {
return errors.New("no cross exchange strategy is registered")
}
configList, ok := exchangeStrategiesConf.([]interface{})
if !ok {
return errors.New("expecting list in crossExchangeStrategies")