Merge pull request #7558 from wizrds/feat/queue-per-client-ws

Refactor broadcasting in Message Websocket
This commit is contained in:
Matthias 2022-10-13 09:52:29 +02:00 committed by GitHub
commit e3ca740704
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 21 deletions

View File

@ -4,6 +4,7 @@ from typing import Any, Dict
from fastapi import APIRouter, Depends, WebSocketDisconnect from fastapi import APIRouter, Depends, WebSocketDisconnect
from fastapi.websockets import WebSocket, WebSocketState from fastapi.websockets import WebSocket, WebSocketState
from pydantic import ValidationError from pydantic import ValidationError
from websockets.exceptions import WebSocketException
from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.enums import RPCMessageType, RPCRequestType
from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.api_auth import validate_ws_token
@ -102,7 +103,6 @@ async def message_endpoint(
""" """
try: try:
channel = await channel_manager.on_connect(ws) channel = await channel_manager.on_connect(ws)
if await is_websocket_alive(ws): if await is_websocket_alive(ws):
logger.info(f"Consumer connected - {channel}") logger.info(f"Consumer connected - {channel}")
@ -115,26 +115,31 @@ async def message_endpoint(
# Process the request here # Process the request here
await _process_consumer_request(request, channel, rpc) await _process_consumer_request(request, channel, rpc)
except WebSocketDisconnect: except (WebSocketDisconnect, WebSocketException):
# Handle client disconnects # Handle client disconnects
logger.info(f"Consumer disconnected - {channel}") logger.info(f"Consumer disconnected - {channel}")
await channel_manager.on_disconnect(ws) except RuntimeError:
except Exception as e:
logger.info(f"Consumer connection failed - {channel}")
logger.exception(e)
# Handle cases like - # Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent') # RuntimeError('Cannot call "send" once a closed message has been sent')
pass
except Exception as e:
logger.info(f"Consumer connection failed - {channel}: {e}")
logger.debug(e, exc_info=e)
finally:
await channel_manager.on_disconnect(ws) await channel_manager.on_disconnect(ws)
else: else:
if channel:
await channel_manager.on_disconnect(ws)
await ws.close() await ws.close()
except RuntimeError: except RuntimeError:
# WebSocket was closed # WebSocket was closed
await channel_manager.on_disconnect(ws) # Do nothing
pass
except Exception as e: except Exception as e:
logger.error(f"Failed to serve - {ws.client}") logger.error(f"Failed to serve - {ws.client}")
# Log tracebacks to keep track of what errors are happening # Log tracebacks to keep track of what errors are happening
logger.exception(e) logger.exception(e)
finally:
await channel_manager.on_disconnect(ws) await channel_manager.on_disconnect(ws)

View File

@ -198,10 +198,6 @@ class ApiServer(RPCHandler):
logger.debug(f"Found message of type: {message.get('type')}") logger.debug(f"Found message of type: {message.get('type')}")
# Broadcast it # Broadcast it
await self._ws_channel_manager.broadcast(message) await self._ws_channel_manager.broadcast(message)
# Limit messages per sec.
# Could cause problems with queue size if too low, and
# problems with network traffik if too high.
await asyncio.sleep(0.001)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
@ -245,6 +241,7 @@ class ApiServer(RPCHandler):
use_colors=False, use_colors=False,
log_config=None, log_config=None,
access_log=True if verbosity != 'error' else False, access_log=True if verbosity != 'error' else False,
ws_ping_interval=None # We do this explicitly ourselves
) )
try: try:
self._server = UvicornServer(uvconfig) self._server = UvicornServer(uvconfig)

View File

