diff --git a/python/bbgo/stream.py b/python/bbgo/stream.py index 040b89ce4..f87d1cdbd 100644 --- a/python/bbgo/stream.py +++ b/python/bbgo/stream.py @@ -18,12 +18,12 @@ from .data import UserDataEvent class Stream(object): subscriptions: List[Subscription] - def __init__(self, host: str, port: int, user_data: bool = False): + def __init__(self, host: str, port: int): self.host = host self.port = port - self.user_data = user_data self.subscriptions = [] + self.sessions = [] self.event_handlers = [] def subscribe(self, exchange: str, channel: str, symbol: str, depth: str = None, interval: str = None): @@ -37,6 +37,9 @@ class Stream(object): self.subscriptions.append(subscription) + def subscribe_user_data(self, session: str): + self.sessions.append(session) + def add_event_handler(self, event_handler: Callable) -> None: self.event_handlers.append(event_handler) @@ -48,7 +51,7 @@ class Stream(object): def address(self): return f'{self.host}:{self.port}' - async def _subscribe(self): + async def _subscribe_market_data(self): async with grpc.aio.insecure_channel(self.address) as channel: stub = bbgo_pb2_grpc.MarketDataServiceStub(channel) @@ -57,19 +60,19 @@ class Stream(object): event = MarketDataEvent.from_pb(response) self.fire_event_handlers(event) - async def _subscribe_user_data(self): + async def _subscribe_user_data(self, session: str): async with grpc.aio.insecure_channel(self.address) as channel: stub = bbgo_pb2_grpc.UserDataServiceStub(channel) - request = bbgo_pb2.Empty() - async for response in stub.SubscribeUserData(request): + request = bbgo_pb2.UserDataRequest(session=session) + async for response in stub.Subscribe(request): event = UserDataEvent.from_pb(response) self.fire_event_handlers(event) def start(self): - coroutines = [self._subscribe()] - if self.user_data: - coroutines.append(self._subscribe_user_data()) + coroutines = [self._subscribe_market_data()] + for session in self.sessions: + coroutines.append(self._subscribe_user_data(session)) group = asyncio.gather(*coroutines) loop = asyncio.get_event_loop() diff --git a/python/examples/query_klines.py b/python/examples/query_klines.py index a8f0b477b..d03be201f 100644 --- a/python/examples/query_klines.py +++ b/python/examples/query_klines.py @@ -9,13 +9,16 @@ from bbgo import MarketService def main(host, port): service = MarketService(host, port) - klines, error = service.query_klines(exchange='binance', symbol='BTCUSDT', interval='1m', limit=10) + klines = service.query_klines( + exchange='binance', + symbol='BTCUSDT', + interval='1m', + limit=10, + ) for kline in klines: print(kline) - print(error) - if __name__ == '__main__': main() diff --git a/python/examples/stream.py b/python/examples/stream.py index 1dc0ed0bd..571ceb5d9 100644 --- a/python/examples/stream.py +++ b/python/examples/stream.py @@ -19,6 +19,7 @@ def main(host, port): stream = Stream(host, port) stream.subscribe('max', 'book', 'BTCUSDT', 'full') stream.subscribe('max', 'book', 'ETHUSDT', 'full') + stream.subscribe_user_data('max') stream.add_event_handler(LogBook()) stream.start() diff --git a/python/pyproject.toml b/python/pyproject.toml index 2071e0f2e..74b9cb314 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "bbgo" -version = "0.1.3" +version = "0.1.4" description = "" authors = ["なるみ "] packages = [