freqtrade_origin/freqtrade/rpc/api_server/ws/channel.py
2024-05-13 07:10:25 +02:00

241 lines
7.3 KiB
Python

import asyncio
import logging
import time
from collections import deque
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Deque, Dict, List, Optional, Type, Union
from uuid import uuid4
from fastapi import WebSocketDisconnect
from websockets.exceptions import ConnectionClosed
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (
HybridJSONWebSocketSerializer,
WebSocketSerializer,
)
from freqtrade.rpc.api_server.ws.types import WebSocketType
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
logger = logging.getLogger(__name__)
class WebSocketChannel:
"""
Object to help facilitate managing a websocket connection
"""
def __init__(
self,
websocket: WebSocketType,
channel_id: Optional[str] = None,
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer,
send_throttle: float = 0.01,
):
self.channel_id = channel_id if channel_id else uuid4().hex[:8]
self._websocket = WebSocketProxy(websocket)
# Internal event to signify a closed websocket
self._closed = asyncio.Event()
# The async tasks created for the channel
self._channel_tasks: List[asyncio.Task] = []
# Deque for average send times
self._send_times: Deque[float] = deque([], maxlen=10)
# High limit defaults to 3 to start
self._send_high_limit = 3
self._send_throttle = send_throttle
# The subscribed message types
self._subscriptions: List[str] = []
# Wrap the WebSocket in the Serializing class
self._wrapped_ws = serializer_cls(self._websocket)
def __repr__(self):
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
@property
def raw_websocket(self):
return self._websocket.raw_websocket
@property
def remote_addr(self):
return self._websocket.remote_addr
@property
def avg_send_time(self):
return sum(self._send_times) / len(self._send_times)
def _calc_send_limit(self):
"""
Calculate the send high limit for this channel
"""
# Only update if we have enough data
if len(self._send_times) == self._send_times.maxlen:
# At least 1s or twice the average of send times, with a
# maximum of 3 seconds per message
self._send_high_limit = min(max(self.avg_send_time * 2, 1), 3)
async def send(
self, message: Union[WSMessageSchemaType, Dict[str, Any]], timeout: bool = False
):
"""
Send a message on the wrapped websocket. If the sending
takes too long, it will raise a TimeoutError and
disconnect the connection.
:param message: The message to send
:param timeout: Enforce send high limit, defaults to False
"""
try:
_ = time.time()
# If the send times out, it will raise
# a TimeoutError and bubble up to the
# message_endpoint to close the connection
await asyncio.wait_for(
self._wrapped_ws.send(message), timeout=self._send_high_limit if timeout else None
)
total_time = time.time() - _
self._send_times.append(total_time)
self._calc_send_limit()
except asyncio.TimeoutError:
logger.info(f"Connection for {self} timed out, disconnecting")
raise
# Explicitly give control back to event loop as
# websockets.send does not
# Also throttles how fast we send
await asyncio.sleep(self._send_throttle)
async def recv(self):
"""
Receive a message on the wrapped websocket
"""
return await self._wrapped_ws.recv()
async def ping(self):
"""
Ping the websocket
"""
return await self._websocket.ping()
async def accept(self):
"""
Accept the underlying websocket connection,
if the connection has been closed before we can
accept, just close the channel.
"""
try:
return await self._websocket.accept()
except RuntimeError:
await self.close()
async def close(self):
"""
Close the WebSocketChannel
"""
self._closed.set()
try:
await self._websocket.close()
except RuntimeError:
pass
def is_closed(self) -> bool:
"""
Closed flag
"""
return self._closed.is_set()
def set_subscriptions(self, subscriptions: List[str]) -> None:
"""
Set which subscriptions this channel is subscribed to
:param subscriptions: List of subscriptions, List[str]
"""
self._subscriptions = subscriptions
def subscribed_to(self, message_type: str) -> bool:
"""
Check if this channel is subscribed to the message_type
:param message_type: The message type to check
"""
return message_type in self._subscriptions
async def run_channel_tasks(self, *tasks, **kwargs):
"""
Create and await on the channel tasks unless an exception
was raised, then cancel them all.
:params *tasks: All coros or tasks to be run concurrently
:param **kwargs: Any extra kwargs to pass to gather
"""
if not self.is_closed():
# Wrap the coros into tasks if they aren't already
self._channel_tasks = [
task if isinstance(task, asyncio.Task) else asyncio.create_task(task)
for task in tasks
]
try:
return await asyncio.gather(*self._channel_tasks, **kwargs)
except Exception:
# If an exception occurred, cancel the rest of the tasks
await self.cancel_channel_tasks()
async def cancel_channel_tasks(self):
"""
Cancel and wait on all channel tasks
"""
for task in self._channel_tasks:
task.cancel()
# Wait for tasks to finish cancelling
try:
await task
except (
asyncio.CancelledError,
asyncio.TimeoutError,
WebSocketDisconnect,
ConnectionClosed,
RuntimeError,
):
pass
except Exception as e:
logger.info(f"Encountered unknown exception: {e}", exc_info=e)
self._channel_tasks = []
async def __aiter__(self):
"""
Generator for received messages
"""
# We can not catch any errors here as websocket.recv is
# the first to catch any disconnects and bubble it up
# so the connection is garbage collected right away
while not self.is_closed():
yield await self.recv()
@asynccontextmanager
async def create_channel(websocket: WebSocketType, **kwargs) -> AsyncIterator[WebSocketChannel]:
"""
Context manager for safely opening and closing a WebSocketChannel
"""
channel = WebSocketChannel(websocket, **kwargs)
try:
await channel.accept()
logger.info(f"Connected to channel - {channel}")
yield channel
finally:
await channel.close()
logger.info(f"Disconnected from channel - {channel}")