mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
Merge pull request #7558 from wizrds/feat/queue-per-client-ws
Refactor broadcasting in Message Websocket
This commit is contained in:
commit
e3ca740704
|
@ -4,6 +4,7 @@ from typing import Any, Dict
|
||||||
from fastapi import APIRouter, Depends, WebSocketDisconnect
|
from fastapi import APIRouter, Depends, WebSocketDisconnect
|
||||||
from fastapi.websockets import WebSocket, WebSocketState
|
from fastapi.websockets import WebSocket, WebSocketState
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
from websockets.exceptions import WebSocketException
|
||||||
|
|
||||||
from freqtrade.enums import RPCMessageType, RPCRequestType
|
from freqtrade.enums import RPCMessageType, RPCRequestType
|
||||||
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
from freqtrade.rpc.api_server.api_auth import validate_ws_token
|
||||||
|
@ -102,7 +103,6 @@ async def message_endpoint(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
channel = await channel_manager.on_connect(ws)
|
channel = await channel_manager.on_connect(ws)
|
||||||
|
|
||||||
if await is_websocket_alive(ws):
|
if await is_websocket_alive(ws):
|
||||||
|
|
||||||
logger.info(f"Consumer connected - {channel}")
|
logger.info(f"Consumer connected - {channel}")
|
||||||
|
@ -115,26 +115,31 @@ async def message_endpoint(
|
||||||
# Process the request here
|
# Process the request here
|
||||||
await _process_consumer_request(request, channel, rpc)
|
await _process_consumer_request(request, channel, rpc)
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except (WebSocketDisconnect, WebSocketException):
|
||||||
# Handle client disconnects
|
# Handle client disconnects
|
||||||
logger.info(f"Consumer disconnected - {channel}")
|
logger.info(f"Consumer disconnected - {channel}")
|
||||||
await channel_manager.on_disconnect(ws)
|
except RuntimeError:
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"Consumer connection failed - {channel}")
|
|
||||||
logger.exception(e)
|
|
||||||
# Handle cases like -
|
# Handle cases like -
|
||||||
# RuntimeError('Cannot call "send" once a closed message has been sent')
|
# RuntimeError('Cannot call "send" once a closed message has been sent')
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Consumer connection failed - {channel}: {e}")
|
||||||
|
logger.debug(e, exc_info=e)
|
||||||
|
finally:
|
||||||
await channel_manager.on_disconnect(ws)
|
await channel_manager.on_disconnect(ws)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
if channel:
|
||||||
|
await channel_manager.on_disconnect(ws)
|
||||||
await ws.close()
|
await ws.close()
|
||||||
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# WebSocket was closed
|
# WebSocket was closed
|
||||||
await channel_manager.on_disconnect(ws)
|
# Do nothing
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to serve - {ws.client}")
|
logger.error(f"Failed to serve - {ws.client}")
|
||||||
# Log tracebacks to keep track of what errors are happening
|
# Log tracebacks to keep track of what errors are happening
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
finally:
|
||||||
await channel_manager.on_disconnect(ws)
|
await channel_manager.on_disconnect(ws)
|
||||||
|
|
|
@ -198,10 +198,6 @@ class ApiServer(RPCHandler):
|
||||||
logger.debug(f"Found message of type: {message.get('type')}")
|
logger.debug(f"Found message of type: {message.get('type')}")
|
||||||
# Broadcast it
|
# Broadcast it
|
||||||
await self._ws_channel_manager.broadcast(message)
|
await self._ws_channel_manager.broadcast(message)
|
||||||
# Limit messages per sec.
|
|
||||||
# Could cause problems with queue size if too low, and
|
|
||||||
# problems with network traffik if too high.
|
|
||||||
await asyncio.sleep(0.001)
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -245,6 +241,7 @@ class ApiServer(RPCHandler):
|
||||||
use_colors=False,
|
use_colors=False,
|
||||||
log_config=None,
|
log_config=None,
|
||||||
access_log=True if verbosity != 'error' else False,
|
access_log=True if verbosity != 'error' else False,
|
||||||
|
ws_ping_interval=None # We do this explicitly ourselves
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self._server = UvicornServer(uvconfig)
|
self._server = UvicornServer(uvconfig)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from typing import List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import WebSocket as FastAPIWebSocket
|
from fastapi import WebSocket as FastAPIWebSocket
|
||||||
|
@ -34,6 +35,8 @@ class WebSocketChannel:
|
||||||
self._serializer_cls = serializer_cls
|
self._serializer_cls = serializer_cls
|
||||||
|
|
||||||
self._subscriptions: List[str] = []
|
self._subscriptions: List[str] = []
|
||||||
|
self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
|
||||||
|
self._relay_task = asyncio.create_task(self.relay())
|
||||||
|
|
||||||
# Internal event to signify a closed websocket
|
# Internal event to signify a closed websocket
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
@ -48,12 +51,18 @@ class WebSocketChannel:
|
||||||
def remote_addr(self):
|
def remote_addr(self):
|
||||||
return self._websocket.remote_addr
|
return self._websocket.remote_addr
|
||||||
|
|
||||||
async def send(self, data):
|
async def _send(self, data):
|
||||||
"""
|
"""
|
||||||
Send data on the wrapped websocket
|
Send data on the wrapped websocket
|
||||||
"""
|
"""
|
||||||
await self._wrapped_ws.send(data)
|
await self._wrapped_ws.send(data)
|
||||||
|
|
||||||
|
async def send(self, data):
|
||||||
|
"""
|
||||||
|
Add the data to the queue to be sent
|
||||||
|
"""
|
||||||
|
self.queue.put_nowait(data)
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
"""
|
"""
|
||||||
Receive data on the wrapped websocket
|
Receive data on the wrapped websocket
|
||||||
|
@ -72,6 +81,7 @@ class WebSocketChannel:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._closed = True
|
self._closed = True
|
||||||
|
self._relay_task.cancel()
|
||||||
|
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -95,6 +105,26 @@ class WebSocketChannel:
|
||||||
"""
|
"""
|
||||||
return message_type in self._subscriptions
|
return message_type in self._subscriptions
|
||||||
|
|
||||||
|
async def relay(self):
|
||||||
|
"""
|
||||||
|
Relay messages from the channel's queue and send them out. This is started
|
||||||
|
as a task.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
message = await self.queue.get()
|
||||||
|
try:
|
||||||
|
await self._send(message)
|
||||||
|
self.queue.task_done()
|
||||||
|
|
||||||
|
# Limit messages per sec.
|
||||||
|
# Could cause problems with queue size if too low, and
|
||||||
|
# problems with network traffik if too high.
|
||||||
|
# 0.001 = 1000/s
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
except RuntimeError:
|
||||||
|
# The connection was closed, just exit the task
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class ChannelManager:
|
class ChannelManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -155,11 +185,11 @@ class ChannelManager:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
message_type = data.get('type')
|
message_type = data.get('type')
|
||||||
for websocket, channel in self.channels.copy().items():
|
for websocket, channel in self.channels.copy().items():
|
||||||
try:
|
|
||||||
if channel.subscribed_to(message_type):
|
if channel.subscribed_to(message_type):
|
||||||
|
if not channel.queue.full():
|
||||||
await channel.send(data)
|
await channel.send(data)
|
||||||
except RuntimeError:
|
else:
|
||||||
# Handle cannot send after close cases
|
logger.info(f"Channel {channel} is too far behind, disconnecting")
|
||||||
await self.on_disconnect(websocket)
|
await self.on_disconnect(websocket)
|
||||||
|
|
||||||
async def send_direct(self, channel, data):
|
async def send_direct(self, channel, data):
|
||||||
|
|
|
@ -62,7 +62,7 @@ class ExternalMessageConsumer:
|
||||||
self.enabled = self._emc_config.get('enabled', False)
|
self.enabled = self._emc_config.get('enabled', False)
|
||||||
self.producers: List[Producer] = self._emc_config.get('producers', [])
|
self.producers: List[Producer] = self._emc_config.get('producers', [])
|
||||||
|
|
||||||
self.wait_timeout = self._emc_config.get('wait_timeout', 300) # in seconds
|
self.wait_timeout = self._emc_config.get('wait_timeout', 30) # in seconds
|
||||||
self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds
|
self.ping_timeout = self._emc_config.get('ping_timeout', 10) # in seconds
|
||||||
self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds
|
self.sleep_time = self._emc_config.get('sleep_time', 10) # in seconds
|
||||||
|
|
||||||
|
@ -174,6 +174,7 @@ class ExternalMessageConsumer:
|
||||||
:param producer: Dictionary containing producer info
|
:param producer: Dictionary containing producer info
|
||||||
:param lock: An asyncio Lock
|
:param lock: An asyncio Lock
|
||||||
"""
|
"""
|
||||||
|
channel = None
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
host, port = producer['host'], producer['port']
|
host, port = producer['host'], producer['port']
|
||||||
|
@ -182,7 +183,11 @@ class ExternalMessageConsumer:
|
||||||
ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}"
|
ws_url = f"ws://{host}:{port}/api/v1/message/ws?token={token}"
|
||||||
|
|
||||||
# This will raise InvalidURI if the url is bad
|
# This will raise InvalidURI if the url is bad
|
||||||
async with websockets.connect(ws_url, max_size=self.message_size_limit) as ws:
|
async with websockets.connect(
|
||||||
|
ws_url,
|
||||||
|
max_size=self.message_size_limit,
|
||||||
|
ping_interval=None
|
||||||
|
) as ws:
|
||||||
channel = WebSocketChannel(ws, channel_id=name)
|
channel = WebSocketChannel(ws, channel_id=name)
|
||||||
|
|
||||||
logger.info(f"Producer connection success - {channel}")
|
logger.info(f"Producer connection success - {channel}")
|
||||||
|
@ -224,6 +229,10 @@ class ExternalMessageConsumer:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if channel:
|
||||||
|
await channel.close()
|
||||||
|
|
||||||
async def _receive_messages(
|
async def _receive_messages(
|
||||||
self,
|
self,
|
||||||
channel: WebSocketChannel,
|
channel: WebSocketChannel,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user