Fix test errors

This commit is contained in:
なるみ 2022-04-09 17:24:34 +08:00
parent b3091b9462
commit b85a13c35b
6 changed files with 58 additions and 45 deletions

View File

@ -1,2 +1,3 @@
from .bbgo import BBGO
from .bbgo import MarketService
from .bbgo import TradingService
from .stream import Stream

View File

@ -1,13 +1,39 @@
from typing import List
from . import bbgo_pb2
from .utils import create_stub
from . import bbgo_pb2_grpc
class BBGO(object):
class MarketService(object):
def __init__(self, host: str, port: int):
self.stub = create_stub(host, port)
def __init__(self, stub: bbgo_pb2_grpc.MarketDataServiceStub):
self.stub = stub
def subscribe(self, subscriptions: List[bbgo_pb2.Subscription]):
request = bbgo_pb2.SubscribeRequest(subscriptions=subscriptions)
request_iter = self.stub.Subscribe(request)
return request_iter
def query_klines(self,
exchange: str,
symbol: str,
limit: int = 30,
interval: int = 1,
timestamp: int = None) -> bbgo_pb2.QueryKLinesResponse:
request = bbgo_pb2.QueryKLinesRequest(exchange=exchange,
symbol=symbol,
limit=limit,
interval=interval,
timestamp=timestamp)
response = self.market_data_stub.QueryKLines(request)
return response
class TradingService(object):
def __init__(self, stub: bbgo_pb2_grpc.TradingServiceStub):
self.stub = stub
def submit_order(self,
exchange: str,
@ -87,18 +113,3 @@ class BBGO(object):
offset=offset)
response = self.stub.QueryTrades(request)
return response
def query_klines(self,
exchange: str,
symbol: str,
limit: int = 30,
interval: int = 1,
timestamp: int = None) -> bbgo_pb2.QueryKLinesResponse:
request = bbgo_pb2.QueryKLinesRequest(exchange=exchange,
symbol=symbol,
limit=limit,
interval=interval,
timestamp=timestamp)
response = self.stub.QueryKLines(request)
return response

View File

@ -34,18 +34,18 @@ class Stream(object):
async def subscribe(self):
async with grpc.aio.insecure_channel(self.address) as channel:
stub = bbgo_pb2_grpc.BBGOStub(channel)
stub = bbgo_pb2_grpc.MarketDataServiceStub(channel)
request = bbgo_pb2.SubscribeRequest(subscriptions=self.subscriptions)
async for response in stub.Subcribe(request):
async for response in stub.Subscribe(request):
self.dispatch(response)
async def subscribe_user_data(self):
async with grpc.aio.insecure_channel(self.address) as channel:
stub = bbgo_pb2_grpc.BBGOStub(channel)
stub = bbgo_pb2_grpc.UserDataServiceStub(channel)
request = bbgo_pb2.Empty()
async for response in stub.SubcribeUserData(request):
async for response in stub.SubscribeUserData(request):
self.dispatch_user_events(response)
def start(self):

View File

@ -29,9 +29,3 @@ def get_credentials_from_env():
private_key_certificate_chain_pairs = [(private_key, certificate_chain)]
server_credentials = grpc.ssl_server_credentials(private_key_certificate_chain_pairs)
return server_credentials
def create_stub(host, port):
address = f'{host}:{port}'
channel = grpc.insecure_channel(address)
return bbgo_pb2_grpc.BBGOStub(channel)

View File

@ -5,7 +5,7 @@ from bbgo import bbgo_pb2
from bbgo import bbgo_pb2_grpc
class TestServicer(bbgo_pb2_grpc.BBGOServicer):
class TestTradingServicer(bbgo_pb2_grpc.TradingServiceServicer):
def Subcribe(self, request, context):
i = 0
@ -28,7 +28,7 @@ class TestServicer(bbgo_pb2_grpc.BBGOServicer):
event=bbgo_pb2.Event.ORDER_UPDATE,
exchange='max',
symbol=f'user_{i}',
)
)
i += 1
time.sleep(random.random())

View File

@ -3,46 +3,53 @@ from concurrent import futures
import grpc
import pytest
from bbgo import BBGO
from bbgo import MarketService, TradingService
from bbgo import bbgo_pb2
from bbgo import bbgo_pb2_grpc
from tests.servicer import TestServicer
from tests.servicer import TestTradingServicer
@pytest.fixture
def grpc_address(host='[::]', port=50051):
def address(host='[::]', port=50051):
return f'{host}:{port}'
@pytest.fixture
def bbgo(host='[::]', port=50051):
return BBGO(host, port)
def channel(address):
return grpc.insecure_channel(address)
@pytest.fixture
def grpc_channel(grpc_address):
return grpc.insecure_channel(grpc_address)
def trading_service(channel):
trading_service_stub = bbgo_pb2_grpc.TradingServiceStub(channel)
return TradingService(trading_service_stub)
@pytest.fixture
def grpc_server(grpc_address, max_workers=1):
def market_service(channel):
market_service_stub = bbgo_pb2_grpc.MarketDataServiceStub(channel)
return MarketService(market_service_stub)
@pytest.fixture
def test_trading_servicer(address, max_workers=1):
server = grpc.server(futures.ThreadPoolExecutor(max_workers))
servicer = TestServicer()
bbgo_pb2_grpc.add_BBGOServicer_to_server(servicer, server)
server.add_insecure_port(grpc_address)
servicer = TestTradingServicer()
bbgo_pb2_grpc.add_TradingServiceServicer_to_server(servicer, server)
server.add_insecure_port(address)
server.start()
yield server
server.stop(grace=None)
def test_submit_order(bbgo, grpc_server):
def test_submit_order(trading_service, test_trading_servicer):
exchange = 'max'
symbol = 'BTCUSDT'
side = bbgo_pb2.Side.BUY
quantity = 0.01
order_type = bbgo_pb2.OrderType.LIMIT
response = bbgo.submit_order(
response = trading_service.submit_order(
exchange=exchange,
symbol=symbol,
side=side,