mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-10 01:01:56 +00:00
use CommonCallback and pull PersistenceTTL out
This commit is contained in:
parent
21e87079b5
commit
d3bc37f45e
|
@ -61,7 +61,7 @@ func generateOpenPositionOrders(market types.Market, quoteInvestment, price, pri
|
|||
prices = append(prices, price)
|
||||
}
|
||||
|
||||
notional, orderNum := calculateNotionalAndNum(market, quoteInvestment, prices)
|
||||
notional, orderNum := calculateNotionalAndNumOrders(market, quoteInvestment, prices)
|
||||
if orderNum == 0 {
|
||||
return nil, fmt.Errorf("failed to calculate notional and num of open position orders, price: %s, quote investment: %s", price, quoteInvestment)
|
||||
}
|
||||
|
@ -87,9 +87,9 @@ func generateOpenPositionOrders(market types.Market, quoteInvestment, price, pri
|
|||
return submitOrders, nil
|
||||
}
|
||||
|
||||
// calculateNotionalAndNum calculates the notional and num of open position orders
|
||||
// calculateNotionalAndNumOrders calculates the notional and num of open position orders
|
||||
// DCA2 is notional-based, every order has the same notional
|
||||
func calculateNotionalAndNum(market types.Market, quoteInvestment fixedpoint.Value, prices []fixedpoint.Value) (fixedpoint.Value, int) {
|
||||
func calculateNotionalAndNumOrders(market types.Market, quoteInvestment fixedpoint.Value, prices []fixedpoint.Value) (fixedpoint.Value, int) {
|
||||
for num := len(prices); num > 0; num-- {
|
||||
notional := quoteInvestment.Div(fixedpoint.NewFromInt(int64(num)))
|
||||
if notional.Compare(market.MinNotional) < 0 {
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
package dca2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -11,6 +9,21 @@ import (
|
|||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
type PersistenceTTL struct {
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func (p *PersistenceTTL) SetTTL(ttl time.Duration) {
|
||||
if ttl.Nanoseconds() <= 0 {
|
||||
return
|
||||
}
|
||||
p.ttl = ttl
|
||||
}
|
||||
|
||||
func (p *PersistenceTTL) Expiration() time.Duration {
|
||||
return p.ttl
|
||||
}
|
||||
|
||||
type ProfitStats struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Market types.Market `json:"market,omitempty"`
|
||||
|
@ -19,13 +32,12 @@ type ProfitStats struct {
|
|||
Round int64 `json:"round,omitempty"`
|
||||
QuoteInvestment fixedpoint.Value `json:"quoteInvestment,omitempty"`
|
||||
|
||||
RoundProfit fixedpoint.Value `json:"roundProfit,omitempty"`
|
||||
RoundFee map[string]fixedpoint.Value `json:"roundFee,omitempty"`
|
||||
TotalProfit fixedpoint.Value `json:"totalProfit,omitempty"`
|
||||
TotalFee map[string]fixedpoint.Value `json:"totalFee,omitempty"`
|
||||
CurrentRoundProfit fixedpoint.Value `json:"currentRoundProfit,omitempty"`
|
||||
CurrentRoundFee map[string]fixedpoint.Value `json:"currentRoundFee,omitempty"`
|
||||
TotalProfit fixedpoint.Value `json:"totalProfit,omitempty"`
|
||||
TotalFee map[string]fixedpoint.Value `json:"totalFee,omitempty"`
|
||||
|
||||
// ttl is the ttl to keep in persistence
|
||||
ttl time.Duration
|
||||
PersistenceTTL
|
||||
}
|
||||
|
||||
func newProfitStats(market types.Market, quoteInvestment fixedpoint.Value) *ProfitStats {
|
||||
|
@ -34,31 +46,20 @@ func newProfitStats(market types.Market, quoteInvestment fixedpoint.Value) *Prof
|
|||
Market: market,
|
||||
Round: 0,
|
||||
QuoteInvestment: quoteInvestment,
|
||||
RoundFee: make(map[string]fixedpoint.Value),
|
||||
CurrentRoundFee: make(map[string]fixedpoint.Value),
|
||||
TotalFee: make(map[string]fixedpoint.Value),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProfitStats) SetTTL(ttl time.Duration) {
|
||||
if ttl.Nanoseconds() <= 0 {
|
||||
return
|
||||
}
|
||||
s.ttl = ttl
|
||||
}
|
||||
|
||||
func (s *ProfitStats) Expiration() time.Duration {
|
||||
return s.ttl
|
||||
}
|
||||
|
||||
func (s *ProfitStats) AddTrade(trade types.Trade) {
|
||||
if s.RoundFee == nil {
|
||||
s.RoundFee = make(map[string]fixedpoint.Value)
|
||||
if s.CurrentRoundFee == nil {
|
||||
s.CurrentRoundFee = make(map[string]fixedpoint.Value)
|
||||
}
|
||||
|
||||
if fee, ok := s.RoundFee[trade.FeeCurrency]; ok {
|
||||
s.RoundFee[trade.FeeCurrency] = fee.Add(trade.Fee)
|
||||
if fee, ok := s.CurrentRoundFee[trade.FeeCurrency]; ok {
|
||||
s.CurrentRoundFee[trade.FeeCurrency] = fee.Add(trade.Fee)
|
||||
} else {
|
||||
s.RoundFee[trade.FeeCurrency] = trade.Fee
|
||||
s.CurrentRoundFee[trade.FeeCurrency] = trade.Fee
|
||||
}
|
||||
|
||||
if s.TotalFee == nil {
|
||||
|
@ -76,63 +77,19 @@ func (s *ProfitStats) AddTrade(trade types.Trade) {
|
|||
quoteQuantity = quoteQuantity.Neg()
|
||||
}
|
||||
|
||||
s.RoundProfit = s.RoundProfit.Add(quoteQuantity)
|
||||
s.CurrentRoundProfit = s.CurrentRoundProfit.Add(quoteQuantity)
|
||||
s.TotalProfit = s.TotalProfit.Add(quoteQuantity)
|
||||
|
||||
if s.Market.QuoteCurrency == trade.FeeCurrency {
|
||||
s.RoundProfit.Sub(trade.Fee)
|
||||
s.CurrentRoundProfit.Sub(trade.Fee)
|
||||
s.TotalProfit.Sub(trade.Fee)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProfitStats) NewRound() {
|
||||
s.Round++
|
||||
s.RoundProfit = fixedpoint.Zero
|
||||
s.RoundFee = make(map[string]fixedpoint.Value)
|
||||
}
|
||||
|
||||
func (s *ProfitStats) CalculateProfitOfRound(ctx context.Context, exchange types.Exchange) error {
|
||||
historyService, ok := exchange.(types.ExchangeTradeHistoryService)
|
||||
if !ok {
|
||||
return fmt.Errorf("exchange %s doesn't support ExchangeTradeHistoryService", exchange.Name())
|
||||
}
|
||||
|
||||
queryService, ok := exchange.(types.ExchangeOrderQueryService)
|
||||
if !ok {
|
||||
return fmt.Errorf("exchange %s doesn't support ExchangeOrderQueryService", exchange.Name())
|
||||
}
|
||||
|
||||
// query the orders of this round
|
||||
orders, err := historyService.QueryClosedOrders(ctx, s.Symbol, time.Time{}, time.Time{}, s.FromOrderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// query the trades of this round
|
||||
for _, order := range orders {
|
||||
if order.ExecutedQuantity.Sign() == 0 {
|
||||
// skip no trade orders
|
||||
continue
|
||||
}
|
||||
|
||||
trades, err := queryService.QueryOrderTrades(ctx, types.OrderQuery{
|
||||
Symbol: order.Symbol,
|
||||
OrderID: strconv.FormatUint(order.OrderID, 10),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, trade := range trades {
|
||||
s.AddTrade(trade)
|
||||
}
|
||||
}
|
||||
|
||||
s.FromOrderID = s.FromOrderID + 1
|
||||
s.QuoteInvestment = s.QuoteInvestment.Add(s.RoundProfit)
|
||||
|
||||
return nil
|
||||
s.CurrentRoundProfit = fixedpoint.Zero
|
||||
s.CurrentRoundFee = make(map[string]fixedpoint.Value)
|
||||
}
|
||||
|
||||
func (s *ProfitStats) String() string {
|
||||
|
@ -141,9 +98,9 @@ func (s *ProfitStats) String() string {
|
|||
sb.WriteString(fmt.Sprintf("Round: %d\n", s.Round))
|
||||
sb.WriteString(fmt.Sprintf("From Order ID: %d\n", s.FromOrderID))
|
||||
sb.WriteString(fmt.Sprintf("Quote Investment: %s\n", s.QuoteInvestment))
|
||||
sb.WriteString(fmt.Sprintf("Round Profit: %s\n", s.RoundProfit))
|
||||
sb.WriteString(fmt.Sprintf("Current Round Profit: %s\n", s.CurrentRoundProfit))
|
||||
sb.WriteString(fmt.Sprintf("Total Profit: %s\n", s.TotalProfit))
|
||||
for currency, fee := range s.RoundFee {
|
||||
for currency, fee := range s.CurrentRoundFee {
|
||||
sb.WriteString(fmt.Sprintf("FEE (%s): %s\n", currency, fee))
|
||||
}
|
||||
sb.WriteString("[------------------ Profit Stats ------------------]\n")
|
||||
|
|
|
@ -57,7 +57,7 @@ func (s *Strategy) recover(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// recover profit stats
|
||||
recoverProfitStats(ctx, s.ProfitStats, s.Session.Exchange)
|
||||
recoverProfitStats(ctx, s)
|
||||
|
||||
// recover startTimeOfNextRound
|
||||
startTimeOfNextRound := recoverStartTimeOfNextRound(ctx, currentRound, s.CoolDownInterval)
|
||||
|
@ -189,12 +189,12 @@ func recoverPosition(ctx context.Context, position *types.Position, queryService
|
|||
return nil
|
||||
}
|
||||
|
||||
func recoverProfitStats(ctx context.Context, profitStats *ProfitStats, exchange types.Exchange) error {
|
||||
if profitStats == nil {
|
||||
func recoverProfitStats(ctx context.Context, strategy *Strategy) error {
|
||||
if strategy.ProfitStats == nil {
|
||||
return fmt.Errorf("profit stats is nil, please check it")
|
||||
}
|
||||
|
||||
profitStats.CalculateProfitOfRound(ctx, exchange)
|
||||
strategy.CalculateProfitOfCurrentRound(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -193,7 +193,7 @@ func (s *Strategy) runTakeProfitReady(ctx context.Context, next State) {
|
|||
s.logger.Info("[State] TakeProfitReady - start reseting position and calculate quote investment for next round")
|
||||
|
||||
// calculate profit stats
|
||||
s.ProfitStats.CalculateProfitOfRound(ctx, s.Session.Exchange)
|
||||
s.CalculateProfitOfCurrentRound(ctx)
|
||||
bbgo.Sync(ctx, s)
|
||||
|
||||
s.EmitProfit(s.ProfitStats)
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -68,11 +69,9 @@ type Strategy struct {
|
|||
state State
|
||||
|
||||
// callbacks
|
||||
readyCallbacks []func()
|
||||
types.CommonCallback
|
||||
positionCallbacks []func(*types.Position)
|
||||
profitCallbacks []func(*ProfitStats)
|
||||
closedCallbacks []func()
|
||||
errorCallbacks []func(error)
|
||||
}
|
||||
|
||||
func (s *Strategy) ID() string {
|
||||
|
@ -278,3 +277,47 @@ func (s *Strategy) CleanUp(ctx context.Context) error {
|
|||
bbgo.Sync(ctx, s)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Strategy) CalculateProfitOfCurrentRound(ctx context.Context) error {
|
||||
historyService, ok := s.Session.Exchange.(types.ExchangeTradeHistoryService)
|
||||
if !ok {
|
||||
return fmt.Errorf("exchange %s doesn't support ExchangeTradeHistoryService", s.Session.Exchange.Name())
|
||||
}
|
||||
|
||||
queryService, ok := s.Session.Exchange.(types.ExchangeOrderQueryService)
|
||||
if !ok {
|
||||
return fmt.Errorf("exchange %s doesn't support ExchangeOrderQueryService", s.Session.Exchange.Name())
|
||||
}
|
||||
|
||||
// query the orders of this round
|
||||
orders, err := historyService.QueryClosedOrders(ctx, s.Symbol, time.Time{}, time.Time{}, s.ProfitStats.FromOrderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// query the trades of this round
|
||||
for _, order := range orders {
|
||||
if order.ExecutedQuantity.Sign() == 0 {
|
||||
// skip no trade orders
|
||||
continue
|
||||
}
|
||||
|
||||
trades, err := queryService.QueryOrderTrades(ctx, types.OrderQuery{
|
||||
Symbol: order.Symbol,
|
||||
OrderID: strconv.FormatUint(order.OrderID, 10),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, trade := range trades {
|
||||
s.ProfitStats.AddTrade(trade)
|
||||
}
|
||||
}
|
||||
|
||||
s.ProfitStats.FromOrderID = s.ProfitStats.FromOrderID + 1
|
||||
s.ProfitStats.QuoteInvestment = s.ProfitStats.QuoteInvestment.Add(s.ProfitStats.CurrentRoundProfit)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -6,16 +6,6 @@ import (
|
|||
"github.com/c9s/bbgo/pkg/types"
|
||||
)
|
||||
|
||||
func (s *Strategy) OnReady(cb func()) {
|
||||
s.readyCallbacks = append(s.readyCallbacks, cb)
|
||||
}
|
||||
|
||||
func (s *Strategy) EmitReady() {
|
||||
for _, cb := range s.readyCallbacks {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Strategy) OnPosition(cb func(*types.Position)) {
|
||||
s.positionCallbacks = append(s.positionCallbacks, cb)
|
||||
}
|
||||
|
@ -35,23 +25,3 @@ func (s *Strategy) EmitProfit(profitStats *ProfitStats) {
|
|||
cb(profitStats)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Strategy) OnClosed(cb func()) {
|
||||
s.closedCallbacks = append(s.closedCallbacks, cb)
|
||||
}
|
||||
|
||||
func (s *Strategy) EmitClosed() {
|
||||
for _, cb := range s.closedCallbacks {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Strategy) OnError(cb func(err error)) {
|
||||
s.errorCallbacks = append(s.errorCallbacks, cb)
|
||||
}
|
||||
|
||||
func (s *Strategy) EmitError(err error) {
|
||||
for _, cb := range s.errorCallbacks {
|
||||
cb(err)
|
||||
}
|
||||
}
|
||||
|
|
37
pkg/types/callbacks.go
Normal file
37
pkg/types/callbacks.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package types
|
||||
|
||||
type CommonCallback struct {
|
||||
readyCallbacks []func()
|
||||
closedCallbacks []func()
|
||||
errorCallbacks []func(error)
|
||||
}
|
||||
|
||||
func (c *CommonCallback) OnReady(cb func()) {
|
||||
c.readyCallbacks = append(c.readyCallbacks, cb)
|
||||
}
|
||||
|
||||
func (c *CommonCallback) EmitReady() {
|
||||
for _, cb := range c.readyCallbacks {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CommonCallback) OnClosed(cb func()) {
|
||||
c.closedCallbacks = append(c.closedCallbacks, cb)
|
||||
}
|
||||
|
||||
func (c *CommonCallback) EmitClosed() {
|
||||
for _, cb := range c.closedCallbacks {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CommonCallback) OnError(cb func(err error)) {
|
||||
c.errorCallbacks = append(c.errorCallbacks, cb)
|
||||
}
|
||||
|
||||
func (c *CommonCallback) EmitError(err error) {
|
||||
for _, cb := range c.errorCallbacks {
|
||||
cb(err)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user