mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-25 00:05:15 +00:00
improve config loading by adding unmarshal yaml method
This commit is contained in:
parent
cd666fdf9e
commit
19f259111d
|
@ -27,12 +27,10 @@ reportPnL:
|
|||
sessions:
|
||||
max:
|
||||
exchange: max
|
||||
keyVar: MAX_API_KEY
|
||||
secretVar: MAX_API_SECRET
|
||||
envVarPrefix: "MAX_API_"
|
||||
binance:
|
||||
exchange: binance
|
||||
keyVar: BINANCE_API_KEY
|
||||
secretVar: BINANCE_API_SECRET
|
||||
envVarPrefix: "BINANCE_API_"
|
||||
|
||||
exchangeStrategies:
|
||||
- on: binance
|
||||
|
|
|
@ -53,16 +53,12 @@ func TestEnvironment_Connect(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
xdb, err := sqlx.Connect("mysql", mysqlURL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
environment := NewEnvironment()
|
||||
environment.AddExchange("binance", exchange).
|
||||
Subscribe(types.KLineChannel,"BTCUSDT", types.SubscribeOptions{})
|
||||
Subscribe(types.KLineChannel, "BTCUSDT", types.SubscribeOptions{})
|
||||
|
||||
err = environment.Connect(ctx)
|
||||
err := environment.Connect(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ func compileRunFile(filepath string, config *config.Config) error {
|
|||
return ioutil.WriteFile(filepath, buf.Bytes(), 0644)
|
||||
}
|
||||
|
||||
func runConfig(ctx context.Context, config *config.Config) error {
|
||||
func runConfig(ctx context.Context, userConfig *config.Config) error {
|
||||
// configure notifiers
|
||||
slackToken := viper.GetString("slack-token")
|
||||
if len(slackToken) > 0 {
|
||||
|
@ -93,21 +93,22 @@ func runConfig(ctx context.Context, config *config.Config) error {
|
|||
trader := bbgo.NewTrader(environ)
|
||||
trader.AddNotifier(notifierSet)
|
||||
|
||||
for _, entry := range config.ExchangeStrategies {
|
||||
for _, entry := range userConfig.ExchangeStrategies {
|
||||
for _, mount := range entry.Mounts {
|
||||
log.Infof("attaching strategy %T on %s...", entry.Strategy, mount)
|
||||
trader.AttachStrategyOn(mount, entry.Strategy)
|
||||
}
|
||||
}
|
||||
|
||||
for _, strategy := range config.CrossExchangeStrategies {
|
||||
for _, strategy := range userConfig.CrossExchangeStrategies {
|
||||
log.Infof("attaching strategy %T", strategy)
|
||||
trader.AttachCrossExchangeStrategy(strategy)
|
||||
}
|
||||
|
||||
for _, report := range config.PnLReporters {
|
||||
for _, report := range userConfig.PnLReporters {
|
||||
if len(report.AverageCostBySymbols) > 0 {
|
||||
|
||||
log.Infof("setting up average cost pnl reporter on symbols: %v", report.AverageCostBySymbols)
|
||||
trader.ReportPnL(notifierSet).
|
||||
AverageCostBySymbols(report.AverageCostBySymbols...).
|
||||
Of(report.Of...).
|
||||
|
|
|
@ -40,6 +40,23 @@ func (s *StringSlice) decode(a interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *StringSlice) UnmarshalYAML(unmarshal func(interface{}) error) (err error) {
|
||||
var ss []string
|
||||
err = unmarshal(&ss)
|
||||
if err == nil {
|
||||
*s = ss
|
||||
return
|
||||
}
|
||||
|
||||
var as string
|
||||
err = unmarshal(&as)
|
||||
if err == nil {
|
||||
*s = append(*s, as)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *StringSlice) UnmarshalJSON(b []byte) error {
|
||||
var a interface{}
|
||||
var err = json.Unmarshal(b, &a)
|
||||
|
@ -51,18 +68,25 @@ func (s *StringSlice) UnmarshalJSON(b []byte) error {
|
|||
}
|
||||
|
||||
type PnLReporter struct {
|
||||
AverageCostBySymbols StringSlice `json:"averageCostBySymbols"`
|
||||
AverageCostBySymbols StringSlice `json:"averageCostBySymbols" yaml:"averageCostBySymbols"`
|
||||
Of StringSlice `json:"of" yaml:"of"`
|
||||
When StringSlice `json:"when" yaml:"when"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ExchangeName string `json:"exchange" yaml:"exchange"`
|
||||
EnvVarPrefix string `json:"envVarPrefix" yaml:"envVarPrefix"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Imports []string `json:"imports" yaml:"imports"`
|
||||
|
||||
Sessions map[string]Session `json:"sessions,omitempty" yaml:"sessions,omitempty"`
|
||||
|
||||
ExchangeStrategies []SingleExchangeStrategyConfig
|
||||
CrossExchangeStrategies []bbgo.CrossExchangeStrategy
|
||||
|
||||
PnLReporters []PnLReporter `json:"reportPnL" yaml:"reportPnL"`
|
||||
PnLReporters []PnLReporter `json:"reportPnL,omitempty" yaml:"reportPnL,omitempty"`
|
||||
}
|
||||
|
||||
type Stash map[string]interface{}
|
||||
|
@ -84,12 +108,12 @@ func Load(configFile string) (*Config, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
stash, err := loadStash(content)
|
||||
if err != nil {
|
||||
if err := yaml.Unmarshal(content, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := loadImports(&config, stash); err != nil {
|
||||
stash, err := loadStash(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -101,43 +125,9 @@ func Load(configFile string) (*Config, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := loadReportPnL(&config, stash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func loadImports(config *Config, stash Stash) error {
|
||||
importStash, ok := stash["imports"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
imports, err := reUnmarshal(importStash, &config.Imports)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.Imports = *imports.(*[]string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadReportPnL(config *Config, stash Stash) error {
|
||||
reporterStash, ok := stash["reportPnL"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
reporters, err := reUnmarshal(reporterStash, &config.PnLReporters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.PnLReporters = *(reporters.(*[]PnLReporter))
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) {
|
||||
exchangeStrategiesConf, ok := stash["crossExchangeStrategies"]
|
||||
if !ok {
|
||||
|
@ -148,7 +138,6 @@ func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) {
|
|||
return errors.New("no cross exchange strategy is registered")
|
||||
}
|
||||
|
||||
|
||||
configList, ok := exchangeStrategiesConf.([]interface{})
|
||||
if !ok {
|
||||
return errors.New("expecting list in crossExchangeStrategies")
|
||||
|
|
Loading…
Reference in New Issue
Block a user