@ -1,6 +1,7 @@
import asyncio
import logging import logging
from threading import RLock from threading import RLock
from typing import List, Optional, Type from typing import Any, Dict, List, Optional, Type
from uuid import uuid4 from uuid import uuid4
from fastapi import WebSocket as FastAPIWebSocket from fastapi import WebSocket as FastAPIWebSocket
@ -34,6 +35,8 @@ class WebSocketChannel:
self._serializer_cls = serializer_cls self._serializer_cls = serializer_cls
self._subscriptions: List[str] = [] self._subscriptions: List[str] = []
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
self._relay_task = asyncio.create_task(self.relay())
# Internal event to signify a closed websocket # Internal event to signify a closed websocket
self._closed = False self._closed = False
@ -48,12 +51,18 @@ class WebSocketChannel:
def remote_addr(self): def remote_addr(self):
return self._websocket.remote_addr return self._websocket.remote_addr
async def send(self, data): async def _send(self, data):
""" """
Send data on the wrapped websocket Send data on the wrapped websocket
""" """
await self._wrapped_ws.send(data) await self._wrapped_ws.send(data)
async def send(self, data):
"""
Add the data to the queue to be sent
"""
self.queue.put_nowait(data)
async def recv(self): async def recv(self):
""" """
Receive data on the wrapped websocket Receive data on the wrapped websocket
@ -72,6 +81,7 @@ class WebSocketChannel:
""" """
self._closed = True self._closed = True
self._relay_task.cancel()
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """
@ -95,6 +105,26 @@ class WebSocketChannel:
""" """
return message_type in self._subscriptions return message_type in self._subscriptions
async def relay(self):
"""
Relay messages from the channel's queue and send them out. This is started
as a task.
"""
while True:
message = await self.queue.get()
try:
await self._send(message)
self.queue.task_done()
# Limit messages per sec.
# Could cause problems with queue size if too low, and
# problems with network traffik if too high.
# 0.001 = 1000/s
await asyncio.sleep(0.001)
except RuntimeError:
# The connection was closed, just exit the task
return
class ChannelManager: class ChannelManager:
def __init__(self): def __init__(self):
@ -155,12 +185,12 @@ class ChannelManager:
with self._lock: with self._lock:
message_type = data.get('type') message_type = data.get('type')
for websocket, channel in self.channels.copy().items(): for websocket, channel in self.channels.copy().items():
try: if channel.subscribed_to(message_type):
if channel.subscribed_to(message_type): if not channel.queue.full():
await channel.send(data) await channel.send(data)
except RuntimeError: else:
# Handle cannot send after close cases logger.info(f"Channel {channel} is too far behind, disconnecting")
await self.on_disconnect(websocket) await self.on_disconnect(websocket)
async def send_direct(self, channel, data): async def send_direct(self, channel, data):
""" """

View File

@ -62,7 +62,7 @@ class ExternalMessageConsumer:
self.enabled = self._emc_config.get('enabled', False) self.enabled = self._emc_config.get('enabled', False)
self.producers: List[Producer] = self._emc_config.get('producers', []) self.producers: List[Producer] = self._emc_config.get('producers', [])
self.wait_timeout = self._emc_config.get('wait_timeout', 300) # in seconds self.wait_timeout = self._emc_config.get('wait_timeout', 30) # in seconds
self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds
self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds
@ -174,6 +174,7 @@ class ExternalMessageConsumer:
:param producer: Dictionary containing producer info :param producer: Dictionary containing producer info
:param lock: An asyncio Lock :param lock: An asyncio Lock
""" """
channel = None
while self._running: while self._running:
try: try:
host, port = producer['host'], producer['port'] host, port = producer['host'], producer['port']
@ -182,7 +183,11 @@ class ExternalMessageConsumer:
ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}" ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}"
# This will raise InvalidURI if the url is bad # This will raise InvalidURI if the url is bad
async with websockets.connect(ws_url, max_size=self.message_size_limit) as ws: async with websockets.connect(
ws_url,
max_size=self.message_size_limit,
ping_interval=None
) as ws:
channel = WebSocketChannel(ws, channel_id=name) channel = WebSocketChannel(ws, channel_id=name)
logger.info(f"Producer connection success - {channel}") logger.info(f"Producer connection success - {channel}")
@ -224,6 +229,10 @@ class ExternalMessageConsumer:
logger.exception(e) logger.exception(e)
continue continue
finally:
if channel:
await channel.close()
async def _receive_messages( async def _receive_messages(
self, self,
channel: WebSocketChannel, channel: WebSocketChannel,