From 7f6a478d4fb2e1de385ce3de8533a9a82e127efc Mon Sep 17 00:00:00 2001 From: c9s Date: Sat, 10 Oct 2020 17:50:49 +0800 Subject: [PATCH] fix max trades query for pnl calulcation --- cmd/bbgo/main.go | 2 + cmd/bbgo/pnl.go | 6 -- cmd/pnl.go | 42 +++++++-- exchange/max/exchange.go | 68 ++++++++++++- exchange/max/maxapi/public.go | 168 +++++++++++++++++++++++++++++++++ exchange/max/maxapi/restapi.go | 50 +++++++++- exchange/max/maxapi/trade.go | 4 +- types/exchange.go | 36 ++++++- 8 files changed, 350 insertions(+), 26 deletions(-) delete mode 100644 cmd/bbgo/pnl.go diff --git a/cmd/bbgo/main.go b/cmd/bbgo/main.go index ee303b2bc..5b4a66841 100644 --- a/cmd/bbgo/main.go +++ b/cmd/bbgo/main.go @@ -2,6 +2,8 @@ package main import ( "github.com/c9s/bbgo/cmd" + + _ "github.com/go-sql-driver/mysql" ) func main() { diff --git a/cmd/bbgo/pnl.go b/cmd/bbgo/pnl.go deleted file mode 100644 index e46cbfd1f..000000000 --- a/cmd/bbgo/pnl.go +++ /dev/null @@ -1,6 +0,0 @@ -package main - -import ( - _ "github.com/go-sql-driver/mysql" -) - diff --git a/cmd/pnl.go b/cmd/pnl.go index dccbb93c1..73ad45478 100644 --- a/cmd/pnl.go +++ b/cmd/pnl.go @@ -13,17 +13,37 @@ import ( "github.com/c9s/bbgo/accounting" "github.com/c9s/bbgo/bbgo" - binance2 "github.com/c9s/bbgo/exchange/binance" + "github.com/c9s/bbgo/exchange/binance" + "github.com/c9s/bbgo/exchange/max" "github.com/c9s/bbgo/service" "github.com/c9s/bbgo/types" ) func init() { + PnLCmd.Flags().String("exchange", "", "target exchange") PnLCmd.Flags().String("symbol", "BTCUSDT", "trading symbol") PnLCmd.Flags().String("since", "", "pnl since time") RootCmd.AddCommand(PnLCmd) } +func newExchangeFromViper(n types.ExchangeName) types.Exchange { + switch n { + + case types.ExchangeBinance: + key := viper.GetString("binance-api-key") + secret := viper.GetString("binance-api-secret") + return binance.New(key, secret) + + case types.ExchangeMax: + key := viper.GetString("max-api-key") + secret := viper.GetString("max-api-secret") + return max.New(key, secret) + + } + + return nil +} + var PnLCmd = &cobra.Command{ Use: "pnl", Short: "pnl calculator", @@ -31,14 +51,22 @@ var PnLCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() + exchangeNameStr, err := cmd.Flags().GetString("exchange") + if err != nil { + return err + } + + exchangeName, err := types.ValidExchangeName(exchangeNameStr) + if err != nil { + return err + } + symbol, err := cmd.Flags().GetString("symbol") if err != nil { return err } - binanceKey := viper.GetString("binance-api-key") - binanceSecret := viper.GetString("binance-api-secret") - binanceExchange := binance2.New(binanceKey, binanceSecret) + exchange := newExchangeFromViper(exchangeName) mysqlURL := viper.GetString("mysql-url") mysqlURL = fmt.Sprintf("%s?parseTime=true", mysqlURL) @@ -69,12 +97,12 @@ var PnLCmd = &cobra.Command{ tradeSync := &service.TradeSync{Service: tradeService} logrus.Info("syncing trades...") - if err := tradeSync.Sync(ctx, binanceExchange, symbol, startTime); err != nil { + if err := tradeSync.Sync(ctx, exchange, symbol, startTime); err != nil { return err } var trades []types.Trade - tradingFeeCurrency := binanceExchange.PlatformFeeCurrency() + tradingFeeCurrency := exchange.PlatformFeeCurrency() if strings.HasPrefix(symbol, tradingFeeCurrency) { logrus.Infof("loading all trading fee currency related trades: %s", symbol) trades, err = tradeService.QueryForTradingFeeCurrency(symbol, tradingFeeCurrency) @@ -101,7 +129,7 @@ var PnLCmd = &cobra.Command{ logrus.Infof("found checkpoints: %+v", checkpoints) logrus.Infof("stock: %f", stockManager.Stocks.Quantity()) - currentPrice, err := binanceExchange.QueryAveragePrice(ctx, symbol) + currentPrice, err := exchange.QueryAveragePrice(ctx, symbol) calculator := &accounting.ProfitAndLossCalculator{ TradingFeeCurrency: tradingFeeCurrency, diff --git a/exchange/max/exchange.go b/exchange/max/exchange.go index bdc97c3e5..c58cdfa89 100644 --- a/exchange/max/exchange.go +++ b/exchange/max/exchange.go @@ -7,6 +7,9 @@ import ( "strings" "time" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + maxapi "github.com/c9s/bbgo/exchange/max/maxapi" "github.com/c9s/bbgo/types" "github.com/c9s/bbgo/util" @@ -101,7 +104,7 @@ func (e *Exchange) QueryAccountBalances(ctx context.Context) (types.BalanceMap, func (e *Exchange) QueryTrades(ctx context.Context, symbol string, options *types.TradeQueryOptions) (trades []types.Trade, err error) { req := e.client.TradeService.NewPrivateTradeRequest() - req.Market(symbol) + req.Market(toLocalSymbol(symbol)) if options.Limit > 0 { req.Limit(options.Limit) @@ -111,6 +114,9 @@ func (e *Exchange) QueryTrades(ctx context.Context, symbol string, options *type req.From(options.LastTradeID) } + // make it compatible with binance, we need the last trade id for the next page. + req.OrderBy("asc") + remoteTrades, err := req.Do(ctx) if err != nil { return nil, err @@ -119,15 +125,60 @@ func (e *Exchange) QueryTrades(ctx context.Context, symbol string, options *type for _, t := range remoteTrades { localTrade, err := convertRemoteTrade(t) if err != nil { - logger.WithError(err).Errorf("can not convert binance trade: %+v", t) + logger.WithError(err).Errorf("can not convert trade: %+v", t) continue } + + logger.Infof("T: id=%d % 4s %s P=%f Q=%f %s", localTrade.ID, localTrade.Symbol, localTrade.Side, localTrade.Price, localTrade.Quantity, localTrade.Time) + trades = append(trades, *localTrade) } return trades, nil } +func (e *Exchange) QueryKLines(ctx context.Context, symbol, interval string, options types.KLineQueryOptions) ([]types.KLine, error) { + var limit = 5000 + if options.Limit > 0 { + // default limit == 500 + limit = options.Limit + } + + if options.StartTime == nil { + return nil, errors.New("start time can not be empty") + } + + if options.EndTime != nil { + return nil, errors.New("end time is not supported") + } + + log.Infof("querying kline %s %s %v", symbol, interval, options) + + // avoid rate limit + time.Sleep(100 * time.Millisecond) + + localKlines, err := e.client.PublicService.KLines(toLocalSymbol(symbol), interval, *options.StartTime, limit) + if err != nil { + return nil, err + } + + var kLines []types.KLine + for _, k := range localKlines { + kLines = append(kLines, k.KLine()) + } + + return kLines, nil +} + +func (e *Exchange) QueryAveragePrice(ctx context.Context, symbol string) (float64, error) { + ticker, err := e.client.PublicService.Ticker(toLocalSymbol(symbol)) + if err != nil { + return 0, err + } + + return (util.MustParseFloat(ticker.Sell) + util.MustParseFloat(ticker.Buy)) / 2, nil +} + func toGlobalCurrency(currency string) string { return strings.ToUpper(currency) } @@ -136,6 +187,14 @@ func toLocalCurrency(currency string) string { return strings.ToLower(currency) } +func toLocalSymbol(symbol string) string { + return strings.ToLower(symbol) +} + +func toGlobalSymbol(symbol string) string { + return strings.ToLower(symbol) +} + func toLocalSideType(side types.SideType) string { return strings.ToLower(string(side)) } @@ -148,6 +207,9 @@ func toGlobalSideType(v string) string { case "ask": return "SELL" + case "self-trade": + return "SELF" + } return strings.ToUpper(v) @@ -195,7 +257,7 @@ func convertRemoteTrade(t maxapi.Trade) (*types.Trade, error) { return &types.Trade{ ID: int64(t.ID), Price: price, - Symbol: t.Market, + Symbol: toGlobalSymbol(t.Market), Exchange: "max", Quantity: quantity, Side: side, diff --git a/exchange/max/maxapi/public.go b/exchange/max/maxapi/public.go index 8c8c10a99..509535ef3 100644 --- a/exchange/max/maxapi/public.go +++ b/exchange/max/maxapi/public.go @@ -1,10 +1,17 @@ package max import ( + "fmt" + "io/ioutil" "net/url" + "strconv" + "strings" "time" + "github.com/pkg/errors" "github.com/valyala/fastjson" + + "github.com/c9s/bbgo/types" ) type PublicService struct { @@ -140,3 +147,164 @@ func mustParseTicker(v *fastjson.Value) Ticker { Low: string(v.GetStringBytes("low")), } } + +type Interval int64 + +func parseResolution(a string) (Interval, error) { + switch strings.ToLower(a) { + + case "1m": + return 1, nil + + case "5m": + return 5, nil + + case "15m": + return 15, nil + + case "30m": + return 30, nil + + case "1h": + return 60, nil + + case "3h": + return 60 * 3, nil + + case "6h": + return 60 * 6, nil + + case "12h": + return 60 * 12, nil + + case "1d": + return 60 * 24, nil + + case "3d": + return 60 * 24 * 3, nil + + case "1w": + return 60 * 24 * 7, nil + + } + + return 0, errors.New("incorrect resolution") +} + +type KLine struct { + Symbol string + Interval string + StartTime, EndTime time.Time + Open, High, Low, Close float64 + Volume float64 + Closed bool +} + +func (k KLine) KLine() types.KLine { + return types.KLine{ + Symbol: k.Symbol, + Interval: k.Interval, + StartTime: k.StartTime, + EndTime: k.EndTime, + Open: k.Open, + Close: k.Close, + High: k.High, + Low: k.Low, + Volume: k.Volume, + // QuoteVolume: util.MustParseFloat(k.QuoteAssetVolume), + // LastTradeID: 0, + // NumberOfTrades: k.TradeNum, + Closed: k.Closed, + } +} + +func (s *PublicService) KLines(symbol string, resolution string, start time.Time, limit int) ([]KLine, error) { + queries := url.Values{} + queries.Set("market", symbol) + + interval, err := parseResolution(resolution) + if err != nil { + return nil, err + } + queries.Set("period", strconv.Itoa(int(interval))) + + nilTime := time.Time{} + if start != nilTime { + queries.Set("timestamp", strconv.FormatInt(start.Unix(), 64)) + } + + if limit > 0 { + queries.Set("limit", strconv.Itoa(limit)) // default to 30, max limit = 10,000 + } + + req, err := s.client.newRequest("GET", fmt.Sprintf("%s/k", s.client.BaseURL), queries, nil) + if err != nil { + return nil, fmt.Errorf("request build error: %s", err.Error()) + } + + resp, err := s.client.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %s", err.Error()) + } + + defer func() { + if err := resp.Body.Close(); err != nil { + logger.WithError(err).Error("failed to close resp body") + } + }() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return parseKLines(body, symbol, resolution, interval) +} + +func parseKLines(payload []byte, symbol, resolution string, interval Interval) (klines []KLine, err error) { + var parser fastjson.Parser + + v, err := parser.ParseBytes(payload) + if err != nil { + return nil, errors.Wrapf(err, "failed to parse payload: %s", payload) + } + + arr, err := v.Array() + if err != nil { + return nil, errors.Wrapf(err, "failed to get array: %s", payload) + } + + for _, x := range arr { + slice, err := x.Array() + if err != nil { + return nil, errors.Wrapf(err, "failed to get array: %s", payload) + } + + if len(slice) < 6 { + return nil, fmt.Errorf("unexpected length of ohlc elements: %s", payload) + } + + ts, err := slice[0].Int64() + if err != nil { + return nil, fmt.Errorf("failed to parse timestamp: %s", payload) + } + + startTime := time.Unix(ts, 0).UTC() + endTime := time.Unix(ts, 0).Add(time.Duration(interval)*time.Minute - time.Millisecond).UTC() + isClosed := time.Now().Before(endTime) + klines = append(klines, KLine{ + Symbol: symbol, + Interval: resolution, + StartTime: startTime, + EndTime: endTime, + Open: slice[1].GetFloat64(), + High: slice[2].GetFloat64(), + Low: slice[3].GetFloat64(), + Close: slice[4].GetFloat64(), + Volume: slice[5].GetFloat64(), + Closed: isClosed, + }) + } + + return klines, nil +} diff --git a/exchange/max/maxapi/restapi.go b/exchange/max/maxapi/restapi.go index 79b7f56de..15d2dbac5 100644 --- a/exchange/max/maxapi/restapi.go +++ b/exchange/max/maxapi/restapi.go @@ -12,6 +12,7 @@ import ( "math" "net/http" "net/url" + "reflect" "regexp" "strconv" "sync/atomic" @@ -141,6 +142,8 @@ func (c *RestClient) initNonce() { // 1 is for the request count mod 0.000 to 0.999 timeOffset = serverTimestamp - clientTime.Unix() - 1 + + logger.Infof("loaded max server timestamp: %d offset=%d", serverTimestamp, timeOffset) } func (c *RestClient) getNonce() int64 { @@ -166,6 +169,7 @@ func (c *RestClient) newRequest(method string, refURL string, params url.Values, return nil, err } + req.Header.Add("User-Agent", UserAgent) return req, nil } @@ -191,13 +195,16 @@ func (c *RestClient) newAuthenticatedRequest(m string, refURL string, data inter p, err = json.Marshal(payload) - case PrivateRequestParams: - d.Nonce = c.getNonce() - d.Path = c.BaseURL.ResolveReference(rel).Path - default: - return nil, errors.New("unsupported payload type") + params, err := getPrivateRequestParamsObject(data) + if err != nil { + return nil, errors.Wrapf(err, "unsupported payload type: %T", d) + } + params.Nonce = c.getNonce() + params.Path = c.BaseURL.ResolveReference(rel).Path + + p, err = json.Marshal(d) } if err != nil { @@ -228,6 +235,39 @@ func (c *RestClient) newAuthenticatedRequest(m string, refURL string, data inter return req, nil } +func getPrivateRequestParamsObject(v interface{}) (*PrivateRequestParams, error) { + vt := reflect.ValueOf(v) + + if vt.Kind() == reflect.Ptr { + vt = vt.Elem() + } + + + if vt.Kind() != reflect.Struct { + return nil, errors.New("reflect error: given object is not a struct" + vt.Kind().String()) + } + + if !vt.CanSet() { + return nil, errors.New("reflect error: can not set object") + } + + field := vt.FieldByName("PrivateRequestParams") + if !field.IsValid() { + return nil, errors.New("reflect error: field PrivateRequestParams not found") + } + + if field.IsNil() { + field.Set(reflect.ValueOf(&PrivateRequestParams{})) + } + + params, ok := field.Interface().(*PrivateRequestParams) + if !ok { + return nil, errors.New("reflect error: failed to cast value to *PrivateRequestParams") + } + + return params, nil +} + func signPayload(payload string, secret string) string { var sig = hmac.New(sha256.New, []byte(secret)) _, err := sig.Write([]byte(payload)) diff --git a/exchange/max/maxapi/trade.go b/exchange/max/maxapi/trade.go index 1116708a9..569f76bdf 100644 --- a/exchange/max/maxapi/trade.go +++ b/exchange/max/maxapi/trade.go @@ -138,7 +138,7 @@ type PrivateRequestParams struct { } type PrivateTradeRequestParams struct { - PrivateRequestParams + *PrivateRequestParams Market string `json:"market"` @@ -202,7 +202,7 @@ func (r *PrivateTradeRequest) OrderBy(orderBy string) *PrivateTradeRequest { } func (r *PrivateTradeRequest) Do(ctx context.Context) (trades []Trade, err error) { - req, err := r.client.newAuthenticatedRequest("GET", "v2/trades/my", r.params) + req, err := r.client.newAuthenticatedRequest("GET", "v2/trades/my", &r.params) if err != nil { return trades, err } diff --git a/types/exchange.go b/types/exchange.go index dd1ff618e..49d6b21ed 100644 --- a/types/exchange.go +++ b/types/exchange.go @@ -2,11 +2,31 @@ package types import ( "context" + "strings" "time" + "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) +type ExchangeName string + +const ( + ExchangeMax = ExchangeName("max") + ExchangeBinance = ExchangeName("binance") +) + +func ValidExchangeName(a string) (ExchangeName, error) { + switch strings.ToLower(a) { + case "max": + return ExchangeMax, nil + case "binance", "bn": + return ExchangeBinance, nil + } + + return "", errors.New("invalid exchange name") +} + type Exchange interface { PlatformFeeCurrency() string @@ -20,6 +40,8 @@ type Exchange interface { QueryTrades(ctx context.Context, symbol string, options *TradeQueryOptions) ([]Trade, error) + QueryAveragePrice(ctx context.Context, symbol string) (float64, error) + SubmitOrder(ctx context.Context, order *SubmitOrder) error } @@ -58,17 +80,17 @@ func (e ExchangeBatchProcessor) BatchQueryKLines(ctx context.Context, symbol, in return allKLines, err } - func (e ExchangeBatchProcessor) BatchQueryTrades(ctx context.Context, symbol string, options *TradeQueryOptions) (allTrades []Trade, err error) { + // last 7 days var startTime = time.Now().Add(-7 * 24 * time.Hour) if options.StartTime != nil { startTime = *options.StartTime } - log.Infof("querying %s trades from %s", symbol, startTime) - var lastTradeID = options.LastTradeID for { + log.Infof("querying %s trades from %s, limit=%d", symbol, startTime, options.Limit) + trades, err := e.QueryTrades(ctx, symbol, &TradeQueryOptions{ StartTime: &startTime, Limit: options.Limit, @@ -78,10 +100,18 @@ func (e ExchangeBatchProcessor) BatchQueryTrades(ctx context.Context, symbol str return allTrades, err } + if len(trades) == 0 { + break + } + if len(trades) == 1 && trades[0].ID == lastTradeID { break } + log.Infof("returned %d trades", len(trades)) + + startTime = trades[len(trades)-1].Time + for _, t := range trades { // ignore the first trade if last TradeID is given if t.ID == lastTradeID {