mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
removed sleep calls, better channel sending
This commit is contained in:
parent
2b6d00dde4
commit
3d7a311caa
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
|
@ -11,6 +10,7 @@ from freqtrade.enums import RPCMessageType, RPCRequestType
|
|||
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
||||
from freqtrade.rpc.api_server.deps import get_channel_manager, get_rpc
|
||||
from freqtrade.rpc.api_server.ws import WebSocketChannel
|
||||
from freqtrade.rpc.api_server.ws.channel import ChannelManager
|
||||
from freqtrade.rpc.api_server.ws_schemas import (WSAnalyzedDFMessage, WSMessageSchema,
|
||||
WSRequestSchema, WSWhitelistMessage)
|
||||
from freqtrade.rpc.rpc import RPC
|
||||
|
@ -37,7 +37,8 @@ async def is_websocket_alive(ws: WebSocket) -> bool:
|
|||
async def _process_consumer_request(
|
||||
request: Dict[str, Any],
|
||||
channel: WebSocketChannel,
|
||||
rpc: RPC
|
||||
rpc: RPC,
|
||||
channel_manager: ChannelManager
|
||||
):
|
||||
"""
|
||||
Validate and handle a request from a websocket consumer
|
||||
|
@ -72,9 +73,9 @@ async def _process_consumer_request(
|
|||
whitelist = rpc._ws_request_whitelist()
|
||||
|
||||
# Format response
|
||||
response = WSWhitelistMessage(data=whitelist)
|
||||
response = WSWhitelistMessage(data=whitelist).dict(exclude_none=True)
|
||||
# Send it back
|
||||
await channel.send(response.dict(exclude_none=True))
|
||||
await channel_manager.send_direct(channel, response)
|
||||
|
||||
elif type == RPCRequestType.ANALYZED_DF:
|
||||
limit = None
|
||||
|
@ -88,10 +89,8 @@ async def _process_consumer_request(
|
|||
|
||||
# For every dataframe, send as a separate message
|
||||
for _, message in analyzed_df.items():
|
||||
response = WSAnalyzedDFMessage(data=message)
|
||||
await channel.send(response.dict(exclude_none=True))
|
||||
# Throttle the messages to 50/s
|
||||
await asyncio.sleep(0.02)
|
||||
response = WSAnalyzedDFMessage(data=message).dict(exclude_none=True)
|
||||
await channel_manager.send_direct(channel, response)
|
||||
|
||||
|
||||
@router.websocket("/message/ws")
|
||||
|
@ -116,7 +115,7 @@ async def message_endpoint(
|
|||
request = await channel.recv()
|
||||
|
||||
# Process the request here
|
||||
await _process_consumer_request(request, channel, rpc)
|
||||
await _process_consumer_request(request, channel, rpc, channel_manager)
|
||||
|
||||
except (WebSocketDisconnect, WebSocketException):
|
||||
# Handle client disconnects
|
||||
|
|
|
@ -198,10 +198,6 @@ class ApiServer(RPCHandler):
|
|||
logger.debug(f"Found message of type: {message.get('type')}")
|
||||
# Broadcast it
|
||||
await self._ws_channel_manager.broadcast(message)
|
||||
# Limit messages per sec.
|
||||
# Could cause problems with queue size if too low, and
|
||||
# problems with network traffik if too high.
|
||||
await asyncio.sleep(0.001)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ class WebSocketChannel:
|
|||
websocket: WebSocketType,
|
||||
channel_id: Optional[str] = None,
|
||||
drain_timeout: int = 3,
|
||||
throttle: float = 0.01,
|
||||
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
|
||||
):
|
||||
|
||||
|
@ -36,6 +37,7 @@ class WebSocketChannel:
|
|||
self._serializer_cls = serializer_cls
|
||||
|
||||
self.drain_timeout = drain_timeout
|
||||
self.throttle = throttle
|
||||
|
||||
self._subscriptions: List[str] = []
|
||||
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
|
||||
|
@ -50,6 +52,10 @@ class WebSocketChannel:
|
|||
def __repr__(self):
|
||||
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
|
||||
|
||||
@property
|
||||
def raw(self):
|
||||
return self._websocket.raw
|
||||
|
||||
@property
|
||||
def remote_addr(self):
|
||||
return self._websocket.remote_addr
|
||||
|
@ -131,7 +137,7 @@ class WebSocketChannel:
|
|||
# Could cause problems with queue size if too low, and
|
||||
# problems with network traffik if too high.
|
||||
# 0.01 = 100/s
|
||||
await asyncio.sleep(0.01)
|
||||
await asyncio.sleep(self.throttle)
|
||||
except RuntimeError:
|
||||
# The connection was closed, just exit the task
|
||||
return
|
||||
|
@ -171,6 +177,7 @@ class ChannelManager:
|
|||
with self._lock:
|
||||
channel = self.channels.get(websocket)
|
||||
if channel:
|
||||
logger.info(f"Disconnecting channel {channel}")
|
||||
if not channel.is_closed():
|
||||
await channel.close()
|
||||
|
||||
|
@ -181,9 +188,8 @@ class ChannelManager:
|
|||
Disconnect all Channels
|
||||
"""
|
||||
with self._lock:
|
||||
for websocket, channel in self.channels.copy().items():
|
||||
if not channel.is_closed():
|
||||
await channel.close()
|
||||
for websocket in self.channels.copy().keys():
|
||||
await self.on_disconnect(websocket)
|
||||
|
||||
self.channels = dict()
|
||||
|
||||
|
@ -195,11 +201,9 @@ class ChannelManager:
|
|||
"""
|
||||
with self._lock:
|
||||
message_type = data.get('type')
|
||||
for websocket, channel in self.channels.copy().items():
|
||||
for channel in self.channels.copy().values():
|
||||
if channel.subscribed_to(message_type):
|
||||
if not await channel.send(data):
|
||||
logger.info(f"Channel {channel} is too far behind, disconnecting")
|
||||
await self.on_disconnect(websocket)
|
||||
await self.send_direct(channel, data)
|
||||
|
||||
async def send_direct(self, channel, data):
|
||||
"""
|
||||
|
@ -208,7 +212,8 @@ class ChannelManager:
|
|||
:param direct_channel: The WebSocketChannel object to send data through
|
||||
:param data: The data to send
|
||||
"""
|
||||
await channel.send(data)
|
||||
if not await channel.send(data):
|
||||
await self.on_disconnect(channel.raw)
|
||||
|
||||
def has_channels(self):
|
||||
"""
|
||||
|
|
|
@ -15,6 +15,10 @@ class WebSocketProxy:
|
|||
def __init__(self, websocket: WebSocketType):
|
||||
self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket
|
||||
|
||||
@property
|
||||
def raw(self):
|
||||
return self._websocket
|
||||
|
||||
@property
|
||||
def remote_addr(self) -> Tuple[Any, ...]:
|
||||
if isinstance(self._websocket, WebSocket):
|
||||
|
|
|
@ -270,6 +270,11 @@ class ExternalMessageConsumer:
|
|||
logger.debug(f"Connection to {channel} still alive...")
|
||||
|
||||
continue
|
||||
except (websockets.exceptions.ConnectionClosed):
|
||||
# Just eat the error and continue reconnecting
|
||||
logger.warning(f"Disconnection in {channel} - retrying in {self.sleep_time}s")
|
||||
await asyncio.sleep(self.sleep_time)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Ping error {channel} - retrying in {self.sleep_time}s")
|
||||
logger.debug(e, exc_info=e)
|
||||
|
|
Loading…
Reference in New Issue
Block a user