diff --git a/examples/okex-book/main.go b/examples/okex-book/main.go index 77218d81c..e6e7341cc 100644 --- a/examples/okex-book/main.go +++ b/examples/okex-book/main.go @@ -44,10 +44,13 @@ var rootCmd = &cobra.Command{ return errors.New("empty key, secret or passphrase") } - client := okexapi.NewClient() + client, err := okexapi.NewClient() + if err != nil { + return errors.New("init client error: please check url") + } client.Auth(key, secret, passphrase) - instruments, err := client.PublicDataService.NewGetInstrumentsRequest(). + instruments, err := client.NewGetInstrumentsRequest(). InstrumentType("SPOT").Do(ctx) if err != nil { return err @@ -55,14 +58,14 @@ var rootCmd = &cobra.Command{ log.Infof("instruments: %+v", instruments) - fundingRate, err := client.PublicDataService.NewGetFundingRate().InstrumentID("BTC-USDT-SWAP").Do(ctx) + fundingRate, err := client.NewGetFundingRate().InstrumentID("BTC-USDT-SWAP").Do(ctx) if err != nil { return err } log.Infof("funding rate: %+v", fundingRate) log.Infof("ACCOUNT BALANCES:") - account, err := client.AccountBalances() + account, err := client.AccountBalances(ctx) if err != nil { return err } @@ -70,7 +73,7 @@ var rootCmd = &cobra.Command{ log.Infof("%+v", account) log.Infof("ASSET BALANCES:") - assetBalances, err := client.AssetBalances() + assetBalances, err := client.AssetBalances(ctx) if err != nil { return err } @@ -80,7 +83,7 @@ var rootCmd = &cobra.Command{ } log.Infof("ASSET CURRENCIES:") - currencies, err := client.AssetCurrencies() + currencies, err := client.AssetCurrencies(ctx) if err != nil { return err } @@ -90,7 +93,7 @@ var rootCmd = &cobra.Command{ } log.Infof("MARKET TICKERS:") - tickers, err := client.MarketTickers(okexapi.InstrumentTypeSpot) + tickers, err := client.MarketTickers(ctx, okexapi.InstrumentTypeSpot) if err != nil { return err } @@ -99,7 +102,7 @@ var rootCmd = &cobra.Command{ log.Infof("%T%+v", ticker, ticker) } - ticker, err := client.MarketTicker("ETH-USDT") + ticker, err := client.MarketTicker(ctx, "ETH-USDT") if err != nil { return err } @@ -107,7 +110,7 @@ var rootCmd = &cobra.Command{ log.Infof("%T%+v", ticker, ticker) log.Infof("PLACING ORDER:") - placeResponse, err := client.TradeService.NewPlaceOrderRequest(). + placeResponse, err := client.NewPlaceOrderRequest(). InstrumentID("LTC-USDT"). OrderType(okexapi.OrderTypeLimit). Side(okexapi.SideTypeBuy). @@ -122,7 +125,7 @@ var rootCmd = &cobra.Command{ time.Sleep(time.Second) log.Infof("getting order detail...") - orderDetail, err := client.TradeService.NewGetOrderDetailsRequest(). + orderDetail, err := client.NewGetOrderDetailsRequest(). InstrumentID("LTC-USDT"). OrderID(placeResponse.OrderID). Do(ctx) @@ -132,7 +135,7 @@ var rootCmd = &cobra.Command{ log.Infof("order detail: %+v", orderDetail) - cancelResponse, err := client.TradeService.NewCancelOrderRequest(). + cancelResponse, err := client.NewCancelOrderRequest(). InstrumentID("LTC-USDT"). OrderID(placeResponse.OrderID). Do(ctx) @@ -144,15 +147,15 @@ var rootCmd = &cobra.Command{ time.Sleep(time.Second) log.Infof("BATCH PLACE ORDER:") - batchPlaceReq := client.TradeService.NewBatchPlaceOrderRequest() - batchPlaceReq.Add(client.TradeService.NewPlaceOrderRequest(). + batchPlaceReq := client.NewBatchPlaceOrderRequest() + batchPlaceReq.Add(client.NewPlaceOrderRequest(). InstrumentID("LTC-USDT"). OrderType(okexapi.OrderTypeLimit). Side(okexapi.SideTypeBuy). Price("50.0"). Quantity("0.5")) - batchPlaceReq.Add(client.TradeService.NewPlaceOrderRequest(). + batchPlaceReq.Add(client.NewPlaceOrderRequest(). InstrumentID("LTC-USDT"). OrderType(okexapi.OrderTypeLimit). Side(okexapi.SideTypeBuy). @@ -168,7 +171,7 @@ var rootCmd = &cobra.Command{ time.Sleep(time.Second) log.Infof("getting pending orders...") - pendingOrders, err := client.TradeService.NewGetPendingOrderRequest().Do(ctx) + pendingOrders, err := client.NewGetPendingOrderRequest().Do(ctx) if err != nil { return err } @@ -176,9 +179,9 @@ var rootCmd = &cobra.Command{ log.Infof("pending order: %+v", pendingOrder) } - cancelReq := client.TradeService.NewBatchCancelOrderRequest() + cancelReq := client.NewBatchCancelOrderRequest() for _, resp := range batchPlaceResponse { - cancelReq.Add(client.TradeService.NewCancelOrderRequest(). + cancelReq.Add(client.NewCancelOrderRequest(). InstrumentID("LTC-USDT"). OrderID(resp.OrderID)) } diff --git a/pkg/exchange/factory.go b/pkg/exchange/factory.go index 00d6db7c9..c7f3371ca 100644 --- a/pkg/exchange/factory.go +++ b/pkg/exchange/factory.go @@ -37,7 +37,7 @@ func New(n types.ExchangeName, key, secret, passphrase string) (types.ExchangeMi return max.New(key, secret), nil case types.ExchangeOKEx: - return okex.New(key, secret, passphrase), nil + return okex.New(key, secret, passphrase) case types.ExchangeKucoin: return kucoin.New(key, secret, passphrase), nil diff --git a/pkg/exchange/okex/exchange.go b/pkg/exchange/okex/exchange.go index 700b95837..92e322116 100644 --- a/pkg/exchange/okex/exchange.go +++ b/pkg/exchange/okex/exchange.go @@ -34,8 +34,11 @@ type Exchange struct { client *okexapi.RestClient } -func New(key, secret, passphrase string) *Exchange { - client := okexapi.NewClient() +func New(key, secret, passphrase string) (*Exchange, error) { + client, err := okexapi.NewClient() + if err != nil { + return nil, err + } if len(key) > 0 && len(secret) > 0 { client.Auth(key, secret, passphrase) @@ -46,7 +49,7 @@ func New(key, secret, passphrase string) *Exchange { secret: secret, passphrase: passphrase, client: client, - } + }, nil } func (e *Exchange) Name() types.ExchangeName { @@ -54,7 +57,7 @@ func (e *Exchange) Name() types.ExchangeName { } func (e *Exchange) QueryMarkets(ctx context.Context) (types.MarketMap, error) { - instruments, err := e.client.PublicDataService.NewGetInstrumentsRequest(). + instruments, err := e.client.NewGetInstrumentsRequest(). InstrumentType(okexapi.InstrumentTypeSpot). Do(ctx) @@ -98,7 +101,7 @@ func (e *Exchange) QueryMarkets(ctx context.Context) (types.MarketMap, error) { func (e *Exchange) QueryTicker(ctx context.Context, symbol string) (*types.Ticker, error) { symbol = toLocalSymbol(symbol) - marketTicker, err := e.client.MarketTicker(symbol) + marketTicker, err := e.client.MarketTicker(ctx, symbol) if err != nil { return nil, err } @@ -107,7 +110,7 @@ func (e *Exchange) QueryTicker(ctx context.Context, symbol string) (*types.Ticke } func (e *Exchange) QueryTickers(ctx context.Context, symbols ...string) (map[string]types.Ticker, error) { - marketTickers, err := e.client.MarketTickers(okexapi.InstrumentTypeSpot) + marketTickers, err := e.client.MarketTickers(ctx, okexapi.InstrumentTypeSpot) if err != nil { return nil, err } @@ -138,7 +141,7 @@ func (e *Exchange) PlatformFeeCurrency() string { } func (e *Exchange) QueryAccount(ctx context.Context) (*types.Account, error) { - accountBalance, err := e.client.AccountBalances() + accountBalance, err := e.client.AccountBalances(ctx) if err != nil { return nil, err } @@ -153,7 +156,7 @@ func (e *Exchange) QueryAccount(ctx context.Context) (*types.Account, error) { } func (e *Exchange) QueryAccountBalances(ctx context.Context) (types.BalanceMap, error) { - accountBalances, err := e.client.AccountBalances() + accountBalances, err := e.client.AccountBalances(ctx) if err != nil { return nil, err } @@ -163,7 +166,7 @@ func (e *Exchange) QueryAccountBalances(ctx context.Context) (types.BalanceMap, } func (e *Exchange) SubmitOrder(ctx context.Context, order types.SubmitOrder) (*types.Order, error) { - orderReq := e.client.TradeService.NewPlaceOrderRequest() + orderReq := e.client.NewPlaceOrderRequest() orderType, err := toLocalOrderType(order.Type) if err != nil { @@ -257,7 +260,7 @@ func (e *Exchange) SubmitOrder(ctx context.Context, order types.SubmitOrder) (*t func (e *Exchange) QueryOpenOrders(ctx context.Context, symbol string) (orders []types.Order, err error) { instrumentID := toLocalSymbol(symbol) - req := e.client.TradeService.NewGetPendingOrderRequest().InstrumentType(okexapi.InstrumentTypeSpot).InstrumentID(instrumentID) + req := e.client.NewGetPendingOrderRequest().InstrumentType(okexapi.InstrumentTypeSpot).InstrumentID(instrumentID) orderDetails, err := req.Do(ctx) if err != nil { return orders, err @@ -278,7 +281,7 @@ func (e *Exchange) CancelOrders(ctx context.Context, orders ...types.Order) erro return ErrSymbolRequired } - req := e.client.TradeService.NewCancelOrderRequest() + req := e.client.NewCancelOrderRequest() req.InstrumentID(toLocalSymbol(order.Symbol)) req.OrderID(strconv.FormatUint(order.OrderID, 10)) if len(order.ClientOrderID) > 0 { @@ -287,7 +290,7 @@ func (e *Exchange) CancelOrders(ctx context.Context, orders ...types.Order) erro reqs = append(reqs, req) } - batchReq := e.client.TradeService.NewBatchCancelOrderRequest() + batchReq := e.client.NewBatchCancelOrderRequest() batchReq.Add(reqs...) _, err := batchReq.Do(ctx) return err @@ -304,7 +307,7 @@ func (e *Exchange) QueryKLines(ctx context.Context, symbol string, interval type intervalParam := toLocalInterval(interval.String()) - req := e.client.MarketDataService.NewCandlesticksRequest(toLocalSymbol(symbol)) + req := e.client.NewCandlesticksRequest(toLocalSymbol(symbol)) req.Bar(intervalParam) if options.StartTime != nil { @@ -349,7 +352,7 @@ func (e *Exchange) QueryOrder(ctx context.Context, q types.OrderQuery) (*types.O if len(q.OrderID) == 0 && len(q.ClientOrderID) == 0 { return nil, errors.New("okex.QueryOrder: OrderId or ClientOrderId is required parameter") } - req := e.client.TradeService.NewGetOrderDetailsRequest() + req := e.client.NewGetOrderDetailsRequest() req.InstrumentID(q.Symbol). OrderID(q.OrderID). ClientOrderID(q.ClientOrderID) diff --git a/pkg/exchange/okex/okexapi/client.go b/pkg/exchange/okex/okexapi/client.go index ccb81f70a..5d0533051 100644 --- a/pkg/exchange/okex/okexapi/client.go +++ b/pkg/exchange/okex/okexapi/client.go @@ -2,6 +2,7 @@ package okexapi import ( "bytes" + "context" "crypto/hmac" "crypto/sha256" "encoding/base64" @@ -14,7 +15,7 @@ import ( "github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/types" - "github.com/c9s/bbgo/pkg/util" + "github.com/c9s/requestgen" "github.com/pkg/errors" ) @@ -60,34 +61,26 @@ const ( ) type RestClient struct { - BaseURL *url.URL - - client *http.Client + requestgen.BaseAPIClient Key, Secret, Passphrase string - - TradeService *TradeService - PublicDataService *PublicDataService - MarketDataService *MarketDataService } -func NewClient() *RestClient { +func NewClient() (*RestClient, error) { u, err := url.Parse(RestBaseURL) if err != nil { - panic(err) + return nil, err } client := &RestClient{ - BaseURL: u, - client: &http.Client{ - Timeout: defaultHTTPTimeout, + BaseAPIClient: requestgen.BaseAPIClient{ + BaseURL: u, + HttpClient: &http.Client{ + Timeout: defaultHTTPTimeout, + }, }, } - - client.TradeService = &TradeService{client: client} - client.PublicDataService = &PublicDataService{client: client} - client.MarketDataService = &MarketDataService{client: client} - return client + return client, nil } func (c *RestClient) Auth(key, secret, passphrase string) { @@ -97,44 +90,8 @@ func (c *RestClient) Auth(key, secret, passphrase string) { c.Passphrase = passphrase } -// NewRequest create new API request. Relative url can be provided in refURL. -func (c *RestClient) newRequest(method, refURL string, params url.Values, body []byte) (*http.Request, error) { - rel, err := url.Parse(refURL) - if err != nil { - return nil, err - } - - if params != nil { - rel.RawQuery = params.Encode() - } - - pathURL := c.BaseURL.ResolveReference(rel) - return http.NewRequest(method, pathURL.String(), bytes.NewReader(body)) -} - -// sendRequest sends the request to the API server and handle the response -func (c *RestClient) sendRequest(req *http.Request) (*util.Response, error) { - resp, err := c.client.Do(req) - if err != nil { - return nil, err - } - - // newResponse reads the response body and return a new Response object - response, err := util.NewResponse(resp) - if err != nil { - return response, err - } - - // Check error, if there is an error, return the ErrorResponse struct type - if response.IsError() { - return response, errors.New(string(response.Body)) - } - - return response, nil -} - -// newAuthenticatedRequest creates new http request for authenticated routes. -func (c *RestClient) newAuthenticatedRequest(method, refURL string, params url.Values, payload interface{}) (*http.Request, error) { +// NewAuthenticatedRequest creates new http request for authenticated routes. +func (c *RestClient) NewAuthenticatedRequest(ctx context.Context, method, refURL string, params url.Values, payload interface{}) (*http.Request, error) { if len(c.Key) == 0 { return nil, errors.New("empty api key") } @@ -215,13 +172,13 @@ type Account struct { Details []BalanceDetail `json:"details"` } -func (c *RestClient) AccountBalances() (*Account, error) { - req, err := c.newAuthenticatedRequest("GET", "/api/v5/account/balance", nil, nil) +func (c *RestClient) AccountBalances(ctx context.Context) (*Account, error) { + req, err := c.NewAuthenticatedRequest(ctx, "GET", "/api/v5/account/balance", nil, nil) if err != nil { return nil, err } - response, err := c.sendRequest(req) + response, err := c.SendRequest(req) if err != nil { return nil, err } @@ -252,13 +209,13 @@ type AssetBalance struct { type AssetBalanceList []AssetBalance -func (c *RestClient) AssetBalances() (AssetBalanceList, error) { - req, err := c.newAuthenticatedRequest("GET", "/api/v5/asset/balances", nil, nil) +func (c *RestClient) AssetBalances(ctx context.Context) (AssetBalanceList, error) { + req, err := c.NewAuthenticatedRequest(ctx, "GET", "/api/v5/asset/balances", nil, nil) if err != nil { return nil, err } - response, err := c.sendRequest(req) + response, err := c.SendRequest(req) if err != nil { return nil, err } @@ -287,13 +244,13 @@ type AssetCurrency struct { MinWithdrawalThreshold fixedpoint.Value `json:"minWd"` } -func (c *RestClient) AssetCurrencies() ([]AssetCurrency, error) { - req, err := c.newAuthenticatedRequest("GET", "/api/v5/asset/currencies", nil, nil) +func (c *RestClient) AssetCurrencies(ctx context.Context) ([]AssetCurrency, error) { + req, err := c.NewAuthenticatedRequest(ctx, "GET", "/api/v5/asset/currencies", nil, nil) if err != nil { return nil, err } - response, err := c.sendRequest(req) + response, err := c.SendRequest(req) if err != nil { return nil, err } @@ -337,17 +294,17 @@ type MarketTicker struct { Timestamp types.MillisecondTimestamp `json:"ts"` } -func (c *RestClient) MarketTicker(instId string) (*MarketTicker, error) { +func (c *RestClient) MarketTicker(ctx context.Context, instId string) (*MarketTicker, error) { // SPOT, SWAP, FUTURES, OPTION var params = url.Values{} params.Add("instId", instId) - req, err := c.newRequest("GET", "/api/v5/market/ticker", params, nil) + req, err := c.NewRequest(ctx, "GET", "/api/v5/market/ticker", params, nil) if err != nil { return nil, err } - response, err := c.sendRequest(req) + response, err := c.SendRequest(req) if err != nil { return nil, err } @@ -368,17 +325,17 @@ func (c *RestClient) MarketTicker(instId string) (*MarketTicker, error) { return &tickerResponse.Data[0], nil } -func (c *RestClient) MarketTickers(instType InstrumentType) ([]MarketTicker, error) { +func (c *RestClient) MarketTickers(ctx context.Context, instType InstrumentType) ([]MarketTicker, error) { // SPOT, SWAP, FUTURES, OPTION var params = url.Values{} params.Add("instType", string(instType)) - req, err := c.newRequest("GET", "/api/v5/market/tickers", params, nil) + req, err := c.NewRequest(ctx, "GET", "/api/v5/market/tickers", params, nil) if err != nil { return nil, err } - response, err := c.sendRequest(req) + response, err := c.SendRequest(req) if err != nil { return nil, err } @@ -405,3 +362,9 @@ func Sign(payload string, secret string) string { return base64.StdEncoding.EncodeToString(sig.Sum(nil)) // return hex.EncodeToString(sig.Sum(nil)) } + +type APIResponse struct { + Code string `json:"code"` + Message string `json:"msg"` + Data json.RawMessage `json:"data"` +} diff --git a/pkg/exchange/okex/okexapi/client_test.go b/pkg/exchange/okex/okexapi/client_test.go index ec6841642..8169450a9 100644 --- a/pkg/exchange/okex/okexapi/client_test.go +++ b/pkg/exchange/okex/okexapi/client_test.go @@ -22,17 +22,17 @@ func getTestClientOrSkip(t *testing.T) *RestClient { return nil } - client := NewClient() + client, err := NewClient() + assert.NoError(t, err) client.Auth(key, secret, passphrase) return client } func TestClient_GetInstrumentsRequest(t *testing.T) { - client := NewClient() + client, err := NewClient() + assert.NoError(t, err) ctx := context.Background() - - srv := &PublicDataService{client: client} - req := srv.NewGetInstrumentsRequest() + req := client.NewGetInstrumentsRequest() instruments, err := req. InstrumentType(InstrumentTypeSpot). @@ -43,10 +43,10 @@ func TestClient_GetInstrumentsRequest(t *testing.T) { } func TestClient_GetFundingRateRequest(t *testing.T) { - client := NewClient() + client, err := NewClient() + assert.NoError(t, err) ctx := context.Background() - srv := &PublicDataService{client: client} - req := srv.NewGetFundingRate() + req := client.NewGetFundingRate() instrument, err := req. InstrumentID("BTC-USDT-SWAP"). @@ -59,8 +59,7 @@ func TestClient_GetFundingRateRequest(t *testing.T) { func TestClient_PlaceOrderRequest(t *testing.T) { client := getTestClientOrSkip(t) ctx := context.Background() - srv := &TradeService{client: client} - req := srv.NewPlaceOrderRequest() + req := client.NewPlaceOrderRequest() order, err := req. InstrumentID("BTC-USDT"). @@ -78,8 +77,7 @@ func TestClient_PlaceOrderRequest(t *testing.T) { func TestClient_GetPendingOrderRequest(t *testing.T) { client := getTestClientOrSkip(t) ctx := context.Background() - srv := &TradeService{client: client} - req := srv.NewGetPendingOrderRequest() + req := client.NewGetPendingOrderRequest() odr_type := []string{string(OrderTypeLimit), string(OrderTypeIOC)} pending_order, err := req. @@ -94,8 +92,7 @@ func TestClient_GetPendingOrderRequest(t *testing.T) { func TestClient_GetOrderDetailsRequest(t *testing.T) { client := getTestClientOrSkip(t) ctx := context.Background() - srv := &TradeService{client: client} - req := srv.NewGetOrderDetailsRequest() + req := client.NewGetOrderDetailsRequest() orderDetail, err := req. InstrumentID("BTC-USDT"). diff --git a/pkg/exchange/okex/okexapi/market.go b/pkg/exchange/okex/okexapi/market.go index b9b46c43f..9c5b7d0db 100644 --- a/pkg/exchange/okex/okexapi/market.go +++ b/pkg/exchange/okex/okexapi/market.go @@ -2,6 +2,7 @@ package okexapi import ( "context" + "encoding/json" "fmt" "net/url" "strconv" @@ -82,29 +83,29 @@ func (r *CandlesticksRequest) Do(ctx context.Context) ([]Candle, error) { params.Add("limit", strconv.Itoa(*r.limit)) } - req, err := r.client.newRequest("GET", "/api/v5/market/candles", params, nil) + req, err := r.client.NewRequest(ctx, "GET", "/api/v5/market/candles", params, nil) if err != nil { return nil, err } - resp, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } type candleEntry [7]string - var candlesResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []candleEntry `json:"data"` - } - if err := resp.DecodeJSON(&candlesResponse); err != nil { + var apiResponse APIResponse + if err := response.DecodeJSON(&apiResponse); err != nil { + return nil, err + } + var data []candleEntry + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { return nil, err } var candles []Candle - for _, entry := range candlesResponse.Data { + for _, entry := range data { timestamp, err := strconv.ParseInt(entry[0], 10, 64) if err != nil { return candles, err @@ -177,27 +178,26 @@ func (r *MarketTickersRequest) Do(ctx context.Context) ([]MarketTicker, error) { var params = url.Values{} params.Add("instType", string(r.instType)) - req, err := r.client.newRequest("GET", "/api/v5/market/tickers", params, nil) + req, err := r.client.NewRequest(ctx, "GET", "/api/v5/market/tickers", params, nil) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var tickerResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []MarketTicker `json:"data"` + var apiResponse APIResponse + if err := response.DecodeJSON(&apiResponse); err != nil { + return nil, err } - - if err := response.DecodeJSON(&tickerResponse); err != nil { + var data []MarketTicker + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { return nil, err } - return tickerResponse.Data, nil + return data, nil } type MarketTickerRequest struct { @@ -216,53 +216,49 @@ func (r *MarketTickerRequest) Do(ctx context.Context) (*MarketTicker, error) { var params = url.Values{} params.Add("instId", r.instId) - req, err := r.client.newRequest("GET", "/api/v5/market/ticker", params, nil) + req, err := r.client.NewRequest(ctx, "GET", "/api/v5/market/ticker", params, nil) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var tickerResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []MarketTicker `json:"data"` + var apiResponse APIResponse + if err := response.DecodeJSON(&apiResponse); err != nil { + return nil, err } - if err := response.DecodeJSON(&tickerResponse); err != nil { + var data []MarketTicker + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { return nil, err } - if len(tickerResponse.Data) == 0 { + if len(data) == 0 { return nil, fmt.Errorf("ticker of %s not found", r.instId) } - return &tickerResponse.Data[0], nil + return &data[0], nil } -type MarketDataService struct { - client *RestClient -} - -func (c *MarketDataService) NewMarketTickerRequest(instId string) *MarketTickerRequest { +func (c *RestClient) NewMarketTickerRequest(instId string) *MarketTickerRequest { return &MarketTickerRequest{ - client: c.client, + client: c, instId: instId, } } -func (c *MarketDataService) NewMarketTickersRequest(instType string) *MarketTickersRequest { +func (c *RestClient) NewMarketTickersRequest(instType string) *MarketTickersRequest { return &MarketTickersRequest{ - client: c.client, + client: c, instType: instType, } } -func (c *MarketDataService) NewCandlesticksRequest(instId string) *CandlesticksRequest { +func (c *RestClient) NewCandlesticksRequest(instId string) *CandlesticksRequest { return &CandlesticksRequest{ - client: c.client, + client: c, instId: instId, } } diff --git a/pkg/exchange/okex/okexapi/public.go b/pkg/exchange/okex/okexapi/public.go index b877fe637..af4b45496 100644 --- a/pkg/exchange/okex/okexapi/public.go +++ b/pkg/exchange/okex/okexapi/public.go @@ -2,6 +2,7 @@ package okexapi import ( "context" + "encoding/json" "net/url" "github.com/c9s/bbgo/pkg/fixedpoint" @@ -9,19 +10,15 @@ import ( "github.com/pkg/errors" ) -type PublicDataService struct { - client *RestClient -} - -func (s *PublicDataService) NewGetInstrumentsRequest() *GetInstrumentsRequest { +func (s *RestClient) NewGetInstrumentsRequest() *GetInstrumentsRequest { return &GetInstrumentsRequest{ - client: s.client, + client: s, } } -func (s *PublicDataService) NewGetFundingRate() *GetFundingRateRequest { +func (s *RestClient) NewGetFundingRate() *GetFundingRateRequest { return &GetFundingRateRequest{ - client: s.client, + client: s, } } @@ -49,30 +46,30 @@ func (r *GetFundingRateRequest) Do(ctx context.Context) (*FundingRate, error) { var params = url.Values{} params.Add("instId", string(r.instId)) - req, err := r.client.newRequest("GET", "/api/v5/public/funding-rate", params, nil) + req, err := r.client.NewRequest(ctx, "GET", "/api/v5/public/funding-rate", params, nil) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var apiResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []FundingRate `json:"data"` - } + var apiResponse APIResponse if err := response.DecodeJSON(&apiResponse); err != nil { return nil, err } + var data []FundingRate + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { + return nil, err + } - if len(apiResponse.Data) == 0 { + if len(data) == 0 { return nil, errors.New("empty funding rate data") } - return &apiResponse.Data[0], nil + return &data[0], nil } type Instrument struct { @@ -123,24 +120,24 @@ func (r *GetInstrumentsRequest) Do(ctx context.Context) ([]Instrument, error) { params.Add("instId", *r.instId) } - req, err := r.client.newRequest("GET", "/api/v5/public/instruments", params, nil) + req, err := r.client.NewRequest(ctx, "GET", "/api/v5/public/instruments", params, nil) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var apiResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []Instrument `json:"data"` - } + var apiResponse APIResponse if err := response.DecodeJSON(&apiResponse); err != nil { return nil, err } + var data []Instrument + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { + return nil, err + } - return apiResponse.Data, nil + return data, nil } diff --git a/pkg/exchange/okex/okexapi/trade.go b/pkg/exchange/okex/okexapi/trade.go index 8a814057b..4c0ecf26a 100644 --- a/pkg/exchange/okex/okexapi/trade.go +++ b/pkg/exchange/okex/okexapi/trade.go @@ -2,6 +2,7 @@ package okexapi import ( "context" + "encoding/json" "net/url" "strings" @@ -10,10 +11,6 @@ import ( "github.com/pkg/errors" ) -type TradeService struct { - client *RestClient -} - type OrderResponse struct { OrderID string `json:"ordId"` ClientOrderID string `json:"clOrdId"` @@ -22,45 +19,45 @@ type OrderResponse struct { Message string `json:"sMsg"` } -func (c *TradeService) NewPlaceOrderRequest() *PlaceOrderRequest { +func (c *RestClient) NewPlaceOrderRequest() *PlaceOrderRequest { return &PlaceOrderRequest{ - client: c.client, + client: c, } } -func (c *TradeService) NewBatchPlaceOrderRequest() *BatchPlaceOrderRequest { +func (c *RestClient) NewBatchPlaceOrderRequest() *BatchPlaceOrderRequest { return &BatchPlaceOrderRequest{ - client: c.client, + client: c, } } -func (c *TradeService) NewCancelOrderRequest() *CancelOrderRequest { +func (c *RestClient) NewCancelOrderRequest() *CancelOrderRequest { return &CancelOrderRequest{ - client: c.client, + client: c, } } -func (c *TradeService) NewBatchCancelOrderRequest() *BatchCancelOrderRequest { +func (c *RestClient) NewBatchCancelOrderRequest() *BatchCancelOrderRequest { return &BatchCancelOrderRequest{ - client: c.client, + client: c, } } -func (c *TradeService) NewGetOrderDetailsRequest() *GetOrderDetailsRequest { +func (c *RestClient) NewGetOrderDetailsRequest() *GetOrderDetailsRequest { return &GetOrderDetailsRequest{ - client: c.client, + client: c, } } -func (c *TradeService) NewGetPendingOrderRequest() *GetPendingOrderRequest { +func (c *RestClient) NewGetPendingOrderRequest() *GetPendingOrderRequest { return &GetPendingOrderRequest{ - client: c.client, + client: c, } } -func (c *TradeService) NewGetTransactionDetailsRequest() *GetTransactionDetailsRequest { +func (c *RestClient) NewGetTransactionDetailsRequest() *GetTransactionDetailsRequest { return &GetTransactionDetailsRequest{ - client: c.client, + client: c, } } @@ -99,30 +96,30 @@ func (r *PlaceOrderRequest) Parameters() map[string]interface{} { func (r *PlaceOrderRequest) Do(ctx context.Context) (*OrderResponse, error) { payload := r.Parameters() - req, err := r.client.newAuthenticatedRequest("POST", "/api/v5/trade/order", nil, payload) + req, err := r.client.NewAuthenticatedRequest(ctx, "POST", "/api/v5/trade/order", nil, payload) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var orderResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []OrderResponse `json:"data"` + var apiResponse APIResponse + if err := response.DecodeJSON(&apiResponse); err != nil { + return nil, err } - if err := response.DecodeJSON(&orderResponse); err != nil { + var data []OrderResponse + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { return nil, err } - if len(orderResponse.Data) == 0 { + if len(data) == 0 { return nil, errors.New("order create error") } - return &orderResponse.Data[0], nil + return &data[0], nil } //go:generate requestgen -type CancelOrderRequest @@ -149,26 +146,26 @@ func (r *CancelOrderRequest) Do(ctx context.Context) ([]OrderResponse, error) { return nil, errors.New("either orderID or clientOrderID is required for canceling order") } - req, err := r.client.newAuthenticatedRequest("POST", "/api/v5/trade/cancel-order", nil, payload) + req, err := r.client.NewAuthenticatedRequest(ctx, "POST", "/api/v5/trade/cancel-order", nil, payload) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var orderResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []OrderResponse `json:"data"` + var apiResponse APIResponse + if err := response.DecodeJSON(&apiResponse); err != nil { + return nil, err } - if err := response.DecodeJSON(&orderResponse); err != nil { + var data []OrderResponse + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { return nil, err } - return orderResponse.Data, nil + return data, nil } type BatchCancelOrderRequest struct { @@ -190,26 +187,26 @@ func (r *BatchCancelOrderRequest) Do(ctx context.Context) ([]OrderResponse, erro parameterList = append(parameterList, params) } - req, err := r.client.newAuthenticatedRequest("POST", "/api/v5/trade/cancel-batch-orders", nil, parameterList) + req, err := r.client.NewAuthenticatedRequest(ctx, "POST", "/api/v5/trade/cancel-batch-orders", nil, parameterList) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var orderResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []OrderResponse `json:"data"` + var apiResponse APIResponse + if err := response.DecodeJSON(&apiResponse); err != nil { + return nil, err } - if err := response.DecodeJSON(&orderResponse); err != nil { + var data []OrderResponse + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { return nil, err } - return orderResponse.Data, nil + return data, nil } type BatchPlaceOrderRequest struct { @@ -231,26 +228,26 @@ func (r *BatchPlaceOrderRequest) Do(ctx context.Context) ([]OrderResponse, error parameterList = append(parameterList, params) } - req, err := r.client.newAuthenticatedRequest("POST", "/api/v5/trade/batch-orders", nil, parameterList) + req, err := r.client.NewAuthenticatedRequest(ctx, "POST", "/api/v5/trade/batch-orders", nil, parameterList) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var orderResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []OrderResponse `json:"data"` + var apiResponse APIResponse + if err := response.DecodeJSON(&apiResponse); err != nil { + return nil, err } - if err := response.DecodeJSON(&orderResponse); err != nil { + var data []OrderResponse + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { return nil, err } - return orderResponse.Data, nil + return data, nil } type OrderDetails struct { @@ -343,30 +340,30 @@ func (r *GetOrderDetailsRequest) QueryParameters() url.Values { func (r *GetOrderDetailsRequest) Do(ctx context.Context) (*OrderDetails, error) { params := r.QueryParameters() - req, err := r.client.newAuthenticatedRequest("GET", "/api/v5/trade/order", params, nil) + req, err := r.client.NewAuthenticatedRequest(ctx, "GET", "/api/v5/trade/order", params, nil) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var orderResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []OrderDetails `json:"data"` + var apiResponse APIResponse + if err := response.DecodeJSON(&apiResponse); err != nil { + return nil, err } - if err := response.DecodeJSON(&orderResponse); err != nil { + var data []OrderDetails + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { return nil, err } - if len(orderResponse.Data) == 0 { + if len(data) == 0 { return nil, errors.New("get order details error") } - return &orderResponse.Data[0], nil + return &data[0], nil } type GetPendingOrderRequest struct { @@ -430,26 +427,26 @@ func (r *GetPendingOrderRequest) Parameters() map[string]interface{} { func (r *GetPendingOrderRequest) Do(ctx context.Context) ([]OrderDetails, error) { payload := r.Parameters() - req, err := r.client.newAuthenticatedRequest("GET", "/api/v5/trade/orders-pending", nil, payload) + req, err := r.client.NewAuthenticatedRequest(ctx, "GET", "/api/v5/trade/orders-pending", nil, payload) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var orderResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []OrderDetails `json:"data"` + var apiResponse APIResponse + if err := response.DecodeJSON(&apiResponse); err != nil { + return nil, err } - if err := response.DecodeJSON(&orderResponse); err != nil { + var data []OrderDetails + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { return nil, err } - return orderResponse.Data, nil + return data, nil } type GetTransactionDetailsRequest struct { @@ -497,24 +494,24 @@ func (r *GetTransactionDetailsRequest) Parameters() map[string]interface{} { func (r *GetTransactionDetailsRequest) Do(ctx context.Context) ([]OrderDetails, error) { payload := r.Parameters() - req, err := r.client.newAuthenticatedRequest("GET", "/api/v5/trade/fills", nil, payload) + req, err := r.client.NewAuthenticatedRequest(ctx, "GET", "/api/v5/trade/fills", nil, payload) if err != nil { return nil, err } - response, err := r.client.sendRequest(req) + response, err := r.client.SendRequest(req) if err != nil { return nil, err } - var orderResponse struct { - Code string `json:"code"` - Message string `json:"msg"` - Data []OrderDetails `json:"data"` + var apiResponse APIResponse + if err := response.DecodeJSON(&apiResponse); err != nil { + return nil, err } - if err := response.DecodeJSON(&orderResponse); err != nil { + var data []OrderDetails + if err := json.Unmarshal(apiResponse.Data, &data); err != nil { return nil, err } - return orderResponse.Data, nil + return data, nil } diff --git a/pkg/exchange/okex/query_order_test.go b/pkg/exchange/okex/query_order_test.go index 3c32da40c..025e045ba 100644 --- a/pkg/exchange/okex/query_order_test.go +++ b/pkg/exchange/okex/query_order_test.go @@ -22,7 +22,8 @@ func Test_QueryOrder(t *testing.T) { return } - e := New(key, secret, passphrase) + e, err := New(key, secret, passphrase) + assert.NoError(t, err) queryOrder := types.OrderQuery{ Symbol: "BTC-USDT",