Merge pull request #131 from c9s/feature/strategy-trade-marker

feature: trade marker
This commit is contained in:
Yo-An Lin 2021-02-17 09:46:37 +08:00 committed by GitHub
commit 116657479c
19 changed files with 440 additions and 159 deletions

View File

@ -20,4 +20,5 @@ before_script:
script: script:
- bash scripts/test-sqlite3-migrations.sh - bash scripts/test-sqlite3-migrations.sh
- bash scripts/test-mysql-migrations.sh
- go test -v ./pkg/... - go test -v ./pkg/...

View File

@ -283,6 +283,15 @@ Delete chart:
helm delete bbgo helm delete bbgo
``` ```
## Development
### Adding new migration
```sh
rockhopper --config rockhopper_sqlite.yaml create --type sql add_pnl_column
rockhopper --config rockhopper_mysql.yaml create --type sql add_pnl_column
```
## Support ## Support
### By contributing pull requests ### By contributing pull requests

View File

@ -0,0 +1,19 @@
-- +up
-- +begin
ALTER TABLE `trades` ADD COLUMN `pnl` DECIMAL NULL;
-- +end
-- +begin
ALTER TABLE `trades` ADD COLUMN `strategy` VARCHAR(32) NULL;
-- +end
-- +down
-- +begin
ALTER TABLE `trades` DROP COLUMN `pnl`;
-- +end
-- +begin
ALTER TABLE `trades` DROP COLUMN `strategy`;
-- +end

View File

@ -0,0 +1,18 @@
-- +up
-- +begin
ALTER TABLE `trades` ADD COLUMN `pnl` DECIMAL NULL;
-- +end
-- +begin
ALTER TABLE `trades` ADD COLUMN `strategy` TEXT;
-- +end
-- +down
-- +begin
ALTER TABLE `trades` RENAME COLUMN `pnl` TO `pnl_deleted`;
-- +end
-- +begin
ALTER TABLE `trades` RENAME COLUMN `strategy` TO `strategy_deleted`;
-- +end

View File

