fix max trades query for pnl calulcation

This commit is contained in:
c9s 2020-10-10 17:50:49 +08:00
parent ce73979713
commit 7f6a478d4f
8 changed files with 350 additions and 26 deletions

View File

@ -2,6 +2,8 @@ package main
import (
"github.com/c9s/bbgo/cmd"
_ "github.com/go-sql-driver/mysql"
)
func main() {

View File

@ -1,6 +0,0 @@
package main
import (
_ "github.com/go-sql-driver/mysql"
)

View File

@ -13,17 +13,37 @@ import (
"github.com/c9s/bbgo/accounting"
"github.com/c9s/bbgo/bbgo"
binance2 "github.com/c9s/bbgo/exchange/binance"
"github.com/c9s/bbgo/exchange/binance"
"github.com/c9s/bbgo/exchange/max"
"github.com/c9s/bbgo/service"
"github.com/c9s/bbgo/types"
)
func init() {
PnLCmd.Flags().String("exchange", "", "target exchange")
PnLCmd.Flags().String("symbol", "BTCUSDT", "trading symbol")
PnLCmd.Flags().String("since", "", "pnl since time")
RootCmd.AddCommand(PnLCmd)
}
func newExchangeFromViper(n types.ExchangeName) types.Exchange {
switch n {
case types.ExchangeBinance:
key := viper.GetString("binance-api-key")
secret := viper.GetString("binance-api-secret")
return binance.New(key, secret)
case types.ExchangeMax:
key := viper.GetString("max-api-key")
secret := viper.GetString("max-api-secret")
return max.New(key, secret)
}
return nil
}
var PnLCmd = &cobra.Command{
Use: "pnl",
Short: "pnl calculator",
@ -31,14 +51,22 @@ var PnLCmd = &cobra.Command{
RunE: func(cmd *cobra.Command, args []string) error {
ctx := context.Background()
exchangeNameStr, err := cmd.Flags().GetString("exchange")
if err != nil {
return err
}
exchangeName, err := types.ValidExchangeName(exchangeNameStr)
if err != nil {
return err
}
symbol, err := cmd.Flags().GetString("symbol")
if err != nil {
return err
}
binanceKey := viper.GetString("binance-api-key")
binanceSecret := viper.GetString("binance-api-secret")
binanceExchange := binance2.New(binanceKey, binanceSecret)
exchange := newExchangeFromViper(exchangeName)
mysqlURL := viper.GetString("mysql-url")
mysqlURL = fmt.Sprintf("%s?parseTime=true", mysqlURL)
@ -69,12 +97,12 @@ var PnLCmd = &cobra.Command{
tradeSync := &service.TradeSync{Service: tradeService}
logrus.Info("syncing trades...")
if err := tradeSync.Sync(ctx, binanceExchange, symbol, startTime); err != nil {
if err := tradeSync.Sync(ctx, exchange, symbol, startTime); err != nil {
return err
}
var trades []types.Trade
tradingFeeCurrency := binanceExchange.PlatformFeeCurrency()
tradingFeeCurrency := exchange.PlatformFeeCurrency()
if strings.HasPrefix(symbol, tradingFeeCurrency) {
logrus.Infof("loading all trading fee currency related trades: %s", symbol)
trades, err = tradeService.QueryForTradingFeeCurrency(symbol, tradingFeeCurrency)
@ -101,7 +129,7 @@ var PnLCmd = &cobra.Command{
logrus.Infof("found checkpoints: %+v", checkpoints)
logrus.Infof("stock: %f", stockManager.Stocks.Quantity())
currentPrice, err := binanceExchange.QueryAveragePrice(ctx, symbol)
currentPrice, err := exchange.QueryAveragePrice(ctx, symbol)
calculator := &accounting.ProfitAndLossCalculator{
TradingFeeCurrency: tradingFeeCurrency,

View File

@ -7,6 +7,9 @@ import (
"strings"
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
maxapi "github.com/c9s/bbgo/exchange/max/maxapi"
"github.com/c9s/bbgo/types"
"github.com/c9s/bbgo/util"
@ -101,7 +104,7 @@ func (e *Exchange) QueryAccountBalances(ctx context.Context) (types.BalanceMap,
func (e *Exchange) QueryTrades(ctx context.Context, symbol string, options *types.TradeQueryOptions) (trades []types.Trade, err error) {
req := e.client.TradeService.NewPrivateTradeRequest()
req.Market(symbol)
req.Market(toLocalSymbol(symbol))
if options.Limit > 0 {
req.Limit(options.Limit)
@ -111,6 +114,9 @@ func (e *Exchange) QueryTrades(ctx context.Context, symbol string, options *type
req.From(options.LastTradeID)
}
// make it compatible with binance, we need the last trade id for the next page.
req.OrderBy("asc")
remoteTrades, err := req.Do(ctx)
if err != nil {
return nil, err
@ -119,15 +125,60 @@ func (e *Exchange) QueryTrades(ctx context.Context, symbol string, options *type
for _, t := range remoteTrades {
localTrade, err := convertRemoteTrade(t)
if err != nil {
logger.WithError(err).Errorf("can not convert binance trade: %+v", t)
logger.WithError(err).Errorf("can not convert trade: %+v", t)
continue
}
logger.Infof("T: id=%d % 4s %s P=%f Q=%f %s", localTrade.ID, localTrade.Symbol, localTrade.Side, localTrade.Price, localTrade.Quantity, localTrade.Time)
trades = append(trades, *localTrade)
}
return trades, nil
}
func (e *Exchange) QueryKLines(ctx context.Context, symbol, interval string, options types.KLineQueryOptions) ([]types.KLine, error) {
var limit = 5000
if options.Limit > 0 {
// default limit == 500
limit = options.Limit
}
if options.StartTime == nil {
return nil, errors.New("start time can not be empty")
}
if options.EndTime != nil {
return nil, errors.New("end time is not supported")
}
log.Infof("querying kline %s %s %v", symbol, interval, options)
// avoid rate limit
time.Sleep(100 * time.Millisecond)
localKlines, err := e.client.PublicService.KLines(toLocalSymbol(symbol), interval, *options.StartTime, limit)
if err != nil {
return nil, err
}
var kLines []types.KLine
for _, k := range localKlines {
kLines = append(kLines, k.KLine())
}
return kLines, nil
}
func (e *Exchange) QueryAveragePrice(ctx context.Context, symbol string) (float64, error) {
ticker, err := e.client.PublicService.Ticker(toLocalSymbol(symbol))
if err != nil {
return 0, err
}
return (util.MustParseFloat(ticker.Sell) + util.MustParseFloat(ticker.Buy)) / 2, nil
}
func toGlobalCurrency(currency string) string {
return strings.ToUpper(currency)
}
@ -136,6 +187,14 @@ func toLocalCurrency(currency string) string {
return strings.ToLower(currency)
}
func toLocalSymbol(symbol string) string {
return strings.ToLower(symbol)
}
func toGlobalSymbol(symbol string) string {
return strings.ToLower(symbol)
}
func toLocalSideType(side types.SideType) string {
return strings.ToLower(string(side))
}
@ -148,6 +207,9 @@ func toGlobalSideType(v string) string {
case "ask":
return "SELL"
case "self-trade":
return "SELF"
}
return strings.ToUpper(v)
@ -195,7 +257,7 @@ func convertRemoteTrade(t maxapi.Trade) (*types.Trade, error) {
return &types.Trade{
ID: int64(t.ID),
Price: price,
Symbol: t.Market,
Symbol: toGlobalSymbol(t.Market),
Exchange: "max",
Quantity: quantity,
Side: side,

View File

@ -1,10 +1,17 @@
package max
import (
"fmt"
"io/ioutil"
"net/url"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
"github.com/valyala/fastjson"
"github.com/c9s/bbgo/types"
)
type PublicService struct {
@ -140,3 +147,164 @@ func mustParseTicker(v *fastjson.Value) Ticker {
Low: string(v.GetStringBytes("low")),
}
}
type Interval int64
func parseResolution(a string) (Interval, error) {
switch strings.ToLower(a) {
case "1m":
return 1, nil
case "5m":
return 5, nil
case "15m":
return 15, nil
case "30m":
return 30, nil
case "1h":
return 60, nil
case "3h":
return 60 * 3, nil
case "6h":
return 60 * 6, nil
case "12h":
return 60 * 12, nil
case "1d":
return 60 * 24, nil
case "3d":
return 60 * 24 * 3, nil
case "1w":
return 60 * 24 * 7, nil
}
return 0, errors.New("incorrect resolution")
}
type KLine struct {
Symbol string
Interval string
StartTime, EndTime time.Time
Open, High, Low, Close float64
Volume float64
Closed bool
}
func (k KLine) KLine() types.KLine {
return types.KLine{
Symbol: k.Symbol,
Interval: k.Interval,
StartTime: k.StartTime,
EndTime: k.EndTime,
Open: k.Open,
Close: k.Close,
High: k.High,
Low: k.Low,
Volume: k.Volume,
// QuoteVolume: util.MustParseFloat(k.QuoteAssetVolume),
// LastTradeID: 0,
// NumberOfTrades: k.TradeNum,
Closed: k.Closed,
}
}
func (s *PublicService) KLines(symbol string, resolution string, start time.Time, limit int) ([]KLine, error) {
queries := url.Values{}
queries.Set("market", symbol)
interval, err := parseResolution(resolution)
if err != nil {
return nil, err
}
queries.Set("period", strconv.Itoa(int(interval)))
nilTime := time.Time{}
if start != nilTime {
queries.Set("timestamp", strconv.FormatInt(start.Unix(), 64))
}
if limit > 0 {
queries.Set("limit", strconv.Itoa(limit)) // default to 30, max limit = 10,000
}
req, err := s.client.newRequest("GET", fmt.Sprintf("%s/k", s.client.BaseURL), queries, nil)
if err != nil {
return nil, fmt.Errorf("request build error: %s", err.Error())
}
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %s", err.Error())
}
defer func() {
if err := resp.Body.Close(); err != nil {
logger.WithError(err).Error("failed to close resp body")
}
}()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return parseKLines(body, symbol, resolution, interval)
}
func parseKLines(payload []byte, symbol, resolution string, interval Interval) (klines []KLine, err error) {
var parser fastjson.Parser
v, err := parser.ParseBytes(payload)
if err != nil {
return nil, errors.Wrapf(err, "failed to parse payload: %s", payload)
}
arr, err := v.Array()
if err != nil {
return nil, errors.Wrapf(err, "failed to get array: %s", payload)
}
for _, x := range arr {
slice, err := x.Array()
if err != nil {
return nil, errors.Wrapf(err, "failed to get array: %s", payload)
}
if len(slice) < 6 {
return nil, fmt.Errorf("unexpected length of ohlc elements: %s", payload)
}
ts, err := slice[0].Int64()
if err != nil {
return nil, fmt.Errorf("failed to parse timestamp: %s", payload)
}
startTime := time.Unix(ts, 0).UTC()
endTime := time.Unix(ts, 0).Add(time.Duration(interval)*time.Minute - time.Millisecond).UTC()
isClosed := time.Now().Before(endTime)
klines = append(klines, KLine{
Symbol: symbol,
Interval: resolution,
StartTime: startTime,
EndTime: endTime,
Open: slice[1].GetFloat64(),
High: slice[2].GetFloat64(),
Low: slice[3].GetFloat64(),
Close: slice[4].GetFloat64(),
Volume: slice[5].GetFloat64(),
Closed: isClosed,
})
}
return klines, nil
}

View File

@ -12,6 +12,7 @@ import (
"math"
"net/http"
"net/url"
"reflect"
"regexp"
"strconv"
"sync/atomic"
@ -141,6 +142,8 @@ func (c *RestClient) initNonce() {
// 1 is for the request count mod 0.000 to 0.999
timeOffset = serverTimestamp - clientTime.Unix() - 1
logger.Infof("loaded max server timestamp: %d offset=%d", serverTimestamp, timeOffset)
}
func (c *RestClient) getNonce() int64 {
@ -166,6 +169,7 @@ func (c *RestClient) newRequest(method string, refURL string, params url.Values,
return nil, err
}
req.Header.Add("User-Agent", UserAgent)
return req, nil
}
@ -191,13 +195,16 @@ func (c *RestClient) newAuthenticatedRequest(m string, refURL string, data inter
p, err = json.Marshal(payload)
case PrivateRequestParams:
d.Nonce = c.getNonce()
d.Path = c.BaseURL.ResolveReference(rel).Path
default:
return nil, errors.New("unsupported payload type")
params, err := getPrivateRequestParamsObject(data)
if err != nil {
return nil, errors.Wrapf(err, "unsupported payload type: %T", d)
}
params.Nonce = c.getNonce()
params.Path = c.BaseURL.ResolveReference(rel).Path
p, err = json.Marshal(d)
}
if err != nil {
@ -228,6 +235,39 @@ func (c *RestClient) newAuthenticatedRequest(m string, refURL string, data inter
return req, nil
}
func getPrivateRequestParamsObject(v interface{}) (*PrivateRequestParams, error) {
vt := reflect.ValueOf(v)
if vt.Kind() == reflect.Ptr {
vt = vt.Elem()
}
if vt.Kind() != reflect.Struct {
return nil, errors.New("reflect error: given object is not a struct" + vt.Kind().String())
}
if !vt.CanSet() {
return nil, errors.New("reflect error: can not set object")
}
field := vt.FieldByName("PrivateRequestParams")
if !field.IsValid() {
return nil, errors.New("reflect error: field PrivateRequestParams not found")
}
if field.IsNil() {
field.Set(reflect.ValueOf(&PrivateRequestParams{}))
}
params, ok := field.Interface().(*PrivateRequestParams)
if !ok {
return nil, errors.New("reflect error: failed to cast value to *PrivateRequestParams")
}
return params, nil
}
func signPayload(payload string, secret string) string {
var sig = hmac.New(sha256.New, []byte(secret))
_, err := sig.Write([]byte(payload))

View File

@ -138,7 +138,7 @@ type PrivateRequestParams struct {
}
type PrivateTradeRequestParams struct {
PrivateRequestParams
*PrivateRequestParams
Market string `json:"market"`
@ -202,7 +202,7 @@ func (r *PrivateTradeRequest) OrderBy(orderBy string) *PrivateTradeRequest {
}
func (r *PrivateTradeRequest) Do(ctx context.Context) (trades []Trade, err error) {
req, err := r.client.newAuthenticatedRequest("GET", "v2/trades/my", r.params)
req, err := r.client.newAuthenticatedRequest("GET", "v2/trades/my", &r.params)
if err != nil {
return trades, err
}

View File

@ -2,11 +2,31 @@ package types
import (
"context"
"strings"
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
type ExchangeName string
const (
ExchangeMax = ExchangeName("max")
ExchangeBinance = ExchangeName("binance")
)
func ValidExchangeName(a string) (ExchangeName, error) {
switch strings.ToLower(a) {
case "max":
return ExchangeMax, nil
case "binance", "bn":
return ExchangeBinance, nil
}
return "", errors.New("invalid exchange name")
}
type Exchange interface {
PlatformFeeCurrency() string
@ -20,6 +40,8 @@ type Exchange interface {
QueryTrades(ctx context.Context, symbol string, options *TradeQueryOptions) ([]Trade, error)
QueryAveragePrice(ctx context.Context, symbol string) (float64, error)
SubmitOrder(ctx context.Context, order *SubmitOrder) error
}
@ -58,17 +80,17 @@ func (e ExchangeBatchProcessor) BatchQueryKLines(ctx context.Context, symbol, in
return allKLines, err
}
func (e ExchangeBatchProcessor) BatchQueryTrades(ctx context.Context, symbol string, options *TradeQueryOptions) (allTrades []Trade, err error) {
// last 7 days
var startTime = time.Now().Add(-7 * 24 * time.Hour)
if options.StartTime != nil {
startTime = *options.StartTime
}
log.Infof("querying %s trades from %s", symbol, startTime)
var lastTradeID = options.LastTradeID
for {
log.Infof("querying %s trades from %s, limit=%d", symbol, startTime, options.Limit)
trades, err := e.QueryTrades(ctx, symbol, &TradeQueryOptions{
StartTime: &startTime,
Limit: options.Limit,
@ -78,10 +100,18 @@ func (e ExchangeBatchProcessor) BatchQueryTrades(ctx context.Context, symbol str
return allTrades, err
}
if len(trades) == 0 {
break
}
if len(trades) == 1 && trades[0].ID == lastTradeID {
break
}
log.Infof("returned %d trades", len(trades))
startTime = trades[len(trades)-1].Time
for _, t := range trades {
// ignore the first trade if last TradeID is given
if t.ID == lastTradeID {