diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index f55b2dbd3..46909955d 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -4,6 +4,7 @@ from typing import Any, Dict from fastapi import APIRouter, Depends, WebSocketDisconnect from fastapi.websockets import WebSocket, WebSocketState from pydantic import ValidationError +from websockets.exceptions import WebSocketException from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.rpc.api_server.api_auth import validate_ws_token @@ -102,7 +103,6 @@ async def message_endpoint( """ try: channel = await channel_manager.on_connect(ws) - if await is_websocket_alive(ws): logger.info(f"Consumer connected - {channel}") @@ -115,26 +115,31 @@ async def message_endpoint( # Process the request here await _process_consumer_request(request, channel, rpc) - except WebSocketDisconnect: + except (WebSocketDisconnect, WebSocketException): # Handle client disconnects logger.info(f"Consumer disconnected - {channel}") - await channel_manager.on_disconnect(ws) - except Exception as e: - logger.info(f"Consumer connection failed - {channel}") - logger.exception(e) + except RuntimeError: # Handle cases like - # RuntimeError('Cannot call "send" once a closed message has been sent') + pass + except Exception as e: + logger.info(f"Consumer connection failed - {channel}: {e}") + logger.debug(e, exc_info=e) + finally: await channel_manager.on_disconnect(ws) else: + if channel: + await channel_manager.on_disconnect(ws) await ws.close() except RuntimeError: # WebSocket was closed - await channel_manager.on_disconnect(ws) - + # Do nothing + pass except Exception as e: logger.error(f"Failed to serve - {ws.client}") # Log tracebacks to keep track of what errors are happening logger.exception(e) + finally: await channel_manager.on_disconnect(ws) diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 53af91477..c6639f1a6 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -198,10 +198,6 @@ class ApiServer(RPCHandler): logger.debug(f"Found message of type: {message.get('type')}") # Broadcast it await self._ws_channel_manager.broadcast(message) - # Limit messages per sec. - # Could cause problems with queue size if too low, and - # problems with network traffik if too high. - await asyncio.sleep(0.001) except asyncio.CancelledError: pass @@ -245,6 +241,7 @@ class ApiServer(RPCHandler): use_colors=False, log_config=None, access_log=True if verbosity != 'error' else False, + ws_ping_interval=None # We do this explicitly ourselves ) try: self._server = UvicornServer(uvconfig) diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index 69a32e266..e9dbd63be 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -1,6 +1,7 @@ +import asyncio import logging from threading import RLock -from typing import List, Optional, Type +from typing import Any, Dict, List, Optional, Type from uuid import uuid4 from fastapi import WebSocket as FastAPIWebSocket @@ -34,6 +35,8 @@ class WebSocketChannel: self._serializer_cls = serializer_cls self._subscriptions: List[str] = [] + self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) + self._relay_task = asyncio.create_task(self.relay()) # Internal event to signify a closed websocket self._closed = False @@ -48,12 +51,18 @@ class WebSocketChannel: def remote_addr(self): return self._websocket.remote_addr - async def send(self, data): + async def _send(self, data): """ Send data on the wrapped websocket """ await self._wrapped_ws.send(data) + async def send(self, data): + """ + Add the data to the queue to be sent + """ + self.queue.put_nowait(data) + async def recv(self): """ Receive data on the wrapped websocket @@ -72,6 +81,7 @@ class WebSocketChannel: """ self._closed = True + self._relay_task.cancel() def is_closed(self) -> bool: """ @@ -95,6 +105,26 @@ class WebSocketChannel: """ return message_type in self._subscriptions + async def relay(self): + """ + Relay messages from the channel's queue and send them out. This is started + as a task. + """ + while True: + message = await self.queue.get() + try: + await self._send(message) + self.queue.task_done() + + # Limit messages per sec. + # Could cause problems with queue size if too low, and + # problems with network traffik if too high. + # 0.001 = 1000/s + await asyncio.sleep(0.001) + except RuntimeError: + # The connection was closed, just exit the task + return + class ChannelManager: def __init__(self): @@ -155,12 +185,12 @@ class ChannelManager: with self._lock: message_type = data.get('type') for websocket, channel in self.channels.copy().items(): - try: - if channel.subscribed_to(message_type): + if channel.subscribed_to(message_type): + if not channel.queue.full(): await channel.send(data) - except RuntimeError: - # Handle cannot send after close cases - await self.on_disconnect(websocket) + else: + logger.info(f"Channel {channel} is too far behind, disconnecting") + await self.on_disconnect(websocket) async def send_direct(self, channel, data): """ diff --git a/freqtrade/rpc/external_message_consumer.py b/freqtrade/rpc/external_message_consumer.py index f5ba4b490..01bc974ad 100644 --- a/freqtrade/rpc/external_message_consumer.py +++ b/freqtrade/rpc/external_message_consumer.py @@ -62,7 +62,7 @@ class ExternalMessageConsumer: self.enabled = self._emc_config.get('enabled', False) self.producers: List[Producer] = self._emc_config.get('producers', []) - self.wait_timeout = self._emc_config.get('wait_timeout', 300) # in seconds + self.wait_timeout = self._emc_config.get('wait_timeout', 30) # in seconds self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds @@ -174,6 +174,7 @@ class ExternalMessageConsumer: :param producer: Dictionary containing producer info :param lock: An asyncio Lock """ + channel = None while self._running: try: host, port = producer['host'], producer['port'] @@ -182,7 +183,11 @@ class ExternalMessageConsumer: ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}" # This will raise InvalidURI if the url is bad - async with websockets.connect(ws_url, max_size=self.message_size_limit) as ws: + async with websockets.connect( + ws_url, + max_size=self.message_size_limit, + ping_interval=None + ) as ws: channel = WebSocketChannel(ws, channel_id=name) logger.info(f"Producer connection success - {channel}") @@ -224,6 +229,10 @@ class ExternalMessageConsumer: logger.exception(e) continue + finally: + if channel: + await channel.close() + async def _receive_messages( self, channel: WebSocketChannel,