@ -48,9 +48,9 @@ type Environment struct {
PersistenceServiceFacade *PersistenceServiceFacade PersistenceServiceFacade *PersistenceServiceFacade
DatabaseService *service.DatabaseService DatabaseService *service.DatabaseService
OrderService *service.OrderService OrderService *service.OrderService
TradeService *service.TradeService TradeService *service.TradeService
TradeSync *service.SyncService TradeSync *service.SyncService
// startTime is the time of start point (which is used in the backtest) // startTime is the time of start point (which is used in the backtest)
startTime time.Time startTime time.Time
@ -61,7 +61,7 @@ type Environment struct {
func NewEnvironment() *Environment { func NewEnvironment() *Environment {
return &Environment{ return &Environment{
// default trade scan time // default trade scan time
tradeScanTime: time.Now().AddDate(0, 0, -7), // sync from 7 days ago tradeScanTime: time.Now().AddDate(0, -1, 0), // sync from 1 month ago
sessions: make(map[string]*ExchangeSession), sessions: make(map[string]*ExchangeSession),
startTime: time.Now(), startTime: time.Now(),
} }
@ -83,7 +83,7 @@ func (environ *Environment) ConfigureDatabase(ctx context.Context, driver string
return err return err
} }
if err := environ.DatabaseService.Upgrade(ctx) ; err != nil { if err := environ.DatabaseService.Upgrade(ctx); err != nil {
return err return err
} }

View File

@ -3,8 +3,6 @@ package bbgo
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/sirupsen/logrus"
) )
func isSymbolBasedStrategy(rs reflect.Value) (string, bool) { func isSymbolBasedStrategy(rs reflect.Value) (string, bool) {
@ -31,8 +29,6 @@ func injectField(rs reflect.Value, fieldName string, obj interface{}, pointerOnl
return nil return nil
} }
logrus.Infof("found %s in %s, injecting %T...", fieldName, rs.Type(), obj)
if !field.CanSet() { if !field.CanSet() {
return fmt.Errorf("field %s of %s can not be set", fieldName, rs.Type()) return fmt.Errorf("field %s of %s can not be set", fieldName, rs.Type())
} }

View File

@ -0,0 +1,30 @@
package bbgo
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/service"
)
func Test_injectField(t *testing.T) {
type TT struct {
TradeService *service.TradeService
}
// only pointer object can be set.
var tt = &TT{}
// get the value of the pointer, or it can not be set.
var rv = reflect.ValueOf(tt).Elem()
_, ret := hasField(rv, "TradeService")
assert.True(t, ret)
ts := &service.TradeService{}
err := injectField(rv, "TradeService", ts, true)
assert.NoError(t, err)
}

View File

@ -1,6 +1,8 @@
package bbgo package bbgo
import ( import (
"fmt"
"github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
@ -15,6 +17,15 @@ type Position struct {
AverageCost fixedpoint.Value `json:"averageCost"` AverageCost fixedpoint.Value `json:"averageCost"`
} }
func (p Position) String() string {
return fmt.Sprintf("%s: average cost = %f, base = %f, quote = %f",
p.Symbol,
p.AverageCost.Float64(),
p.Base.Float64(),
p.Quote.Float64(),
)
}
func (p *Position) BindStream(stream types.Stream) { func (p *Position) BindStream(stream types.Stream) {
stream.OnTradeUpdate(func(trade types.Trade) { stream.OnTradeUpdate(func(trade types.Trade) {
if p.Symbol == trade.Symbol { if p.Symbol == trade.Symbol {

View File

@ -506,7 +506,7 @@ func (session *ExchangeSession) UpdatePrices(ctx context.Context) (err error) {
symbols := make([]string, len(balances)) symbols := make([]string, len(balances))
for _, b := range balances { for _, b := range balances {
symbols = append(symbols, b.Currency + "USDT") symbols = append(symbols, b.Currency+"USDT")
} }
tickers, err := session.Exchange.QueryTickers(ctx, symbols...) tickers, err := session.Exchange.QueryTickers(ctx, symbols...)

View File

@ -6,6 +6,7 @@ import (
"reflect" "reflect"
"sync" "sync"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
@ -147,6 +148,92 @@ func (trader *Trader) Subscribe() {
} }
} }
func (trader *Trader) RunSingleExchangeStrategy(ctx context.Context, strategy SingleExchangeStrategy, session *ExchangeSession, orderExecutor OrderExecutor) error {
rs := reflect.ValueOf(strategy)
// get the struct element
rs = rs.Elem()
if rs.Kind() != reflect.Struct {
return errors.New("strategy object is not a struct")
}
if err := trader.injectCommonServices(rs); err != nil {
return err
}
if err := injectField(rs, "OrderExecutor", orderExecutor, false); err != nil {
return errors.Wrapf(err, "failed to inject OrderExecutor on %T", strategy)
}
if symbol, ok := isSymbolBasedStrategy(rs); ok {
log.Debugf("found symbol based strategy from %s", rs.Type())
if _, ok := hasField(rs, "Market"); ok {
if market, ok := session.Market(symbol); ok {
// let's make the market object passed by pointer
if err := injectField(rs, "Market", &market, false); err != nil {
return errors.Wrapf(err, "failed to inject Market on %T", strategy)
}
}
}
// StandardIndicatorSet
if _, ok := hasField(rs, "StandardIndicatorSet"); ok {
if indicatorSet, ok := session.StandardIndicatorSet(symbol); ok {
if err := injectField(rs, "StandardIndicatorSet", indicatorSet, true); err != nil {
return errors.Wrapf(err, "failed to inject StandardIndicatorSet on %T", strategy)
}
}
}
if _, ok := hasField(rs, "MarketDataStore"); ok {
if store, ok := session.MarketDataStore(symbol); ok {
if err := injectField(rs, "MarketDataStore", store, true); err != nil {
return errors.Wrapf(err, "failed to inject MarketDataStore on %T", strategy)
}
}
}
}
return strategy.Run(ctx, orderExecutor, session)
}
func (trader *Trader) getSessionOrderExecutor(sessionName string) OrderExecutor {
var session = trader.environment.sessions[sessionName]
// default to base order executor
var orderExecutor OrderExecutor = session.orderExecutor
// Since the risk controls are loaded from the config file
if trader.riskControls != nil && trader.riskControls.SessionBasedRiskControl != nil {
if control, ok := trader.riskControls.SessionBasedRiskControl[sessionName]; ok {
control.SetBaseOrderExecutor(session.orderExecutor)
// pick the wrapped order executor
if control.OrderExecutor != nil {
return control.OrderExecutor
}
}
}
return orderExecutor
}
func (trader *Trader) RunAllSingleExchangeStrategy(ctx context.Context) error {
// load and run Session strategies
for sessionName, strategies := range trader.exchangeStrategies {
var session = trader.environment.sessions[sessionName]
var orderExecutor = trader.getSessionOrderExecutor(sessionName)
for _, strategy := range strategies {
if err := trader.RunSingleExchangeStrategy(ctx, strategy, session, orderExecutor); err != nil {
return err
}
}
}
return nil
}
func (trader *Trader) Run(ctx context.Context) error { func (trader *Trader) Run(ctx context.Context) error {
trader.Subscribe() trader.Subscribe()
@ -154,92 +241,8 @@ func (trader *Trader) Run(ctx context.Context) error {
return err return err
} }
// load and run Session strategies if err := trader.RunAllSingleExchangeStrategy(ctx); err != nil {
for sessionName, strategies := range trader.exchangeStrategies { return err
var session = trader.environment.sessions[sessionName]
// default to base order executor
var orderExecutor OrderExecutor = session.orderExecutor
// Since the risk controls are loaded from the config file
if riskControls := trader.riskControls; riskControls != nil {
if trader.riskControls.SessionBasedRiskControl != nil {
control, ok := trader.riskControls.SessionBasedRiskControl[sessionName]
if ok {
control.SetBaseOrderExecutor(session.orderExecutor)
// pick the order executor
if control.OrderExecutor != nil {
orderExecutor = control.OrderExecutor
}
}
}
}
for _, strategy := range strategies {
rs := reflect.ValueOf(strategy)
if rs.Elem().Kind() == reflect.Struct {
// get the struct element
rs = rs.Elem()
if err := injectField(rs, "Graceful", &trader.Graceful, true); err != nil {
log.WithError(err).Errorf("strategy Graceful injection failed")
return err
}
if err := injectField(rs, "Logger", &trader.logger, false); err != nil {
log.WithError(err).Errorf("strategy Logger injection failed")
return err
}
if err := injectField(rs, "Notifiability", &trader.environment.Notifiability, false); err != nil {
log.WithError(err).Errorf("strategy Notifiability injection failed")
return err
}
if err := injectField(rs, "OrderExecutor", orderExecutor, false); err != nil {
log.WithError(err).Errorf("strategy OrderExecutor injection failed")
return err
}
if symbol, ok := isSymbolBasedStrategy(rs); ok {
log.Infof("found symbol based strategy from %s", rs.Type())
if _, ok := hasField(rs, "Market"); ok {
if market, ok := session.Market(symbol); ok {
// let's make the market object passed by pointer
if err := injectField(rs, "Market", &market, false); err != nil {
log.WithError(err).Errorf("strategy %T Market injection failed", strategy)
return err
}
}
}
// StandardIndicatorSet
if _, ok := hasField(rs, "StandardIndicatorSet"); ok {
if indicatorSet, ok := session.StandardIndicatorSet(symbol); ok {
if err := injectField(rs, "StandardIndicatorSet", indicatorSet, true); err != nil {
log.WithError(err).Errorf("strategy %T StandardIndicatorSet injection failed", strategy)
return err
}
}
}
if _, ok := hasField(rs, "MarketDataStore"); ok {
if store, ok := session.MarketDataStore(symbol); ok {
if err := injectField(rs, "MarketDataStore", store, true); err != nil {
log.WithError(err).Errorf("strategy %T MarketDataStore injection failed", strategy)
return err
}
}
}
}
}
err := strategy.Run(ctx, orderExecutor, session)
if err != nil {
return err
}
}
} }
router := &ExchangeOrderExecutionRouter{ router := &ExchangeOrderExecutionRouter{
@ -249,52 +252,16 @@ func (trader *Trader) Run(ctx context.Context) error {
for _, strategy := range trader.crossExchangeStrategies { for _, strategy := range trader.crossExchangeStrategies {
rs := reflect.ValueOf(strategy) rs := reflect.ValueOf(strategy)
if rs.Elem().Kind() == reflect.Struct {
// get the struct element
rs = rs.Elem()
if field, ok := hasField(rs, "Persistence"); ok { // get the struct element from the struct pointer
if trader.environment.PersistenceServiceFacade == nil { rs = rs.Elem()
log.Warnf("strategy has Persistence field but persistence service is not defined")
} else {
log.Infof("found Persistence field, injecting...")
if field.IsNil() {
field.Set(reflect.ValueOf(&Persistence{
PersistenceSelector: &PersistenceSelector{
StoreID: "default",
Type: "memory",
},
Facade: trader.environment.PersistenceServiceFacade,
}))
} else {
elem := field.Elem()
if elem.Kind() != reflect.Struct {
return fmt.Errorf("the field Persistence is not a struct element")
}
if err := injectField(elem, "Facade", trader.environment.PersistenceServiceFacade, true); err != nil { if rs.Kind() != reflect.Struct {
log.WithError(err).Errorf("strategy Persistence injection failed") continue
return err }
}
}
}
}
if err := injectField(rs, "Graceful", &trader.Graceful, true); err != nil {
log.WithError(err).Errorf("strategy Graceful injection failed")
return err
}
if err := injectField(rs, "Logger", &trader.logger, false); err != nil {
log.WithError(err).Errorf("strategy Logger injection failed")
return err
}
if err := injectField(rs, "Notifiability", &trader.environment.Notifiability, false); err != nil {
log.WithError(err).Errorf("strategy Notifiability injection failed")
return err
}
if err := trader.injectCommonServices(rs); err != nil {
return err
} }
if err := strategy.CrossRun(ctx, router, trader.environment.sessions); err != nil { if err := strategy.CrossRun(ctx, router, trader.environment.sessions); err != nil {
@ -305,6 +272,53 @@ func (trader *Trader) Run(ctx context.Context) error {
return trader.environment.Connect(ctx) return trader.environment.Connect(ctx)
} }
func (trader *Trader) injectCommonServices(rs reflect.Value) error {
if err := injectField(rs, "Graceful", &trader.Graceful, true); err != nil {
return errors.Wrap(err, "failed to inject Graceful")
}
if err := injectField(rs, "Logger", &trader.logger, false); err != nil {
return errors.Wrap(err, "failed to inject Logger")
}
if err := injectField(rs, "Notifiability", &trader.environment.Notifiability, false); err != nil {
return errors.Wrap(err, "failed to inject Notifiability")
}
if trader.environment.TradeService != nil {
if err := injectField(rs, "TradeService", trader.environment.TradeService, true); err != nil {
return errors.Wrap(err, "failed to inject TradeService")
}
}
if field, ok := hasField(rs, "Persistence"); ok {
if trader.environment.PersistenceServiceFacade == nil {
log.Warnf("strategy has Persistence field but persistence service is not defined")
} else {
if field.IsNil() {
field.Set(reflect.ValueOf(&Persistence{
PersistenceSelector: &PersistenceSelector{
StoreID: "default",
Type: "memory",
},
Facade: trader.environment.PersistenceServiceFacade,
}))
} else {
elem := field.Elem()
if elem.Kind() != reflect.Struct {
return fmt.Errorf("field Persistence is not a struct element")
}
if err := injectField(elem, "Facade", trader.environment.PersistenceServiceFacade, true); err != nil {
return errors.Wrap(err, "failed to inject Persistence")
}
}
}
}
return nil
}
// ReportPnL configure and set the PnLReporter with the given notifier // ReportPnL configure and set the PnLReporter with the given notifier
func (trader *Trader) ReportPnL() *PnLReporterManager { func (trader *Trader) ReportPnL() *PnLReporterManager {
return NewPnLReporter(&trader.environment.Notifiability) return NewPnLReporter(&trader.environment.Notifiability)

View File

@ -37,7 +37,6 @@ func init() {
RunCmd.Flags().String("totp-account-name", "", "") RunCmd.Flags().String("totp-account-name", "", "")
RunCmd.Flags().Bool("enable-web-server", false, "enable web server") RunCmd.Flags().Bool("enable-web-server", false, "enable web server")
RunCmd.Flags().Bool("setup", false, "use setup mode") RunCmd.Flags().Bool("setup", false, "use setup mode")
RunCmd.Flags().String("since", "", "pnl since time")
RootCmd.AddCommand(RunCmd) RootCmd.AddCommand(RunCmd)
} }
@ -227,7 +226,9 @@ func ConfigureTrader(trader *bbgo.Trader, userConfig *bbgo.Config) error {
for _, entry := range userConfig.ExchangeStrategies { for _, entry := range userConfig.ExchangeStrategies {
for _, mount := range entry.Mounts { for _, mount := range entry.Mounts {
log.Infof("attaching strategy %T on %s...", entry.Strategy, mount) log.Infof("attaching strategy %T on %s...", entry.Strategy, mount)
trader.AttachStrategyOn(mount, entry.Strategy) if err := trader.AttachStrategyOn(mount, entry.Strategy) ; err != nil {
return err
}
} }
} }

View File

@ -498,6 +498,8 @@ func (e *Exchange) QueryTrades(ctx context.Context, symbol string, options *type
if options.Limit > 0 { if options.Limit > 0 {
req.Limit(options.Limit) req.Limit(options.Limit)
} else {
req.Limit(500)
} }
if options.LastTradeID > 0 { if options.LastTradeID > 0 {
@ -519,7 +521,7 @@ func (e *Exchange) QueryTrades(ctx context.Context, symbol string, options *type
continue continue
} }
logger.Infof("T: id=%d % 4s %s P=%f Q=%f %s", localTrade.ID, localTrade.Symbol, localTrade.Side, localTrade.Price, localTrade.Quantity, localTrade.Time) logger.Infof("T: %d %7s %4s P=%f Q=%f %s", localTrade.ID, localTrade.Symbol, localTrade.Side, localTrade.Price, localTrade.Quantity, localTrade.Time)
trades = append(trades, *localTrade) trades = append(trades, *localTrade)
} }

View File

@ -143,7 +143,7 @@ type PrivateTradeRequestParams struct {
Market string `json:"market"` Market string `json:"market"`
// Timestamp is the seconds elapsed since Unix epoch, set to return trades executed before the time only // Timestamp is the seconds elapsed since Unix epoch, set to return trades executed before the time only
Timestamp int `json:"timestamp,omitempty"` Timestamp int64 `json:"timestamp,omitempty"`
// From field is a trade id, set ot return trades created after the trade // From field is a trade id, set ot return trades created after the trade
From int64 `json:"from,omitempty"` From int64 `json:"from,omitempty"`
@ -176,6 +176,11 @@ func (r *PrivateTradeRequest) From(from int64) *PrivateTradeRequest {
return r return r
} }
func (r *PrivateTradeRequest) Timestamp(t int64) *PrivateTradeRequest {
r.params.Timestamp = t
return r
}
func (r *PrivateTradeRequest) To(to int64) *PrivateTradeRequest { func (r *PrivateTradeRequest) To(to int64) *PrivateTradeRequest {
r.params.To = to r.params.To = to
return r return r

View File

@ -1,6 +1,8 @@
package service package service
import ( import (
"context"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -12,6 +14,18 @@ import (
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
var ErrTradeNotFound = errors.New("trade not found")
type QueryTradesOptions struct {
Exchange types.ExchangeName
Symbol string
LastGID int64
// ASC or DESC
Ordering string
Limit int
}
type TradingVolume struct { type TradingVolume struct {
Year int `db:"year" json:"year"` Year int `db:"year" json:"year"`
Month int `db:"month" json:"month,omitempty"` Month int `db:"month" json:"month,omitempty"`
@ -80,7 +94,6 @@ func (s *TradeService) QueryTradingVolume(startTime time.Time, options TradingVo
return records, rows.Err() return records, rows.Err()
} }
func generateSqliteTradingVolumeSQL(options TradingVolumeQueryOptions) string { func generateSqliteTradingVolumeSQL(options TradingVolumeQueryOptions) string {
var sel []string var sel []string
var groupBys []string var groupBys []string
@ -127,7 +140,6 @@ func generateSqliteTradingVolumeSQL(options TradingVolumeQueryOptions) string {
return sql return sql
} }
func generateMysqlTradingVolumeQuerySQL(options TradingVolumeQueryOptions) string { func generateMysqlTradingVolumeQuerySQL(options TradingVolumeQueryOptions) string {
var sel []string var sel []string
var groupBys []string var groupBys []string
@ -137,8 +149,6 @@ func generateMysqlTradingVolumeQuerySQL(options TradingVolumeQueryOptions) strin
switch options.GroupByPeriod { switch options.GroupByPeriod {
case "month": case "month":
sel = append(sel, "YEAR(traded_at) AS year", "MONTH(traded_at) AS month") sel = append(sel, "YEAR(traded_at) AS year", "MONTH(traded_at) AS month")
groupBys = append([]string{"MONTH(traded_at)", "YEAR(traded_at)"}, groupBys...) groupBys = append([]string{"MONTH(traded_at)", "YEAR(traded_at)"}, groupBys...)
orderBys = append(orderBys, "year ASC", "month ASC") orderBys = append(orderBys, "year ASC", "month ASC")
@ -221,15 +231,6 @@ func (s *TradeService) QueryForTradingFeeCurrency(ex types.ExchangeName, symbol
return s.scanRows(rows) return s.scanRows(rows)
} }
type QueryTradesOptions struct {
Exchange types.ExchangeName
Symbol string
LastGID int64
// ASC or DESC
Ordering string
Limit int
}
func (s *TradeService) Query(options QueryTradesOptions) ([]types.Trade, error) { func (s *TradeService) Query(options QueryTradesOptions) ([]types.Trade, error) {
sql := queryTradesSQL(options) sql := queryTradesSQL(options)
@ -249,6 +250,69 @@ func (s *TradeService) Query(options QueryTradesOptions) ([]types.Trade, error)
return s.scanRows(rows) return s.scanRows(rows)
} }
func (s *TradeService) Load(ctx context.Context, id int64) (*types.Trade, error) {
var trade types.Trade
rows, err := s.DB.NamedQuery("SELECT * FROM trades WHERE id = :id", map[string]interface{}{
"id": id,
})
if err != nil {
return nil, err
}
defer rows.Close()
if rows.Next() {
err = rows.StructScan(&trade)
return &trade, err
}
return nil, errors.Wrapf(ErrTradeNotFound, "trade id:%d not found", id)
}
func (s *TradeService) MarkStrategyID(ctx context.Context, id int64, strategyID string) error {
result, err := s.DB.NamedExecContext(ctx, "UPDATE `trades` SET `strategy` = :strategy WHERE `id` = :id", map[string]interface{}{
"id": id,
"strategy": strategyID,
})
if err != nil {
return err
}
cnt, err := result.RowsAffected()
if err != nil {
return err
}
if cnt == 0 {
return fmt.Errorf("trade id:%d not found", id)
}
return nil
}
func (s *TradeService) UpdatePnL(ctx context.Context, id int64, pnl float64) error {
result, err := s.DB.NamedExecContext(ctx, "UPDATE `trades` SET `pnl` = :pnl WHERE `id` = :id", map[string]interface{}{
"id": id,
"pnl": pnl,
})
if err != nil {
return err
}
cnt, err := result.RowsAffected()
if err != nil {
return err
}
if cnt == 0 {
return fmt.Errorf("trade id:%d not found", id)
}
return nil
}
func queryTradesSQL(options QueryTradesOptions) string { func queryTradesSQL(options QueryTradesOptions) string {
ordering := "ASC" ordering := "ASC"
switch v := strings.ToUpper(options.Ordering); v { switch v := strings.ToUpper(options.Ordering); v {
@ -283,7 +347,10 @@ func queryTradesSQL(options QueryTradesOptions) string {
sql += ` ORDER BY gid ` + ordering sql += ` ORDER BY gid ` + ordering
sql += ` LIMIT ` + strconv.Itoa(options.Limit) if options.Limit > 0 {
sql += ` LIMIT ` + strconv.Itoa(options.Limit)
}
return sql return sql
} }

View File

@ -1,11 +1,96 @@
package service package service
import ( import (
"context"
"testing" "testing"
"github.com/c9s/rockhopper"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/c9s/bbgo/pkg/types"
) )
func prepareDB(t *testing.T) (*rockhopper.DB, error) {
dialect, err := rockhopper.LoadDialect("sqlite3")
if !assert.NoError(t, err) {
return nil, err
}
assert.NotNil(t, dialect)
db, err := rockhopper.Open("sqlite3", dialect, ":memory:")
if !assert.NoError(t, err) {
return nil, err
}
assert.NotNil(t, db)
_, err = db.CurrentVersion()
if !assert.NoError(t, err) {
return nil, err
}
var loader rockhopper.SqlMigrationLoader
migrations, err := loader.Load("../../migrations/sqlite3")
if !assert.NoError(t, err) {
return nil, err
}
assert.NotEmpty(t, migrations)
ctx := context.Background()
err = rockhopper.Up(ctx, db, migrations, 0, 0)
assert.NoError(t, err)
return db, err
}
func Test_tradeService(t *testing.T) {
db, err := prepareDB(t)
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx := context.Background()
xdb := sqlx.NewDb(db.DB, "sqlite3")
service := &TradeService{DB: xdb}
err = service.Insert(types.Trade{
ID: 1,
OrderID: 1,
Exchange: "binance",
Price: 1000.0,
Quantity: 0.1,
QuoteQuantity: 1000.0 * 0.1,
Symbol: "BTCUSDT",
Side: "BUY",
IsBuyer: true,
})
assert.NoError(t, err)
err = service.MarkStrategyID(ctx, 1, "grid")
assert.NoError(t, err)
tradeRecord, err := service.Load(ctx, 1)
assert.NoError(t, err)
assert.NotNil(t, tradeRecord)
assert.True(t, tradeRecord.StrategyID.Valid)
assert.Equal(t, "grid", tradeRecord.StrategyID.String)
err = service.UpdatePnL(ctx, 1, 10.0)
assert.NoError(t, err)
tradeRecord, err = service.Load(ctx, 1)
assert.NoError(t, err)
assert.NotNil(t, tradeRecord)
assert.True(t, tradeRecord.PnL.Valid)
assert.Equal(t, 10.0, tradeRecord.PnL.Float64)
}
func Test_queryTradingVolumeSQL(t *testing.T) { func Test_queryTradingVolumeSQL(t *testing.T) {
t.Run("group by different period", func(t *testing.T) { t.Run("group by different period", func(t *testing.T) {
o := TradingVolumeQueryOptions{ o := TradingVolumeQueryOptions{
@ -52,7 +137,7 @@ func Test_queryTradesSQL(t *testing.T) {
Symbol: "btc", Symbol: "btc",
LastGID: 123, LastGID: 123,
Ordering: "DESC", Ordering: "DESC",
Limit: 500, Limit: 500,
})) }))
}) })
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/c9s/bbgo/pkg/bbgo" "github.com/c9s/bbgo/pkg/bbgo"
"github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/service"
"github.com/c9s/bbgo/pkg/types" "github.com/c9s/bbgo/pkg/types"
) )
@ -39,6 +40,8 @@ type Strategy struct {
// This field will be injected automatically since we defined the Symbol field. // This field will be injected automatically since we defined the Symbol field.
types.Market `json:"-" yaml:"-"` types.Market `json:"-" yaml:"-"`
TradeService *service.TradeService `json:"-" yaml:"-"`
// These fields will be filled from the config file (it translates YAML to JSON) // These fields will be filled from the config file (it translates YAML to JSON)
Symbol string `json:"symbol" yaml:"symbol"` Symbol string `json:"symbol" yaml:"symbol"`
@ -305,6 +308,7 @@ func (s *Strategy) Subscribe(session *bbgo.ExchangeSession) {
} }
func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, session *bbgo.ExchangeSession) error { func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, session *bbgo.ExchangeSession) error {
// do some basic validation
if s.GridNum == 0 { if s.GridNum == 0 {
s.GridNum = 10 s.GridNum = 10
} }
@ -313,6 +317,14 @@ func (s *Strategy) Run(ctx context.Context, orderExecutor bbgo.OrderExecutor, se
return fmt.Errorf("upper price (%f) should not be less than lower price (%f)", s.UpperPrice.Float64(), s.LowerPrice.Float64()) return fmt.Errorf("upper price (%f) should not be less than lower price (%f)", s.UpperPrice.Float64(), s.LowerPrice.Float64())
} }
position, ok := session.Position(s.Symbol)
if !ok {
return fmt.Errorf("position not found")
}
log.Infof("position: %+v", position)
instanceID := fmt.Sprintf("grid-%s-%d", s.Symbol, s.GridNum) instanceID := fmt.Sprintf("grid-%s-%d", s.Symbol, s.GridNum)
s.groupID = generateGroupID(instanceID) s.groupID = generateGroupID(instanceID)
log.Infof("using group id %d from fnv(%s)", s.groupID, instanceID) log.Infof("using group id %d from fnv(%s)", s.groupID, instanceID)

View File

@ -1,6 +1,7 @@
package types package types
import ( import (
"database/sql"
"fmt" "fmt"
"sync" "sync"
@ -58,6 +59,9 @@ type Trade struct {
IsMargin bool `json:"isMargin" db:"is_margin"` IsMargin bool `json:"isMargin" db:"is_margin"`
IsIsolated bool `json:"isIsolated" db:"is_isolated"` IsIsolated bool `json:"isIsolated" db:"is_isolated"`
StrategyID sql.NullString `json:"strategyID" db:"strategy"`
PnL sql.NullFloat64 `json:"pnl" db:"pnl"`
} }
func (trade Trade) PlainText() string { func (trade Trade) PlainText() string {

View File

@ -0,0 +1,4 @@
#!/bin/bash
set -e
rockhopper --config rockhopper_mysql.yaml up
rockhopper --config rockhopper_mysql.yaml down --to 1

View File

@ -1,2 +1,5 @@
#!/bin/bash #!/bin/bash
rm -fv bbgo.sqlite3 && rockhopper --config rockhopper_sqlite.yaml up && rockhopper --config rockhopper_sqlite.yaml down --to 1 set -e
rm -fv bbgo.sqlite3
rockhopper --config rockhopper_sqlite.yaml up
rockhopper --config rockhopper_sqlite.yaml down --to 1