mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-21 22:43:52 +00:00
improve config loading by adding unmarshal yaml method
This commit is contained in:
parent
cd666fdf9e
commit
19f259111d
|
@ -27,12 +27,10 @@ reportPnL:
|
||||||
sessions:
|
sessions:
|
||||||
max:
|
max:
|
||||||
exchange: max
|
exchange: max
|
||||||
keyVar: MAX_API_KEY
|
envVarPrefix: "MAX_API_"
|
||||||
secretVar: MAX_API_SECRET
|
|
||||||
binance:
|
binance:
|
||||||
exchange: binance
|
exchange: binance
|
||||||
keyVar: BINANCE_API_KEY
|
envVarPrefix: "BINANCE_API_"
|
||||||
secretVar: BINANCE_API_SECRET
|
|
||||||
|
|
||||||
exchangeStrategies:
|
exchangeStrategies:
|
||||||
- on: binance
|
- on: binance
|
||||||
|
|
|
@ -25,13 +25,13 @@ func TestTradeService(t *testing.T) {
|
||||||
service.NewTradeService(xdb)
|
service.NewTradeService(xdb)
|
||||||
/*
|
/*
|
||||||
|
|
||||||
stmt := mock.ExpectQuery(`SELECT \* FROM trades WHERE symbol = \? ORDER BY gid DESC LIMIT 1`)
|
stmt := mock.ExpectQuery(`SELECT \* FROM trades WHERE symbol = \? ORDER BY gid DESC LIMIT 1`)
|
||||||
stmt.WithArgs("BTCUSDT")
|
stmt.WithArgs("BTCUSDT")
|
||||||
stmt.WillReturnRows(sqlmock.NewRows([]string{"gid", "id", "exchange", "symbol", "price", "quantity"}))
|
stmt.WillReturnRows(sqlmock.NewRows([]string{"gid", "id", "exchange", "symbol", "price", "quantity"}))
|
||||||
|
|
||||||
stmt2 := mock.ExpectQuery(`INSERT INTO trades (id, exchange, symbol, price, quantity, quote_quantity, side, is_buyer, is_maker, fee, fee_currency, traded_at)
|
stmt2 := mock.ExpectQuery(`INSERT INTO trades (id, exchange, symbol, price, quantity, quote_quantity, side, is_buyer, is_maker, fee, fee_currency, traded_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||||
stmt2.WithArgs()
|
stmt2.WithArgs()
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,16 +53,12 @@ func TestEnvironment_Connect(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
xdb, err := sqlx.Connect("mysql", mysqlURL)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
environment := NewEnvironment()
|
environment := NewEnvironment()
|
||||||
environment.AddExchange("binance", exchange).
|
environment.AddExchange("binance", exchange).
|
||||||
Subscribe(types.KLineChannel,"BTCUSDT", types.SubscribeOptions{})
|
Subscribe(types.KLineChannel, "BTCUSDT", types.SubscribeOptions{})
|
||||||
|
|
||||||
err = environment.Connect(ctx)
|
err := environment.Connect(ctx)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ func compileRunFile(filepath string, config *config.Config) error {
|
||||||
return ioutil.WriteFile(filepath, buf.Bytes(), 0644)
|
return ioutil.WriteFile(filepath, buf.Bytes(), 0644)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runConfig(ctx context.Context, config *config.Config) error {
|
func runConfig(ctx context.Context, userConfig *config.Config) error {
|
||||||
// configure notifiers
|
// configure notifiers
|
||||||
slackToken := viper.GetString("slack-token")
|
slackToken := viper.GetString("slack-token")
|
||||||
if len(slackToken) > 0 {
|
if len(slackToken) > 0 {
|
||||||
|
@ -93,21 +93,22 @@ func runConfig(ctx context.Context, config *config.Config) error {
|
||||||
trader := bbgo.NewTrader(environ)
|
trader := bbgo.NewTrader(environ)
|
||||||
trader.AddNotifier(notifierSet)
|
trader.AddNotifier(notifierSet)
|
||||||
|
|
||||||
for _, entry := range config.ExchangeStrategies {
|
for _, entry := range userConfig.ExchangeStrategies {
|
||||||
for _, mount := range entry.Mounts {
|
for _, mount := range entry.Mounts {
|
||||||
log.Infof("attaching strategy %T on %s...", entry.Strategy, mount)
|
log.Infof("attaching strategy %T on %s...", entry.Strategy, mount)
|
||||||
trader.AttachStrategyOn(mount, entry.Strategy)
|
trader.AttachStrategyOn(mount, entry.Strategy)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, strategy := range config.CrossExchangeStrategies {
|
for _, strategy := range userConfig.CrossExchangeStrategies {
|
||||||
log.Infof("attaching strategy %T", strategy)
|
log.Infof("attaching strategy %T", strategy)
|
||||||
trader.AttachCrossExchangeStrategy(strategy)
|
trader.AttachCrossExchangeStrategy(strategy)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, report := range config.PnLReporters {
|
for _, report := range userConfig.PnLReporters {
|
||||||
if len(report.AverageCostBySymbols) > 0 {
|
if len(report.AverageCostBySymbols) > 0 {
|
||||||
|
|
||||||
|
log.Infof("setting up average cost pnl reporter on symbols: %v", report.AverageCostBySymbols)
|
||||||
trader.ReportPnL(notifierSet).
|
trader.ReportPnL(notifierSet).
|
||||||
AverageCostBySymbols(report.AverageCostBySymbols...).
|
AverageCostBySymbols(report.AverageCostBySymbols...).
|
||||||
Of(report.Of...).
|
Of(report.Of...).
|
||||||
|
|
|
@ -40,6 +40,23 @@ func (s *StringSlice) decode(a interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StringSlice) UnmarshalYAML(unmarshal func(interface{}) error) (err error) {
|
||||||
|
var ss []string
|
||||||
|
err = unmarshal(&ss)
|
||||||
|
if err == nil {
|
||||||
|
*s = ss
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var as string
|
||||||
|
err = unmarshal(&as)
|
||||||
|
if err == nil {
|
||||||
|
*s = append(*s, as)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StringSlice) UnmarshalJSON(b []byte) error {
|
func (s *StringSlice) UnmarshalJSON(b []byte) error {
|
||||||
var a interface{}
|
var a interface{}
|
||||||
var err = json.Unmarshal(b, &a)
|
var err = json.Unmarshal(b, &a)
|
||||||
|
@ -51,18 +68,25 @@ func (s *StringSlice) UnmarshalJSON(b []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type PnLReporter struct {
|
type PnLReporter struct {
|
||||||
AverageCostBySymbols StringSlice `json:"averageCostBySymbols"`
|
AverageCostBySymbols StringSlice `json:"averageCostBySymbols" yaml:"averageCostBySymbols"`
|
||||||
Of StringSlice `json:"of" yaml:"of"`
|
Of StringSlice `json:"of" yaml:"of"`
|
||||||
When StringSlice `json:"when" yaml:"when"`
|
When StringSlice `json:"when" yaml:"when"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ExchangeName string `json:"exchange" yaml:"exchange"`
|
||||||
|
EnvVarPrefix string `json:"envVarPrefix" yaml:"envVarPrefix"`
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Imports []string `json:"imports" yaml:"imports"`
|
Imports []string `json:"imports" yaml:"imports"`
|
||||||
|
|
||||||
|
Sessions map[string]Session `json:"sessions,omitempty" yaml:"sessions,omitempty"`
|
||||||
|
|
||||||
ExchangeStrategies []SingleExchangeStrategyConfig
|
ExchangeStrategies []SingleExchangeStrategyConfig
|
||||||
CrossExchangeStrategies []bbgo.CrossExchangeStrategy
|
CrossExchangeStrategies []bbgo.CrossExchangeStrategy
|
||||||
|
|
||||||
PnLReporters []PnLReporter `json:"reportPnL" yaml:"reportPnL"`
|
PnLReporters []PnLReporter `json:"reportPnL,omitempty" yaml:"reportPnL,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Stash map[string]interface{}
|
type Stash map[string]interface{}
|
||||||
|
@ -84,12 +108,12 @@ func Load(configFile string) (*Config, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
stash, err := loadStash(content)
|
if err := yaml.Unmarshal(content, &config); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := loadImports(&config, stash); err != nil {
|
stash, err := loadStash(content)
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,43 +125,9 @@ func Load(configFile string) (*Config, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := loadReportPnL(&config, stash); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadImports(config *Config, stash Stash) error {
|
|
||||||
importStash, ok := stash["imports"]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
imports, err := reUnmarshal(importStash, &config.Imports)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Imports = *imports.(*[]string)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadReportPnL(config *Config, stash Stash) error {
|
|
||||||
reporterStash, ok := stash["reportPnL"]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
reporters, err := reUnmarshal(reporterStash, &config.PnLReporters)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config.PnLReporters = *(reporters.(*[]PnLReporter))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) {
|
func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) {
|
||||||
exchangeStrategiesConf, ok := stash["crossExchangeStrategies"]
|
exchangeStrategiesConf, ok := stash["crossExchangeStrategies"]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -148,7 +138,6 @@ func loadCrossExchangeStrategies(config *Config, stash Stash) (err error) {
|
||||||
return errors.New("no cross exchange strategy is registered")
|
return errors.New("no cross exchange strategy is registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
configList, ok := exchangeStrategiesConf.([]interface{})
|
configList, ok := exchangeStrategiesConf.([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("expecting list in crossExchangeStrategies")
|
return errors.New("expecting list in crossExchangeStrategies")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user