diff --git a/python/bbgo/services.py b/python/bbgo/services.py index 13c83b483..f37e83dcc 100644 --- a/python/bbgo/services.py +++ b/python/bbgo/services.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Iterator from typing import List from typing import Tuple @@ -10,12 +12,14 @@ from .data import KLine from .data import MarketDataEvent from .data import Subscription from .data import UserDataEvent +from .utils import get_insecure_channel class UserDataService(object): + stub: bbgo_pb2_grpc.UserDataServiceStub - def __init__(self, stub: bbgo_pb2_grpc.UserDataServiceStub) -> None: - self.stub = stub + def __init__(self, host: str, port: int) -> None: + self.stub = bbgo_pb2_grpc.UserDataServiceStub(get_insecure_channel(host, port)) def subscribe(self, session: str) -> Iterator[UserDataEvent]: request = bbgo_pb2.UserDataRequest(session) @@ -26,9 +30,10 @@ class UserDataService(object): class MarketService(object): + stub: bbgo_pb2_grpc.MarketDataServiceStub - def __init__(self, stub: bbgo_pb2_grpc.MarketDataServiceStub) -> None: - self.stub = stub + def __init__(self, host: str, port: int) -> None: + self.stub = bbgo_pb2_grpc.MarketDataServiceStub(get_insecure_channel(host, port)) def subscribe(self, subscriptions: List[Subscription]) -> Iterator[MarketDataEvent]: request = bbgo_pb2.SubscribeRequest(subscriptions=[s.to_pb() for s in subscriptions]) @@ -63,9 +68,10 @@ class MarketService(object): class TradingService(object): + stub: bbgo_pb2_grpc.TradingServiceStub - def __init__(self, stub: bbgo_pb2_grpc.TradingServiceStub): - self.stub = stub + def __init__(self, host: str, port: int) -> None: + self.stub = bbgo_pb2_grpc.TradingServiceStub(get_insecure_channel(host, port)) def submit_order(self, exchange: str, diff --git a/python/bbgo/utils.py b/python/bbgo/utils.py index 1b93ef538..ffdb3049c 100644 --- a/python/bbgo/utils.py +++ b/python/bbgo/utils.py @@ -27,3 +27,17 @@ def get_credentials_from_env(): private_key_certificate_chain_pairs = [(private_key, certificate_chain)] server_credentials = grpc.ssl_server_credentials(private_key_certificate_chain_pairs) return server_credentials + + +def get_insecure_channel(host: str, port: int) -> grpc.Channel: + address = f'{host}:{port}' + return grpc.insecure_channel(address) + + +def get_insecure_channel_from_env() -> grpc.Channel: + host = os.environ.get('BBGO_GRPC_HOST') or '127.0.0.1' + port = os.environ.get('BBGO_GRPC_PORT') or 50051 + + address = get_insecure_channel(host, port) + + return grpc.insecure_channel(address) diff --git a/python/query_klines.py b/python/query_klines.py index 153a282da..a8f0b477b 100644 --- a/python/query_klines.py +++ b/python/query_klines.py @@ -1,7 +1,5 @@ import click -import grpc -import bbgo_pb2_grpc from bbgo import MarketService @@ -9,11 +7,7 @@ from bbgo import MarketService @click.option('--host', default='127.0.0.1') @click.option('--port', default=50051) def main(host, port): - address = f'{host}:{port}' - channel = grpc.insecure_channel(address) - - stub = bbgo_pb2_grpc.MarketDataServiceStub(channel) - service = MarketService(stub) + service = MarketService(host, port) klines, error = service.query_klines(exchange='binance', symbol='BTCUSDT', interval='1m', limit=10) diff --git a/python/subscribe.py b/python/subscribe.py index 3ed49f0cc..abd0ff963 100644 --- a/python/subscribe.py +++ b/python/subscribe.py @@ -1,8 +1,6 @@ import click -import grpc from loguru import logger -import bbgo_pb2_grpc from bbgo import MarketService from bbgo.data import Subscription from bbgo.enums import ChannelType @@ -16,11 +14,8 @@ def main(host, port): subscriptions = [ Subscription('binance', ChannelType.BOOK, symbol='BTCUSDT', depth=DepthType.FULL), ] - address = f'{host}:{port}' - channel = grpc.insecure_channel(address) - stub = bbgo_pb2_grpc.MarketDataServiceStub(channel) - service = MarketService(stub) + service = MarketService(host, port) response_iter = service.subscribe(subscriptions) for response in response_iter: logger.info(response)