diff --git a/python/bbgo/__init__.py b/python/bbgo/__init__.py index 13cb41ff2..a914a86ca 100644 --- a/python/bbgo/__init__.py +++ b/python/bbgo/__init__.py @@ -1,2 +1,3 @@ -from .bbgo import BBGO +from .bbgo import MarketService +from .bbgo import TradingService from .stream import Stream diff --git a/python/bbgo/bbgo.py b/python/bbgo/bbgo.py index 662601708..a7ffb7bdb 100644 --- a/python/bbgo/bbgo.py +++ b/python/bbgo/bbgo.py @@ -1,13 +1,39 @@ from typing import List from . import bbgo_pb2 -from .utils import create_stub +from . import bbgo_pb2_grpc -class BBGO(object): +class MarketService(object): - def __init__(self, host: str, port: int): - self.stub = create_stub(host, port) + def __init__(self, stub: bbgo_pb2_grpc.MarketDataServiceStub): + self.stub = stub + + def subscribe(self, subscriptions: List[bbgo_pb2.Subscription]): + request = bbgo_pb2.SubscribeRequest(subscriptions=subscriptions) + request_iter = self.stub.Subscribe(request) + return request_iter + + def query_klines(self, + exchange: str, + symbol: str, + limit: int = 30, + interval: int = 1, + timestamp: int = None) -> bbgo_pb2.QueryKLinesResponse: + request = bbgo_pb2.QueryKLinesRequest(exchange=exchange, + symbol=symbol, + limit=limit, + interval=interval, + timestamp=timestamp) + + response = self.market_data_stub.QueryKLines(request) + return response + + +class TradingService(object): + + def __init__(self, stub: bbgo_pb2_grpc.TradingServiceStub): + self.stub = stub def submit_order(self, exchange: str, @@ -87,18 +113,3 @@ class BBGO(object): offset=offset) response = self.stub.QueryTrades(request) return response - - def query_klines(self, - exchange: str, - symbol: str, - limit: int = 30, - interval: int = 1, - timestamp: int = None) -> bbgo_pb2.QueryKLinesResponse: - request = bbgo_pb2.QueryKLinesRequest(exchange=exchange, - symbol=symbol, - limit=limit, - interval=interval, - timestamp=timestamp) - - response = self.stub.QueryKLines(request) - return response diff --git a/python/bbgo/stream.py b/python/bbgo/stream.py index 9ffd0d5d1..99bf37229 100644 --- a/python/bbgo/stream.py +++ b/python/bbgo/stream.py @@ -34,18 +34,18 @@ class Stream(object): async def subscribe(self): async with grpc.aio.insecure_channel(self.address) as channel: - stub = bbgo_pb2_grpc.BBGOStub(channel) + stub = bbgo_pb2_grpc.MarketDataServiceStub(channel) request = bbgo_pb2.SubscribeRequest(subscriptions=self.subscriptions) - async for response in stub.Subcribe(request): + async for response in stub.Subscribe(request): self.dispatch(response) async def subscribe_user_data(self): async with grpc.aio.insecure_channel(self.address) as channel: - stub = bbgo_pb2_grpc.BBGOStub(channel) + stub = bbgo_pb2_grpc.UserDataServiceStub(channel) request = bbgo_pb2.Empty() - async for response in stub.SubcribeUserData(request): + async for response in stub.SubscribeUserData(request): self.dispatch_user_events(response) def start(self): diff --git a/python/bbgo/utils.py b/python/bbgo/utils.py index d03461c1d..a9fa4a964 100644 --- a/python/bbgo/utils.py +++ b/python/bbgo/utils.py @@ -29,9 +29,3 @@ def get_credentials_from_env(): private_key_certificate_chain_pairs = [(private_key, certificate_chain)] server_credentials = grpc.ssl_server_credentials(private_key_certificate_chain_pairs) return server_credentials - - -def create_stub(host, port): - address = f'{host}:{port}' - channel = grpc.insecure_channel(address) - return bbgo_pb2_grpc.BBGOStub(channel) diff --git a/python/tests/servicer.py b/python/tests/servicer.py index 82d1f5c22..d0478f17e 100644 --- a/python/tests/servicer.py +++ b/python/tests/servicer.py @@ -5,7 +5,7 @@ from bbgo import bbgo_pb2 from bbgo import bbgo_pb2_grpc -class TestServicer(bbgo_pb2_grpc.BBGOServicer): +class TestTradingServicer(bbgo_pb2_grpc.TradingServiceServicer): def Subcribe(self, request, context): i = 0 @@ -28,7 +28,7 @@ class TestServicer(bbgo_pb2_grpc.BBGOServicer): event=bbgo_pb2.Event.ORDER_UPDATE, exchange='max', symbol=f'user_{i}', - ) + ) i += 1 time.sleep(random.random()) diff --git a/python/tests/test_grpc.py b/python/tests/test_grpc.py index 2b2af2ba4..b98d2a174 100644 --- a/python/tests/test_grpc.py +++ b/python/tests/test_grpc.py @@ -3,46 +3,53 @@ from concurrent import futures import grpc import pytest -from bbgo import BBGO +from bbgo import MarketService, TradingService from bbgo import bbgo_pb2 from bbgo import bbgo_pb2_grpc -from tests.servicer import TestServicer +from tests.servicer import TestTradingServicer @pytest.fixture -def grpc_address(host='[::]', port=50051): +def address(host='[::]', port=50051): return f'{host}:{port}' @pytest.fixture -def bbgo(host='[::]', port=50051): - return BBGO(host, port) +def channel(address): + return grpc.insecure_channel(address) @pytest.fixture -def grpc_channel(grpc_address): - return grpc.insecure_channel(grpc_address) +def trading_service(channel): + trading_service_stub = bbgo_pb2_grpc.TradingServiceStub(channel) + return TradingService(trading_service_stub) @pytest.fixture -def grpc_server(grpc_address, max_workers=1): +def market_service(channel): + market_service_stub = bbgo_pb2_grpc.MarketDataServiceStub(channel) + return MarketService(market_service_stub) + + +@pytest.fixture +def test_trading_servicer(address, max_workers=1): server = grpc.server(futures.ThreadPoolExecutor(max_workers)) - servicer = TestServicer() - bbgo_pb2_grpc.add_BBGOServicer_to_server(servicer, server) - server.add_insecure_port(grpc_address) + servicer = TestTradingServicer() + bbgo_pb2_grpc.add_TradingServiceServicer_to_server(servicer, server) + server.add_insecure_port(address) server.start() yield server server.stop(grace=None) -def test_submit_order(bbgo, grpc_server): +def test_submit_order(trading_service, test_trading_servicer): exchange = 'max' symbol = 'BTCUSDT' side = bbgo_pb2.Side.BUY quantity = 0.01 order_type = bbgo_pb2.OrderType.LIMIT - response = bbgo.submit_order( + response = trading_service.submit_order( exchange=exchange, symbol=symbol, side=side,