diff --git a/freqtrade/rpc/api_server/api_ws.py b/freqtrade/rpc/api_server/api_ws.py index 2f490b8a8..f3f6b852d 100644 --- a/freqtrade/rpc/api_server/api_ws.py +++ b/freqtrade/rpc/api_server/api_ws.py @@ -1,4 +1,3 @@ -import asyncio import logging from typing import Any, Dict @@ -11,6 +10,7 @@ from freqtrade.enums import RPCMessageType, RPCRequestType from freqtrade.rpc.api_server.api_auth import validate_ws_token from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc from freqtrade.rpc.api_server.ws import WebSocketChannel +from freqtrade.rpc.api_server.ws.channel import ChannelManager from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema, WSRequestSchema, WSWhitelistMessage) from freqtrade.rpc.rpc import RPC @@ -37,7 +37,8 @@ async def is_websocket_alive(ws: WebSocket) -> bool: async def _process_consumer_request( request: Dict[str, Any], channel: WebSocketChannel, - rpc: RPC + rpc: RPC, + channel_manager: ChannelManager ): """ Validate and handle a request from a websocket consumer @@ -72,9 +73,9 @@ async def _process_consumer_request( whitelist = rpc._ws_request_whitelist() # Format response - response = WSWhitelistMessage(data=whitelist) + response = WSWhitelistMessage(data=whitelist).dict(exclude_none=True) # Send it back - await channel.send(response.dict(exclude_none=True)) + await channel_manager.send_direct(channel, response) elif type == RPCRequestType.ANALYZED_DF: limit = None @@ -88,10 +89,8 @@ async def _process_consumer_request( # For every dataframe, send as a separate message for _, message in analyzed_df.items(): - response = WSAnalyzedDFMessage(data=message) - await channel.send(response.dict(exclude_none=True)) - # Throttle the messages to 50/s - await asyncio.sleep(0.02) + response = WSAnalyzedDFMessage(data=message).dict(exclude_none=True) + await channel_manager.send_direct(channel, response) @router.websocket("/message/ws") @@ -116,7 +115,7 @@ async def message_endpoint( request = await channel.recv() # Process the request here - await _process_consumer_request(request, channel, rpc) + await _process_consumer_request(request, channel, rpc, channel_manager) except (WebSocketDisconnect, WebSocketException): # Handle client disconnects diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index 4a09fd78e..c6639f1a6 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -198,10 +198,6 @@ class ApiServer(RPCHandler): logger.debug(f"Found message of type: {message.get('type')}") # Broadcast it await self._ws_channel_manager.broadcast(message) - # Limit messages per sec. - # Could cause problems with queue size if too low, and - # problems with network traffik if too high. - await asyncio.sleep(0.001) except asyncio.CancelledError: pass diff --git a/freqtrade/rpc/api_server/ws/channel.py b/freqtrade/rpc/api_server/ws/channel.py index a1334bce9..4afca0d33 100644 --- a/freqtrade/rpc/api_server/ws/channel.py +++ b/freqtrade/rpc/api_server/ws/channel.py @@ -25,6 +25,7 @@ class WebSocketChannel: websocket: WebSocketType, channel_id: Optional[str] = None, drain_timeout: int = 3, + throttle: float = 0.01, serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer ): @@ -36,6 +37,7 @@ class WebSocketChannel: self._serializer_cls = serializer_cls self.drain_timeout = drain_timeout + self.throttle = throttle self._subscriptions: List[str] = [] self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32) @@ -50,6 +52,10 @@ class WebSocketChannel: def __repr__(self): return f"WebSocketChannel({self.channel_id}, {self.remote_addr})" + @property + def raw(self): + return self._websocket.raw + @property def remote_addr(self): return self._websocket.remote_addr @@ -131,7 +137,7 @@ class WebSocketChannel: # Could cause problems with queue size if too low, and # problems with network traffik if too high. # 0.01 = 100/s - await asyncio.sleep(0.01) + await asyncio.sleep(self.throttle) except RuntimeError: # The connection was closed, just exit the task return @@ -171,6 +177,7 @@ class ChannelManager: with self._lock: channel = self.channels.get(websocket) if channel: + logger.info(f"Disconnecting channel {channel}") if not channel.is_closed(): await channel.close() @@ -181,9 +188,8 @@ class ChannelManager: Disconnect all Channels """ with self._lock: - for websocket, channel in self.channels.copy().items(): - if not channel.is_closed(): - await channel.close() + for websocket in self.channels.copy().keys(): + await self.on_disconnect(websocket) self.channels = dict() @@ -195,11 +201,9 @@ class ChannelManager: """ with self._lock: message_type = data.get('type') - for websocket, channel in self.channels.copy().items(): + for channel in self.channels.copy().values(): if channel.subscribed_to(message_type): - if not await channel.send(data): - logger.info(f"Channel {channel} is too far behind, disconnecting") - await self.on_disconnect(websocket) + await self.send_direct(channel, data) async def send_direct(self, channel, data): """ @@ -208,7 +212,8 @@ class ChannelManager: :param direct_channel: The WebSocketChannel object to send data through :param data: The data to send """ - await channel.send(data) + if not await channel.send(data): + await self.on_disconnect(channel.raw) def has_channels(self): """ diff --git a/freqtrade/rpc/api_server/ws/proxy.py b/freqtrade/rpc/api_server/ws/proxy.py index 2e5a59f05..8518709aa 100644 --- a/freqtrade/rpc/api_server/ws/proxy.py +++ b/freqtrade/rpc/api_server/ws/proxy.py @@ -15,6 +15,10 @@ class WebSocketProxy: def __init__(self, websocket: WebSocketType): self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket + @property + def raw(self): + return self._websocket + @property def remote_addr(self) -> Tuple[Any, ...]: if isinstance(self._websocket, WebSocket): diff --git a/freqtrade/rpc/external_message_consumer.py b/freqtrade/rpc/external_message_consumer.py index 01bc974ad..e86f44c17 100644 --- a/freqtrade/rpc/external_message_consumer.py +++ b/freqtrade/rpc/external_message_consumer.py @@ -270,6 +270,11 @@ class ExternalMessageConsumer: logger.debug(f"Connection to {channel} still alive...") continue + except (websockets.exceptions.ConnectionClosed): + # Just eat the error and continue reconnecting + logger.warning(f"Disconnection in {channel} - retrying in {self.sleep_time}s") + await asyncio.sleep(self.sleep_time) + break except Exception as e: logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s") logger.debug(e, exc_info=e)