freqtrade_origin/freqtrade/rpc/api_server/ws/channel.py
2022-11-14 20:27:45 -07:00

349 lines
10 KiB
Python

import asyncio
import logging
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Optional, Type, Union
from uuid import uuid4
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import (HybridJSONWebSocketSerializer,
WebSocketSerializer)
from freqtrade.rpc.api_server.ws.types import WebSocketType
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
logger = logging.getLogger(__name__)
class WebSocketChannel:
"""
Object to help facilitate managing a websocket connection
"""
def __init__(
self,
websocket: WebSocketType,
channel_id: Optional[str] = None,
serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
):
self.channel_id = channel_id if channel_id else uuid4().hex[:8]
self._websocket = WebSocketProxy(websocket)
# Internal event to signify a closed websocket
self._closed = asyncio.Event()
# Throttle how fast we send messages
self._throttle = 0.01
# Wrap the WebSocket in the Serializing class
self._wrapped_ws = serializer_cls(self._websocket)
def __repr__(self):
return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
@property
def raw_websocket(self):
return self._websocket.raw_websocket
@property
def remote_addr(self):
return self._websocket.remote_addr
async def send(self, message: Union[WSMessageSchemaType, Dict[str, Any]]):
"""
Send a message on the wrapped websocket
"""
await asyncio.sleep(self._throttle)
await self._wrapped_ws.send(message)
async def recv(self):
"""
Receive a message on the wrapped websocket
"""
return await self._wrapped_ws.recv()
async def ping(self):
"""
Ping the websocket
"""
return await self._websocket.ping()
async def accept(self):
"""
Accept the underlying websocket connection
"""
return await self._websocket.accept()
async def close(self):
"""
Close the WebSocketChannel
"""
try:
await self._websocket.close()
except Exception:
pass
self._closed.set()
def is_closed(self) -> bool:
"""
Closed flag
"""
return self._closed.is_set()
def set_subscriptions(self, subscriptions: List[str] = []) -> None:
"""
Set which subscriptions this channel is subscribed to
:param subscriptions: List of subscriptions, List[str]
"""
self._subscriptions = subscriptions
def subscribed_to(self, message_type: str) -> bool:
"""
Check if this channel is subscribed to the message_type
:param message_type: The message type to check
"""
return message_type in self._subscriptions
async def __aiter__(self):
"""
Generator for received messages
"""
while True:
try:
yield await self.recv()
except Exception:
break
@asynccontextmanager
async def connect(self):
"""
Context manager for safely opening and closing the websocket connection
"""
try:
await self.accept()
yield self
finally:
await self.close()
# class WebSocketChannel:
# """
# Object to help facilitate managing a websocket connection
# """
# def __init__(
# self,
# websocket: WebSocketType,
# channel_id: Optional[str] = None,
# drain_timeout: int = 3,
# throttle: float = 0.01,
# serializer_cls: Type[WebSocketSerializer] = HybridJSONWebSocketSerializer
# ):
# self.channel_id = channel_id if channel_id else uuid4().hex[:8]
# # The WebSocket object
# self._websocket = WebSocketProxy(websocket)
# self.drain_timeout = drain_timeout
# self.throttle = throttle
# self._subscriptions: List[str] = []
# # 32 is the size of the receiving queue in websockets package
# self.queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=32)
# self._relay_task = asyncio.create_task(self.relay())
# # Internal event to signify a closed websocket
# self._closed = asyncio.Event()
# # Wrap the WebSocket in the Serializing class
# self._wrapped_ws = serializer_cls(self._websocket)
# def __repr__(self):
# return f"WebSocketChannel({self.channel_id}, {self.remote_addr})"
# @property
# def raw_websocket(self):
# return self._websocket.raw_websocket
# @property
# def remote_addr(self):
# return self._websocket.remote_addr
# async def _send(self, data):
# """
# Send data on the wrapped websocket
# """
# await self._wrapped_ws.send(data)
# async def send(self, data) -> bool:
# """
# Add the data to the queue to be sent.
# :returns: True if data added to queue, False otherwise
# """
# # This block only runs if the queue is full, it will wait
# # until self.drain_timeout for the relay to drain the outgoing queue
# # We can't use asyncio.wait_for here because the queue may have been created with a
# # different eventloop
# start = time.time()
# while self.queue.full():
# await asyncio.sleep(1)
# if (time.time() - start) > self.drain_timeout:
# return False
# # If for some reason the queue is still full, just return False
# try:
# self.queue.put_nowait(data)
# except asyncio.QueueFull:
# return False
# # If we got here everything is ok
# return True
# async def recv(self):
# """
# Receive data on the wrapped websocket
# """
# return await self._wrapped_ws.recv()
# async def ping(self):
# """
# Ping the websocket
# """
# return await self._websocket.ping()
# async def close(self):
# """
# Close the WebSocketChannel
# """
# try:
# await self.raw_websocket.close()
# except Exception:
# pass
# self._closed.set()
# self._relay_task.cancel()
# def is_closed(self) -> bool:
# """
# Closed flag
# """
# return self._closed.is_set()
# def set_subscriptions(self, subscriptions: List[str] = []) -> None:
# """
# Set which subscriptions this channel is subscribed to
# :param subscriptions: List of subscriptions, List[str]
# """
# self._subscriptions = subscriptions
# def subscribed_to(self, message_type: str) -> bool:
# """
# Check if this channel is subscribed to the message_type
# :param message_type: The message type to check
# """
# return message_type in self._subscriptions
# async def relay(self):
# """
# Relay messages from the channel's queue and send them out. This is started
# as a task.
# """
# while not self._closed.is_set():
# message = await self.queue.get()
# try:
# await self._send(message)
# self.queue.task_done()
# # Limit messages per sec.
# # Could cause problems with queue size if too low, and
# # problems with network traffik if too high.
# # 0.01 = 100/s
# await asyncio.sleep(self.throttle)
# except RuntimeError:
# # The connection was closed, just exit the task
# return
# class ChannelManager:
# def __init__(self):
# self.channels = dict()
# self._lock = RLock() # Re-entrant Lock
# async def on_connect(self, websocket: WebSocketType):
# """
# Wrap websocket connection into Channel and add to list
# :param websocket: The WebSocket object to attach to the Channel
# """
# if isinstance(websocket, FastAPIWebSocket):
# try:
# await websocket.accept()
# except RuntimeError:
# # The connection was closed before we could accept it
# return
# ws_channel = WebSocketChannel(websocket)
# with self._lock:
# self.channels[websocket] = ws_channel
# return ws_channel
# async def on_disconnect(self, websocket: WebSocketType):
# """
# Call close on the channel if it's not, and remove from channel list
# :param websocket: The WebSocket objet attached to the Channel
# """
# with self._lock:
# channel = self.channels.get(websocket)
# if channel:
# logger.info(f"Disconnecting channel {channel}")
# if not channel.is_closed():
# await channel.close()
# del self.channels[websocket]
# async def disconnect_all(self):
# """
# Disconnect all Channels
# """
# with self._lock:
# for websocket in self.channels.copy().keys():
# await self.on_disconnect(websocket)
# async def broadcast(self, message: WSMessageSchemaType):
# """
# Broadcast a message on all Channels
# :param message: The message to send
# """
# with self._lock:
# for channel in self.channels.copy().values():
# if channel.subscribed_to(message.get('type')):
# await self.send_direct(channel, message)
# async def send_direct(
# self, channel: WebSocketChannel, message: Union[WSMessageSchemaType, Dict[str, Any]]):
# """
# Send a message directly through direct_channel only
# :param direct_channel: The WebSocketChannel object to send the message through
# :param message: The message to send
# """
# if not await channel.send(message):
# await self.on_disconnect(channel.raw_websocket)
# def has_channels(self):
# """
# Flag for more than 0 channels
# """
# return len(self.channels) > 0