mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-10 09:11:55 +00:00
81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
import asyncio
|
|
from typing import Callable
|
|
from typing import List
|
|
|
|
import grpc
|
|
|
|
import bbgo_pb2
|
|
import bbgo_pb2_grpc
|
|
from bbgo.enums import ChannelType
|
|
from bbgo.enums import DepthType
|
|
|
|
from .data import Event
|
|
from .data import MarketDataEvent
|
|
from .data import Subscription
|
|
from .data import UserDataEvent
|
|
|
|
|
|
class Stream(object):
|
|
subscriptions: List[Subscription]
|
|
|
|
def __init__(self, host: str, port: int):
|
|
self.host = host
|
|
self.port = port
|
|
|
|
self.subscriptions = []
|
|
self.sessions = []
|
|
self.event_handlers = []
|
|
|
|
def subscribe(self, exchange: str, channel: str, symbol: str, depth: str = None, interval: str = None):
|
|
subscription = Subscription(exchange=exchange, channel=ChannelType.from_str(channel), symbol=symbol)
|
|
|
|
if depth is not None:
|
|
subscription.depth = DepthType(depth)
|
|
|
|
if interval is not None:
|
|
subscription.interval = interval
|
|
|
|
self.subscriptions.append(subscription)
|
|
|
|
def subscribe_user_data(self, session: str):
|
|
self.sessions.append(session)
|
|
|
|
def add_event_handler(self, event_handler: Callable) -> None:
|
|
self.event_handlers.append(event_handler)
|
|
|
|
def fire_event_handlers(self, event: Event) -> None:
|
|
for event_handler in self.event_handlers:
|
|
event_handler(event)
|
|
|
|
@property
|
|
def address(self):
|
|
return f'{self.host}:{self.port}'
|
|
|
|
async def _subscribe_market_data(self):
|
|
async with grpc.aio.insecure_channel(self.address) as channel:
|
|
stub = bbgo_pb2_grpc.MarketDataServiceStub(channel)
|
|
|
|
request = bbgo_pb2.SubscribeRequest(subscriptions=[s.to_pb() for s in self.subscriptions])
|
|
async for response in stub.Subscribe(request):
|
|
event = MarketDataEvent.from_pb(response)
|
|
self.fire_event_handlers(event)
|
|
|
|
async def _subscribe_user_data(self, session: str):
|
|
async with grpc.aio.insecure_channel(self.address) as channel:
|
|
stub = bbgo_pb2_grpc.UserDataServiceStub(channel)
|
|
|
|
request = bbgo_pb2.UserDataRequest(session=session)
|
|
async for response in stub.Subscribe(request):
|
|
event = UserDataEvent.from_pb(response)
|
|
self.fire_event_handlers(event)
|
|
|
|
def start(self):
|
|
coroutines = [self._subscribe_market_data()]
|
|
for session in self.sessions:
|
|
coroutines.append(self._subscribe_user_data(session))
|
|
|
|
group = asyncio.gather(*coroutines)
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(group)
|
|
loop.close()
|