initial rework separating server and client impl

This commit is contained in:
Timothy Pogue 2022-08-29 13:41:15 -06:00
parent 8c4e68b8eb
commit 7952e0df25
25 changed files with 1329 additions and 1068 deletions

3
.gitignore vendored
View File

@ -115,6 +115,3 @@ target/
!config_examples/config_freqai.example.json !config_examples/config_freqai.example.json
!config_examples/config_leader.example.json !config_examples/config_leader.example.json
!config_examples/config_follower.example.json !config_examples/config_follower.example.json
*-config.json
*.db*

View File

@ -1,87 +0,0 @@
{
"db_url": "sqlite:///follower.db",
"strategy": "SampleStrategy",
"max_open_trades": 3,
"stake_currency": "USDT",
"stake_amount": 100,
"tradable_balance_ratio": 0.99,
"fiat_display_currency": "USD",
"dry_run": true,
"cancel_open_orders_on_exit": false,
"trading_mode": "spot",
"margin_mode": "",
"unfilledtimeout": {
"entry": 10,
"exit": 10,
"exit_timeout_count": 0,
"unit": "minutes"
},
"entry_pricing": {
"price_side": "same",
"use_order_book": true,
"order_book_top": 1,
"price_last_balance": 0.0,
"check_depth_of_market": {
"enabled": false,
"bids_to_ask_delta": 1
}
},
"exit_pricing":{
"price_side": "same",
"use_order_book": true,
"order_book_top": 1
},
"exchange": {
"name": "kucoin",
"key": "",
"secret": "",
"password": "",
"ccxt_config": {},
"ccxt_async_config": {},
"pair_whitelist": [
],
"pair_blacklist": [
]
},
"pairlists": [
{
"method": "ExternalPairList", // ExternalPairList is required in follower mode
"number_assets": 5, // We can limit the amount of pairs to use from the leaders
}
],
"telegram": {
"enabled": false,
"token": "",
"chat_id": ""
},
"api_server": {
"enabled": true,
"listen_ip_address": "127.0.0.1",
"listen_port": 8081,
"verbosity": "error",
"enable_openapi": false,
"jwt_secret_key": "fcc24d31d6581ad2c90c3fc438c8a8b2ccce1393126959934568707f0bd2d647",
"CORS_origins": [],
"username": "freqtrader",
"password": "testing123"
},
"external_signal": {
"enabled": true,
"mode": "follower",
"leaders": [
{
"url": "ws://localhost:8080/signals/ws",
"api_token": "testtoken"
}
],
"wait_data_policy": "all", // ['all', 'first', none] defaults to all
"remove_signals_analyzed_df": true, // Remove entry/exit signals from Leader df, Defaults to false
},
"bot_name": "freqtrade",
"initial_state": "running",
"force_entry_enable": false,
"internals": {
"process_throttle_secs": 5,
}
}

View File

@ -1,97 +0,0 @@
{
"db_url": "sqlite:///leader.db",
"strategy": "SampleStrategy",
"max_open_trades": 3,
"stake_currency": "USDT",
"stake_amount": 100,
"tradable_balance_ratio": 0.99,
"fiat_display_currency": "USD",
"dry_run": true,
"cancel_open_orders_on_exit": false,
"trading_mode": "spot",
"margin_mode": "",
"unfilledtimeout": {
"entry": 10,
"exit": 10,
"exit_timeout_count": 0,
"unit": "minutes"
},
"entry_pricing": {
"price_side": "same",
"use_order_book": true,
"order_book_top": 1,
"price_last_balance": 0.0,
"check_depth_of_market": {
"enabled": false,
"bids_to_ask_delta": 1
}
},
"exit_pricing":{
"price_side": "same",
"use_order_book": true,
"order_book_top": 1
},
"exchange": {
"name": "kucoin",
"key": "",
"secret": "",
"password": "",
"ccxt_config": {},
"ccxt_async_config": {},
"pair_whitelist": [
],
"pair_blacklist": [
]
},
"pairlists": [
{
"method": "VolumePairList",
"number_assets": 20,
"sort_key": "quoteVolume",
"min_value": 0,
"refresh_period": 1800
}
],
"edge": {
"enabled": false,
"process_throttle_secs": 3600,
"calculate_since_number_of_days": 7,
"allowed_risk": 0.01,
"stoploss_range_min": -0.01,
"stoploss_range_max": -0.1,
"stoploss_range_step": -0.01,
"minimum_winrate": 0.60,
"minimum_expectancy": 0.20,
"min_trade_number": 10,
"max_trade_duration_minute": 1440,
"remove_pumps": false
},
"telegram": {
"enabled": false,
"token": "",
"chat_id": ""
},
"api_server": {
"enabled": true,
"listen_ip_address": "127.0.0.1",
"listen_port": 8080,
"verbosity": "error",
"enable_openapi": false,
"jwt_secret_key": "fcc24d31d6581ad2c90c3fc438c8a8b2ccce1393126959934568707f0bd2d647",
"CORS_origins": [],
"username": "freqtrader",
"password": "testing123"
},
"external_signal": {
"enabled": true,
"mode": "leader",
"api_token": "testtoken",
},
"bot_name": "freqtrade",
"initial_state": "running",
"force_entry_enable": false,
"internals": {
"process_throttle_secs": 5,
}
}

View File

@ -404,6 +404,7 @@ CONF_SCHEMA = {
}, },
'username': {'type': 'string'}, 'username': {'type': 'string'},
'password': {'type': 'string'}, 'password': {'type': 'string'},
'api_token': {'type': 'string'},
'jwt_secret_key': {'type': 'string'}, 'jwt_secret_key': {'type': 'string'},
'CORS_origins': {'type': 'array', 'items': {'type': 'string'}}, 'CORS_origins': {'type': 'array', 'items': {'type': 'string'}},
'verbosity': {'type': 'string', 'enum': ['error', 'info']}, 'verbosity': {'type': 'string', 'enum': ['error', 'info']},

View File

@ -1,7 +1,8 @@
from enum import Enum from enum import Enum
class RPCMessageType(Enum): # We need to inherit from str so we can use as a str
class RPCMessageType(str, Enum):
STATUS = 'status' STATUS = 'status'
WARNING = 'warning' WARNING = 'warning'
STARTUP = 'startup' STARTUP = 'startup'
@ -19,7 +20,7 @@ class RPCMessageType(Enum):
STRATEGY_MSG = 'strategy_msg' STRATEGY_MSG = 'strategy_msg'
EMIT_DATA = 'emit_data' WHITELIST = 'whitelist'
def __repr__(self): def __repr__(self):
return self.value return self.value

View File

@ -17,13 +17,13 @@ from freqtrade.constants import BuySell, LongShort
from freqtrade.data.converter import order_book_to_dataframe from freqtrade.data.converter import order_book_to_dataframe
from freqtrade.data.dataprovider import DataProvider from freqtrade.data.dataprovider import DataProvider
from freqtrade.edge import Edge from freqtrade.edge import Edge
from freqtrade.enums import (ExitCheckTuple, ExitType, LeaderMessageType, RPCMessageType, RunMode, from freqtrade.enums import (ExitCheckTuple, ExitType, RPCMessageType, RunMode, SignalDirection,
SignalDirection, State, TradingMode) State, TradingMode)
from freqtrade.exceptions import (DependencyException, ExchangeError, InsufficientFundsError, from freqtrade.exceptions import (DependencyException, ExchangeError, InsufficientFundsError,
InvalidOrderException, PricingError) InvalidOrderException, PricingError)
from freqtrade.exchange import timeframe_to_minutes, timeframe_to_seconds from freqtrade.exchange import timeframe_to_minutes, timeframe_to_seconds
from freqtrade.exchange.exchange import timeframe_to_next_date from freqtrade.exchange.exchange import timeframe_to_next_date
from freqtrade.misc import dataframe_to_json, safe_value_fallback, safe_value_fallback2 from freqtrade.misc import safe_value_fallback, safe_value_fallback2
from freqtrade.mixins import LoggingMixin from freqtrade.mixins import LoggingMixin
from freqtrade.persistence import Order, PairLocks, Trade, init_db from freqtrade.persistence import Order, PairLocks, Trade, init_db
from freqtrade.plugins.pairlistmanager import PairListManager from freqtrade.plugins.pairlistmanager import PairListManager
@ -75,8 +75,6 @@ class FreqtradeBot(LoggingMixin):
PairLocks.timeframe = self.config['timeframe'] PairLocks.timeframe = self.config['timeframe']
self.external_signal_controller = None
self.pairlists = PairListManager(self.exchange, self.config) self.pairlists = PairListManager(self.exchange, self.config)
# RPC runs in separate threads, can start handling external commands just after # RPC runs in separate threads, can start handling external commands just after
@ -194,27 +192,6 @@ class FreqtradeBot(LoggingMixin):
strategy_safe_wrapper(self.strategy.bot_loop_start, supress_error=True)() strategy_safe_wrapper(self.strategy.bot_loop_start, supress_error=True)()
if self.external_signal_controller:
if not self.external_signal_controller.is_leader():
# Run Follower mode analyzing
leader_pairs = self.pairlists._whitelist
self.strategy.analyze_external(self.active_pair_whitelist, leader_pairs)
else:
# We are leader, make sure to pass callback func to emit data
def emit_on_finish(pair, dataframe, timeframe, candle_type):
logger.debug(f"Emitting dataframe for {pair}")
return self.rpc.emit_data(
{
"data_type": LeaderMessageType.analyzed_df,
"data": {
"key": (pair, timeframe, candle_type),
"value": dataframe_to_json(dataframe)
}
}
)
self.strategy.analyze(self.active_pair_whitelist, finish_callback=emit_on_finish)
else:
self.strategy.analyze(self.active_pair_whitelist) self.strategy.analyze(self.active_pair_whitelist)
with self._exit_lock: with self._exit_lock:
@ -278,15 +255,7 @@ class FreqtradeBot(LoggingMixin):
self.pairlists.refresh_pairlist() self.pairlists.refresh_pairlist()
_whitelist = self.pairlists.whitelist _whitelist = self.pairlists.whitelist
# If external signal leader, broadcast whitelist data self.rpc.send_msg({'type': RPCMessageType.WHITELIST, 'msg': _whitelist})
# Should we broadcast before trade pairs are added?
if self.external_signal_controller:
if self.external_signal_controller.is_leader():
self.rpc.emit_data({
"data_type": LeaderMessageType.pairlist,
"data": _whitelist
})
# Calculating Edge positioning # Calculating Edge positioning
if self.edge: if self.edge:

View File

@ -1,8 +1,10 @@
import logging
import secrets import secrets
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, Union
import jwt import jwt
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, WebSocket, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi.security.http import HTTPBasic, HTTPBasicCredentials from fastapi.security.http import HTTPBasic, HTTPBasicCredentials
@ -10,6 +12,8 @@ from freqtrade.rpc.api_server.api_schemas import AccessAndRefreshToken, AccessTo
from freqtrade.rpc.api_server.deps import get_api_config from freqtrade.rpc.api_server.deps import get_api_config
logger = logging.getLogger(__name__)
ALGORITHM = "HS256" ALGORITHM = "HS256"
router_login = APIRouter() router_login = APIRouter()
@ -44,6 +48,24 @@ def get_user_from_token(token, secret_key: str, token_type: str = "access"):
return username return username
# This should be reimplemented to better realign with the existing tools provided
# by FastAPI regarding API Tokens
async def get_ws_token(
ws: WebSocket,
token: Union[str, None] = None,
api_config: Dict[str, Any] = Depends(get_api_config)
):
secret_ws_token = api_config['ws_token']
if token == secret_ws_token:
# Just return the token if it matches
return token
else:
logger.debug("Denying websocket request")
# If it doesn't match, close the websocket connection
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
def create_token(data: dict, secret_key: str, token_type: str = "access") -> str: def create_token(data: dict, secret_key: str, token_type: str = "access") -> str:
to_encode = data.copy() to_encode = data.copy()
if token_type == "access": if token_type == "access":

View File

@ -0,0 +1,52 @@
import logging
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from freqtrade.rpc.api_server.deps import get_channel_manager
from freqtrade.rpc.api_server.ws.utils import is_websocket_alive
logger = logging.getLogger(__name__)
# Private router, protected by API Key authentication
router = APIRouter()
@router.websocket("/message/ws")
async def message_endpoint(
ws: WebSocket,
channel_manager=Depends(get_channel_manager)
):
try:
if is_websocket_alive(ws):
logger.info(f"Consumer connected - {ws.client}")
# TODO:
# Return a channel ID, pass that instead of ws to the rest of the methods
channel = await channel_manager.on_connect(ws)
# Keep connection open until explicitly closed, and sleep
try:
while not channel.is_closed():
request = await channel.recv()
# This is where we'd parse the request. For now this should only
# be a list of topics to subscribe too. List[str]
# Maybe allow the consumer to update the topics subscribed
# during runtime?
logger.info(f"Consumer request - {request}")
except WebSocketDisconnect:
# Handle client disconnects
logger.info(f"Consumer disconnected - {ws.client}")
await channel_manager.on_disconnect(ws)
except Exception as e:
logger.info(f"Consumer connection failed - {ws.client}")
logger.exception(e)
# Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent')
await channel_manager.on_disconnect(ws)
except Exception:
logger.error(f"Failed to serve - {ws.client}")
await channel_manager.on_disconnect(ws)

View File

@ -41,6 +41,10 @@ def get_exchange(config=Depends(get_config)):
return ApiServer._exchange return ApiServer._exchange
def get_channel_manager():
return ApiServer._channel_manager
def is_webserver_mode(config=Depends(get_config)): def is_webserver_mode(config=Depends(get_config)):
if config['runmode'] != RunMode.WEBSERVER: if config['runmode'] != RunMode.WEBSERVER:
raise RPCException('Bot is not in the correct state') raise RPCException('Bot is not in the correct state')

View File

@ -1,15 +1,20 @@
import asyncio
import logging import logging
from ipaddress import IPv4Address from ipaddress import IPv4Address
from threading import Thread
from typing import Any, Dict from typing import Any, Dict
import orjson import orjson
import uvicorn import uvicorn
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
# Look into alternatives
from janus import Queue as ThreadedQueue
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
from freqtrade.rpc.api_server.ws.channel import ChannelManager
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
@ -43,6 +48,10 @@ class ApiServer(RPCHandler):
_config: Dict[str, Any] = {} _config: Dict[str, Any] = {}
# Exchange - only available in webserver mode. # Exchange - only available in webserver mode.
_exchange = None _exchange = None
# websocket message queue stuff
_channel_manager = None
_thread = None
_loop = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """
@ -64,10 +73,15 @@ class ApiServer(RPCHandler):
return return
self._standalone: bool = standalone self._standalone: bool = standalone
self._server = None self._server = None
self._queue = None
self._background_task = None
ApiServer.__initialized = True ApiServer.__initialized = True
api_config = self._config['api_server'] api_config = self._config['api_server']
ApiServer._channel_manager = ChannelManager()
self.app = FastAPI(title="Freqtrade API", self.app = FastAPI(title="Freqtrade API",
docs_url='/docs' if api_config.get('enable_openapi', False) else None, docs_url='/docs' if api_config.get('enable_openapi', False) else None,
redoc_url=None, redoc_url=None,
@ -95,6 +109,18 @@ class ApiServer(RPCHandler):
logger.info("Stopping API Server") logger.info("Stopping API Server")
self._server.cleanup() self._server.cleanup()
if self._thread and self._loop:
logger.info("Stopping API Server background tasks")
if self._background_task:
# Cancel the queue task
self._background_task.cancel()
# Finally stop the loop
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
@classmethod @classmethod
def shutdown(cls): def shutdown(cls):
cls.__initialized = False cls.__initialized = False
@ -104,7 +130,10 @@ class ApiServer(RPCHandler):
cls._rpc = None cls._rpc = None
def send_msg(self, msg: Dict[str, str]) -> None: def send_msg(self, msg: Dict[str, str]) -> None:
pass if self._queue:
logger.info(f"Adding message to queue: {msg}")
sync_q = self._queue.sync_q
sync_q.put(msg)
def handle_rpc_exception(self, request, exc): def handle_rpc_exception(self, request, exc):
logger.exception(f"API Error calling: {exc}") logger.exception(f"API Error calling: {exc}")
@ -114,10 +143,12 @@ class ApiServer(RPCHandler):
) )
def configure_app(self, app: FastAPI, config): def configure_app(self, app: FastAPI, config):
from freqtrade.rpc.api_server.api_auth import http_basic_or_jwt_token, router_login from freqtrade.rpc.api_server.api_auth import (get_ws_token, http_basic_or_jwt_token,
router_login)
from freqtrade.rpc.api_server.api_backtest import router as api_backtest from freqtrade.rpc.api_server.api_backtest import router as api_backtest
from freqtrade.rpc.api_server.api_v1 import router as api_v1 from freqtrade.rpc.api_server.api_v1 import router as api_v1
from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public
from freqtrade.rpc.api_server.api_ws import router as ws_router
from freqtrade.rpc.api_server.web_ui import router_ui from freqtrade.rpc.api_server.web_ui import router_ui
app.include_router(api_v1_public, prefix="/api/v1") app.include_router(api_v1_public, prefix="/api/v1")
@ -128,6 +159,9 @@ class ApiServer(RPCHandler):
app.include_router(api_backtest, prefix="/api/v1", app.include_router(api_backtest, prefix="/api/v1",
dependencies=[Depends(http_basic_or_jwt_token)], dependencies=[Depends(http_basic_or_jwt_token)],
) )
app.include_router(ws_router, prefix="/api/v1",
dependencies=[Depends(get_ws_token)]
)
app.include_router(router_login, prefix="/api/v1", tags=["auth"]) app.include_router(router_login, prefix="/api/v1", tags=["auth"])
# UI Router MUST be last! # UI Router MUST be last!
app.include_router(router_ui, prefix='') app.include_router(router_ui, prefix='')
@ -142,6 +176,43 @@ class ApiServer(RPCHandler):
app.add_exception_handler(RPCException, self.handle_rpc_exception) app.add_exception_handler(RPCException, self.handle_rpc_exception)
def start_message_queue(self):
# Create a new loop, as it'll be just for the background thread
self._loop = asyncio.new_event_loop()
# Start the thread
if not self._thread:
self._thread = Thread(target=self._loop.run_forever)
self._thread.start()
else:
raise RuntimeError("Threaded loop is already running")
# Finally, submit the coro to the thread
self._background_task = asyncio.run_coroutine_threadsafe(
self._broadcast_queue_data(), loop=self._loop)
async def _broadcast_queue_data(self):
# Instantiate the queue in this coroutine so it's attached to our loop
self._queue = ThreadedQueue()
async_queue = self._queue.async_q
try:
while True:
logger.debug("Getting queue data...")
# Get data from queue
data = await async_queue.get()
logger.debug(f"Found data: {data}")
# Broadcast it
await self._channel_manager.broadcast(data)
# Sleep, make this configurable?
await asyncio.sleep(0.1)
except asyncio.CancelledError:
# Silently stop
pass
# For testing, shouldn't happen when stable
except Exception as e:
logger.info(f"Exception happened in background task: {e}")
def start_api(self): def start_api(self):
""" """
Start API ... should be run in thread. Start API ... should be run in thread.
@ -179,6 +250,7 @@ class ApiServer(RPCHandler):
if self._standalone: if self._standalone:
self._server.run() self._server.run()
else: else:
self.start_message_queue()
self._server.run_in_thread() self._server.run_in_thread()
except Exception: except Exception:
logger.exception("Api server failed to start.") logger.exception("Api server failed to start.")

View File

@ -0,0 +1,146 @@
import logging
from threading import RLock
from typing import Type
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
from freqtrade.rpc.api_server.ws.serializer import ORJSONWebSocketSerializer, WebSocketSerializer
from freqtrade.rpc.api_server.ws.types import WebSocketType
logger = logging.getLogger(__name__)
class WebSocketChannel:
"""
Object to help facilitate managing a websocket connection
"""
def __init__(
self,
websocket: WebSocketType,
serializer_cls: Type[WebSocketSerializer] = ORJSONWebSocketSerializer
):
# The WebSocket object
self._websocket = WebSocketProxy(websocket)
# The Serializing class for the WebSocket object
self._serializer_cls = serializer_cls
# Internal event to signify a closed websocket
self._closed = False
# Wrap the WebSocket in the Serializing class
self._wrapped_ws = self._serializer_cls(self._websocket)
async def send(self, data):
"""
Send data on the wrapped websocket
"""
# logger.info(f"Serialized Send - {self._wrapped_ws._serialize(data)}")
await self._wrapped_ws.send(data)
async def recv(self):
"""
Receive data on the wrapped websocket
"""
return await self._wrapped_ws.recv()
async def ping(self):
"""
Ping the websocket
"""
return await self._websocket.ping()
async def close(self):
"""
Close the WebSocketChannel
"""
self._closed = True
def is_closed(self):
return self._closed
class ChannelManager:
def __init__(self):
self.channels = dict()
self._lock = RLock() # Re-entrant Lock
async def on_connect(self, websocket: WebSocketType):
"""
Wrap websocket connection into Channel and add to list
:param websocket: The WebSocket object to attach to the Channel
"""
if hasattr(websocket, "accept"):
try:
await websocket.accept()
except RuntimeError:
# The connection was closed before we could accept it
return
ws_channel = WebSocketChannel(websocket)
with self._lock:
self.channels[websocket] = ws_channel
return ws_channel
async def on_disconnect(self, websocket: WebSocketType):
"""
Call close on the channel if it's not, and remove from channel list
:param websocket: The WebSocket objet attached to the Channel
"""
with self._lock:
channel = self.channels.get(websocket)
if channel:
logger.debug(f"Disconnecting channel - {channel}")
if not channel.is_closed():
await channel.close()
del self.channels[websocket]
async def disconnect_all(self):
"""
Disconnect all Channels
"""
with self._lock:
for websocket, channel in self.channels.items():
if not channel.is_closed():
await channel.close()
self.channels = dict()
async def broadcast(self, data):
"""
Broadcast data on all Channels
:param data: The data to send
"""
with self._lock:
logger.debug(f"Broadcasting data: {data}")
for websocket, channel in self.channels.items():
try:
await channel.send(data)
except RuntimeError:
# Handle cannot send after close cases
await self.on_disconnect(websocket)
async def send_direct(self, channel, data):
"""
Send data directly through direct_channel only
:param direct_channel: The WebSocketChannel object to send data through
:param data: The data to send
"""
# We iterate over the channels to get reference to the websocket object
# so we can disconnect incase of failure
await channel.send(data)
def has_channels(self):
"""
Flag for more than 0 channels
"""
return len(self.channels) > 0

View File

@ -0,0 +1,61 @@
from typing import Union
from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket
from freqtrade.rpc.api_server.ws.types import WebSocketType
class WebSocketProxy:
"""
WebSocketProxy object to bring the FastAPIWebSocket and websockets.WebSocketClientProtocol
under the same API
"""
def __init__(self, websocket: WebSocketType):
self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket
async def send(self, data):
"""
Send data on the wrapped websocket
"""
if isinstance(data, str):
data = data.encode()
if hasattr(self._websocket, "send_bytes"):
await self._websocket.send_bytes(data)
else:
await self._websocket.send(data)
async def recv(self):
"""
Receive data on the wrapped websocket
"""
if hasattr(self._websocket, "receive_bytes"):
return await self._websocket.receive_bytes()
else:
return await self._websocket.recv()
async def ping(self):
"""
Ping the websocket, not supported by FastAPI WebSockets
"""
if hasattr(self._websocket, "ping"):
return await self._websocket.ping()
return False
async def close(self, code: int = 1000):
"""
Close the websocket connection, only supported by FastAPI WebSockets
"""
if hasattr(self._websocket, "close"):
return await self._websocket.close(code)
pass
async def accept(self):
"""
Accept the WebSocket connection, only support by FastAPI WebSockets
"""
if hasattr(self._websocket, "accept"):
return await self._websocket.accept()
pass

View File

@ -0,0 +1,65 @@
import json
import logging
from abc import ABC, abstractmethod
import msgpack
import orjson
from freqtrade.rpc.api_server.ws.proxy import WebSocketProxy
logger = logging.getLogger(__name__)
class WebSocketSerializer(ABC):
def __init__(self, websocket: WebSocketProxy):
self._websocket: WebSocketProxy = websocket
@abstractmethod
def _serialize(self, data):
raise NotImplementedError()
@abstractmethod
def _deserialize(self, data):
raise NotImplementedError()
async def send(self, data: bytes):
await self._websocket.send(self._serialize(data))
async def recv(self) -> bytes:
data = await self._websocket.recv()
return self._deserialize(data)
async def close(self, code: int = 1000):
await self._websocket.close(code)
# Going to explore using MsgPack as the serialization,
# as that might be the best method for sending pandas
# dataframes over the wire
class JSONWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data):
return json.dumps(data)
def _deserialize(self, data):
return json.loads(data)
class ORJSONWebSocketSerializer(WebSocketSerializer):
ORJSON_OPTIONS = orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY
def _serialize(self, data):
return orjson.dumps(data, option=self.ORJSON_OPTIONS)
def _deserialize(self, data):
return orjson.loads(data, option=self.ORJSON_OPTIONS)
class MsgPackWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data):
return msgpack.packb(data, use_bin_type=True)
def _deserialize(self, data):
return msgpack.unpackb(data, raw=False)

View File

@ -0,0 +1,8 @@
from typing import Any, Dict, TypeVar
from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket)
MessageType = Dict[str, Any]

View File

@ -0,0 +1,12 @@
from fastapi import WebSocket
# fastapi does not make this available through it, so import directly from starlette
from starlette.websockets import WebSocketState
async def is_websocket_alive(ws: WebSocket) -> bool:
if (
ws.application_state == WebSocketState.CONNECTED and
ws.client_state == WebSocketState.CONNECTED
):
return True
return False

View File

@ -1,5 +1,5 @@
# flake8: noqa: F401 # # flake8: noqa: F401
from freqtrade.rpc.external_signal.controller import ExternalSignalController # from freqtrade.rpc.external_signal.controller import ExternalSignalController
#
#
__all__ = ('ExternalSignalController') # __all__ = ('ExternalSignalController')

View File

@ -1,145 +1,145 @@
import logging # import logging
from threading import RLock # from threading import RLock
from typing import Type # from typing import Type
#
from freqtrade.rpc.external_signal.proxy import WebSocketProxy # from freqtrade.rpc.external_signal.proxy import WebSocketProxy
from freqtrade.rpc.external_signal.serializer import MsgPackWebSocketSerializer, WebSocketSerializer # from freqtrade.rpc.external_signal.serializer import MsgPackWebSocketSerializer
from freqtrade.rpc.external_signal.types import WebSocketType # from freqtrade.rpc.external_signal.types import WebSocketType
#
#
logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
#
#
class WebSocketChannel: # class WebSocketChannel:
""" # """
Object to help facilitate managing a websocket connection # Object to help facilitate managing a websocket connection
""" # """
#
def __init__( # def __init__(
self, # self,
websocket: WebSocketType, # websocket: WebSocketType,
serializer_cls: Type[WebSocketSerializer] = MsgPackWebSocketSerializer # serializer_cls: Type[WebSocketSerializer] = MsgPackWebSocketSerializer
): # ):
# The WebSocket object # # The WebSocket object
self._websocket = WebSocketProxy(websocket) # self._websocket = WebSocketProxy(websocket)
# The Serializing class for the WebSocket object # # The Serializing class for the WebSocket object
self._serializer_cls = serializer_cls # self._serializer_cls = serializer_cls
#
# Internal event to signify a closed websocket # # Internal event to signify a closed websocket
self._closed = False # self._closed = False
#
# Wrap the WebSocket in the Serializing class # # Wrap the WebSocket in the Serializing class
self._wrapped_ws = self._serializer_cls(self._websocket) # self._wrapped_ws = self._serializer_cls(self._websocket)
#
async def send(self, data): # async def send(self, data):
""" # """
Send data on the wrapped websocket # Send data on the wrapped websocket
""" # """
# logger.info(f"Serialized Send - {self._wrapped_ws._serialize(data)}") # # logger.info(f"Serialized Send - {self._wrapped_ws._serialize(data)}")
await self._wrapped_ws.send(data) # await self._wrapped_ws.send(data)
#
async def recv(self): # async def recv(self):
""" # """
Receive data on the wrapped websocket # Receive data on the wrapped websocket
""" # """
return await self._wrapped_ws.recv() # return await self._wrapped_ws.recv()
#
async def ping(self): # async def ping(self):
""" # """
Ping the websocket # Ping the websocket
""" # """
return await self._websocket.ping() # return await self._websocket.ping()
#
async def close(self): # async def close(self):
""" # """
Close the WebSocketChannel # Close the WebSocketChannel
""" # """
#
self._closed = True # self._closed = True
#
def is_closed(self): # def is_closed(self):
return self._closed # return self._closed
#
#
class ChannelManager: # class ChannelManager:
def __init__(self): # def __init__(self):
self.channels = dict() # self.channels = dict()
self._lock = RLock() # Re-entrant Lock # self._lock = RLock() # Re-entrant Lock
#
async def on_connect(self, websocket: WebSocketType): # async def on_connect(self, websocket: WebSocketType):
""" # """
Wrap websocket connection into Channel and add to list # Wrap websocket connection into Channel and add to list
#
:param websocket: The WebSocket object to attach to the Channel # :param websocket: The WebSocket object to attach to the Channel
""" # """
if hasattr(websocket, "accept"): # if hasattr(websocket, "accept"):
try: # try:
await websocket.accept() # await websocket.accept()
except RuntimeError: # except RuntimeError:
# The connection was closed before we could accept it # # The connection was closed before we could accept it
return # return
#
ws_channel = WebSocketChannel(websocket) # ws_channel = WebSocketChannel(websocket)
#
with self._lock: # with self._lock:
self.channels[websocket] = ws_channel # self.channels[websocket] = ws_channel
#
return ws_channel # return ws_channel
#
async def on_disconnect(self, websocket: WebSocketType): # async def on_disconnect(self, websocket: WebSocketType):
""" # """
Call close on the channel if it's not, and remove from channel list # Call close on the channel if it's not, and remove from channel list
#
:param websocket: The WebSocket objet attached to the Channel # :param websocket: The WebSocket objet attached to the Channel
""" # """
with self._lock: # with self._lock:
channel = self.channels.get(websocket) # channel = self.channels.get(websocket)
if channel: # if channel:
logger.debug(f"Disconnecting channel - {channel}") # logger.debug(f"Disconnecting channel - {channel}")
#
if not channel.is_closed(): # if not channel.is_closed():
await channel.close() # await channel.close()
#
del self.channels[websocket] # del self.channels[websocket]
#
async def disconnect_all(self): # async def disconnect_all(self):
""" # """
Disconnect all Channels # Disconnect all Channels
""" # """
with self._lock: # with self._lock:
for websocket, channel in self.channels.items(): # for websocket, channel in self.channels.items():
if not channel.is_closed(): # if not channel.is_closed():
await channel.close() # await channel.close()
#
self.channels = dict() # self.channels = dict()
#
async def broadcast(self, data): # async def broadcast(self, data):
""" # """
Broadcast data on all Channels # Broadcast data on all Channels
#
:param data: The data to send # :param data: The data to send
""" # """
with self._lock: # with self._lock:
for websocket, channel in self.channels.items(): # for websocket, channel in self.channels.items():
try: # try:
await channel.send(data) # await channel.send(data)
except RuntimeError: # except RuntimeError:
# Handle cannot send after close cases # # Handle cannot send after close cases
await self.on_disconnect(websocket) # await self.on_disconnect(websocket)
#
async def send_direct(self, channel, data): # async def send_direct(self, channel, data):
""" # """
Send data directly through direct_channel only # Send data directly through direct_channel only
#
:param direct_channel: The WebSocketChannel object to send data through # :param direct_channel: The WebSocketChannel object to send data through
:param data: The data to send # :param data: The data to send
""" # """
# We iterate over the channels to get reference to the websocket object # # We iterate over the channels to get reference to the websocket object
# so we can disconnect incase of failure # # so we can disconnect incase of failure
await channel.send(data) # await channel.send(data)
#
def has_channels(self): # def has_channels(self):
""" # """
Flag for more than 0 channels # Flag for more than 0 channels
""" # """
return len(self.channels) > 0 # return len(self.channels) > 0

View File

@ -1,449 +1,449 @@
""" # """
This module manages replicate mode communication # This module manages replicate mode communication
""" # """
import asyncio # import asyncio
import logging # import logging
import secrets # import secrets
import socket # import socket
from threading import Thread # from threading import Thread
from typing import Any, Callable, Coroutine, Dict, Union # from typing import Any, Callable, Coroutine, Dict, Union
#
import websockets # import websockets
from fastapi import Depends # from fastapi import Depends
from fastapi import WebSocket as FastAPIWebSocket # from fastapi import WebSocket as FastAPIWebSocket
from fastapi import WebSocketDisconnect, status # from fastapi import WebSocketDisconnect, status
from janus import Queue as ThreadedQueue # from janus import Queue as ThreadedQueue
#
from freqtrade.enums import ExternalSignalModeType, LeaderMessageType, RPCMessageType # from freqtrade.enums import ExternalSignalModeType, LeaderMessageType, RPCMessageType
from freqtrade.rpc import RPC, RPCHandler # from freqtrade.rpc import RPC, RPCHandler
from freqtrade.rpc.external_signal.channel import ChannelManager # from freqtrade.rpc.external_signal.channel import ChannelManager
from freqtrade.rpc.external_signal.types import MessageType # from freqtrade.rpc.external_signal.types import MessageType
from freqtrade.rpc.external_signal.utils import is_websocket_alive # from freqtrade.rpc.external_signal.utils import is_websocket_alive
#
#
logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
#
#
class ExternalSignalController(RPCHandler): # class ExternalSignalController(RPCHandler):
""" This class handles all websocket communication """ # """ This class handles all websocket communication """
#
def __init__( # def __init__(
self, # self,
rpc: RPC, # rpc: RPC,
config: Dict[str, Any], # config: Dict[str, Any],
api_server: Union[Any, None] = None # api_server: Union[Any, None] = None
) -> None: # ) -> None:
""" # """
Init the ExternalSignalController class, and init the super class RPCHandler # Init the ExternalSignalController class, and init the super class RPCHandler
:param rpc: instance of RPC Helper class # :param rpc: instance of RPC Helper class
:param config: Configuration object # :param config: Configuration object
:param api_server: The ApiServer object # :param api_server: The ApiServer object
:return: None # :return: None
""" # """
super().__init__(rpc, config) # super().__init__(rpc, config)
#
self.freqtrade = rpc._freqtrade # self.freqtrade = rpc._freqtrade
self.api_server = api_server # self.api_server = api_server
#
if not self.api_server: # if not self.api_server:
raise RuntimeError("The API server must be enabled for external signals to work") # raise RuntimeError("The API server must be enabled for external signals to work")
#
self._loop = None # self._loop = None
self._running = False # self._running = False
self._thread = None # self._thread = None
self._queue = None # self._queue = None
#
self._main_task = None # self._main_task = None
self._sub_tasks = None # self._sub_tasks = None
#
self._message_handlers = { # self._message_handlers = {
LeaderMessageType.pairlist: self._rpc._handle_pairlist_message, # LeaderMessageType.pairlist: self._rpc._handle_pairlist_message,
LeaderMessageType.analyzed_df: self._rpc._handle_analyzed_df_message, # LeaderMessageType.analyzed_df: self._rpc._handle_analyzed_df_message,
LeaderMessageType.default: self._rpc._handle_default_message # LeaderMessageType.default: self._rpc._handle_default_message
} # }
#
self.channel_manager = ChannelManager() # self.channel_manager = ChannelManager()
self.external_signal_config = config.get('external_signal', {}) # self.external_signal_config = config.get('external_signal', {})
#
# What the config should look like # # What the config should look like
# "external_signal": { # # "external_signal": {
# "enabled": true, # # "enabled": true,
# "mode": "follower", # # "mode": "follower",
# "leaders": [ # # "leaders": [
# { # # {
# "url": "ws://localhost:8080/signals/ws", # # "url": "ws://localhost:8080/signals/ws",
# "api_token": "test" # # "api_token": "test"
# } # # }
# ] # # ]
# } # # }
#
# "external_signal": { # # "external_signal": {
# "enabled": true, # # "enabled": true,
# "mode": "leader", # # "mode": "leader",
# "api_token": "test" # # "api_token": "test"
# } # # }
#
self.mode = ExternalSignalModeType[ # self.mode = ExternalSignalModeType[
self.external_signal_config.get('mode', 'leader').lower() # self.external_signal_config.get('mode', 'leader').lower()
] # ]
#
self.leaders_list = self.external_signal_config.get('leaders', []) # self.leaders_list = self.external_signal_config.get('leaders', [])
self.push_throttle_secs = self.external_signal_config.get('push_throttle_secs', 0.1) # self.push_throttle_secs = self.external_signal_config.get('push_throttle_secs', 0.1)
#
self.reply_timeout = self.external_signal_config.get('follower_reply_timeout', 10) # self.reply_timeout = self.external_signal_config.get('follower_reply_timeout', 10)
self.ping_timeout = self.external_signal_config.get('follower_ping_timeout', 2) # self.ping_timeout = self.external_signal_config.get('follower_ping_timeout', 2)
self.sleep_time = self.external_signal_config.get('follower_sleep_time', 5) # self.sleep_time = self.external_signal_config.get('follower_sleep_time', 5)
#
# Validate external_signal_config here? # # Validate external_signal_config here?
#
if self.mode == ExternalSignalModeType.follower and len(self.leaders_list) == 0: # if self.mode == ExternalSignalModeType.follower and len(self.leaders_list) == 0:
raise ValueError("You must specify at least 1 leader in follower mode.") # raise ValueError("You must specify at least 1 leader in follower mode.")
#
# This is only used by the leader, the followers use the tokens specified # # This is only used by the leader, the followers use the tokens specified
# in each of the leaders # # in each of the leaders
# If you do not specify an API key in the config, one will be randomly # # If you do not specify an API key in the config, one will be randomly
# generated and logged on startup # # generated and logged on startup
default_api_key = secrets.token_urlsafe(16) # default_api_key = secrets.token_urlsafe(16)
self.secret_api_key = self.external_signal_config.get('api_token', default_api_key) # self.secret_api_key = self.external_signal_config.get('api_token', default_api_key)
#
self.start() # self.start()
#
def is_leader(self): # def is_leader(self):
""" # """
Leader flag # Leader flag
""" # """
return self.enabled() and self.mode == ExternalSignalModeType.leader # return self.enabled() and self.mode == ExternalSignalModeType.leader
#
def enabled(self): # def enabled(self):
""" # """
Enabled flag # Enabled flag
""" # """
return self.external_signal_config.get('enabled', False) # return self.external_signal_config.get('enabled', False)
#
def num_leaders(self): # def num_leaders(self):
""" # """
The number of leaders we should be connected to # The number of leaders we should be connected to
""" # """
return len(self.leaders_list) # return len(self.leaders_list)
#
def start_threaded_loop(self): # def start_threaded_loop(self):
""" # """
Start the main internal loop in another thread to run coroutines # Start the main internal loop in another thread to run coroutines
""" # """
self._loop = asyncio.new_event_loop() # self._loop = asyncio.new_event_loop()
#
if not self._thread: # if not self._thread:
self._thread = Thread(target=self._loop.run_forever) # self._thread = Thread(target=self._loop.run_forever)
self._thread.start() # self._thread.start()
self._running = True # self._running = True
else: # else:
raise RuntimeError("A loop is already running") # raise RuntimeError("A loop is already running")
#
def submit_coroutine(self, coroutine: Coroutine): # def submit_coroutine(self, coroutine: Coroutine):
""" # """
Submit a coroutine to the threaded loop # Submit a coroutine to the threaded loop
""" # """
if not self._running: # if not self._running:
raise RuntimeError("Cannot schedule new futures after shutdown") # raise RuntimeError("Cannot schedule new futures after shutdown")
#
if not self._loop or not self._loop.is_running(): # if not self._loop or not self._loop.is_running():
raise RuntimeError("Loop must be started before any function can" # raise RuntimeError("Loop must be started before any function can"
" be submitted") # " be submitted")
#
return asyncio.run_coroutine_threadsafe(coroutine, self._loop) # return asyncio.run_coroutine_threadsafe(coroutine, self._loop)
#
def start(self): # def start(self):
""" # """
Start the controller main loop # Start the controller main loop
""" # """
self.start_threaded_loop() # self.start_threaded_loop()
self._main_task = self.submit_coroutine(self.main()) # self._main_task = self.submit_coroutine(self.main())
#
async def shutdown(self): # async def shutdown(self):
""" # """
Shutdown all tasks and close up # Shutdown all tasks and close up
""" # """
logger.info("Stopping rpc.externalsignalcontroller") # logger.info("Stopping rpc.externalsignalcontroller")
#
# Flip running flag # # Flip running flag
self._running = False # self._running = False
#
# Cancel sub tasks # # Cancel sub tasks
for task in self._sub_tasks: # for task in self._sub_tasks:
task.cancel() # task.cancel()
#
# Then disconnect all channels # # Then disconnect all channels
await self.channel_manager.disconnect_all() # await self.channel_manager.disconnect_all()
#
def cleanup(self) -> None: # def cleanup(self) -> None:
""" # """
Cleanup pending module resources. # Cleanup pending module resources.
""" # """
if self._thread: # if self._thread:
if self._loop.is_running(): # if self._loop.is_running():
self._main_task.cancel() # self._main_task.cancel()
self._thread.join() # self._thread.join()
#
async def main(self): # async def main(self):
""" # """
Main coro # Main coro
#
Start the loop based on what mode we're in # Start the loop based on what mode we're in
""" # """
try: # try:
if self.mode == ExternalSignalModeType.leader: # if self.mode == ExternalSignalModeType.leader:
logger.info("Starting rpc.externalsignalcontroller in Leader mode") # logger.info("Starting rpc.externalsignalcontroller in Leader mode")
#
await self.run_leader_mode() # await self.run_leader_mode()
elif self.mode == ExternalSignalModeType.follower: # elif self.mode == ExternalSignalModeType.follower:
logger.info("Starting rpc.externalsignalcontroller in Follower mode") # logger.info("Starting rpc.externalsignalcontroller in Follower mode")
#
await self.run_follower_mode() # await self.run_follower_mode()
#
except asyncio.CancelledError: # except asyncio.CancelledError:
# We're cancelled # # We're cancelled
await self.shutdown() # await self.shutdown()
except Exception as e: # except Exception as e:
# Log the error # # Log the error
logger.error(f"Exception occurred in main task: {e}") # logger.error(f"Exception occurred in main task: {e}")
logger.exception(e) # logger.exception(e)
finally: # finally:
# This coroutine is the last thing to be ended, so it should stop the loop # # This coroutine is the last thing to be ended, so it should stop the loop
self._loop.stop() # self._loop.stop()
#
def log_api_token(self): # def log_api_token(self):
""" # """
Log the API token # Log the API token
""" # """
logger.info("-" * 15) # logger.info("-" * 15)
logger.info(f"API_KEY: {self.secret_api_key}") # logger.info(f"API_KEY: {self.secret_api_key}")
logger.info("-" * 15) # logger.info("-" * 15)
#
def send_msg(self, msg: MessageType) -> None: # def send_msg(self, msg: MessageType) -> None:
""" # """
Support RPC calls # Support RPC calls
""" # """
if msg["type"] == RPCMessageType.EMIT_DATA: # if msg["type"] == RPCMessageType.EMIT_DATA:
message = msg.get("message") # message = msg.get("message")
if message: # if message:
self.send_message(message) # self.send_message(message)
else: # else:
logger.error(f"Message is empty! {msg}") # logger.error(f"Message is empty! {msg}")
#
def send_message(self, msg: MessageType) -> None: # def send_message(self, msg: MessageType) -> None:
""" # """
Broadcast message over all channels if there are any # Broadcast message over all channels if there are any
""" # """
#
if self.channel_manager.has_channels(): # if self.channel_manager.has_channels():
self._send_message(msg) # self._send_message(msg)
else: # else:
logger.debug("No listening followers, skipping...") # logger.debug("No listening followers, skipping...")
pass # pass
#
def _send_message(self, msg: MessageType): # def _send_message(self, msg: MessageType):
""" # """
Add data to the internal queue to be broadcasted. This func will block # Add data to the internal queue to be broadcasted. This func will block
if the queue is full. This is meant to be called in the main thread. # if the queue is full. This is meant to be called in the main thread.
""" # """
if self._queue: # if self._queue:
queue = self._queue.sync_q # queue = self._queue.sync_q
queue.put(msg) # This will block if the queue is full # queue.put(msg) # This will block if the queue is full
else: # else:
logger.warning("Can not send data, leader loop has not started yet!") # logger.warning("Can not send data, leader loop has not started yet!")
#
async def send_initial_data(self, channel): # async def send_initial_data(self, channel):
logger.info("Sending initial data through channel") # logger.info("Sending initial data through channel")
#
data = self._rpc._initial_leader_data() # data = self._rpc._initial_leader_data()
#
for message in data: # for message in data:
await channel.send(message) # await channel.send(message)
#
async def _handle_leader_message(self, message: MessageType): # async def _handle_leader_message(self, message: MessageType):
""" # """
Handle message received from a Leader # Handle message received from a Leader
""" # """
type = message.get("data_type", LeaderMessageType.default) # type = message.get("data_type", LeaderMessageType.default)
data = message.get("data") # data = message.get("data")
#
handler: Callable = self._message_handlers[type] # handler: Callable = self._message_handlers[type]
handler(type, data) # handler(type, data)
#
# ---------------------------------------------------------------------- # # ----------------------------------------------------------------------
#
async def run_leader_mode(self): # async def run_leader_mode(self):
""" # """
Main leader coroutine # Main leader coroutine
#
This starts all of the leader coros and registers the endpoint on # This starts all of the leader coros and registers the endpoint on
the ApiServer # the ApiServer
""" # """
self.register_leader_endpoint() # self.register_leader_endpoint()
self.log_api_token() # self.log_api_token()
#
self._sub_tasks = [ # self._sub_tasks = [
self._loop.create_task(self._broadcast_queue_data()) # self._loop.create_task(self._broadcast_queue_data())
] # ]
#
return await asyncio.gather(*self._sub_tasks) # return await asyncio.gather(*self._sub_tasks)
#
async def run_follower_mode(self): # async def run_follower_mode(self):
""" # """
Main follower coroutine # Main follower coroutine
#
This starts all of the follower connection coros # This starts all of the follower connection coros
""" # """
#
rpc_lock = asyncio.Lock() # rpc_lock = asyncio.Lock()
#
self._sub_tasks = [ # self._sub_tasks = [
self._loop.create_task(self._handle_leader_connection(leader, rpc_lock)) # self._loop.create_task(self._handle_leader_connection(leader, rpc_lock))
for leader in self.leaders_list # for leader in self.leaders_list
] # ]
#
return await asyncio.gather(*self._sub_tasks) # return await asyncio.gather(*self._sub_tasks)
#
async def _broadcast_queue_data(self): # async def _broadcast_queue_data(self):
""" # """
Loop over queue data and broadcast it # Loop over queue data and broadcast it
""" # """
# Instantiate the queue in this coroutine so it's attached to our loop # # Instantiate the queue in this coroutine so it's attached to our loop
self._queue = ThreadedQueue() # self._queue = ThreadedQueue()
async_queue = self._queue.async_q # async_queue = self._queue.async_q
#
try: # try:
while self._running: # while self._running:
# Get data from queue # # Get data from queue
data = await async_queue.get() # data = await async_queue.get()
#
# Broadcast it to everyone # # Broadcast it to everyone
await self.channel_manager.broadcast(data) # await self.channel_manager.broadcast(data)
#
# Sleep # # Sleep
await asyncio.sleep(self.push_throttle_secs) # await asyncio.sleep(self.push_throttle_secs)
#
except asyncio.CancelledError: # except asyncio.CancelledError:
# Silently stop # # Silently stop
pass # pass
#
async def get_api_token( # async def get_api_token(
self, # self,
websocket: FastAPIWebSocket, # websocket: FastAPIWebSocket,
token: Union[str, None] = None # token: Union[str, None] = None
): # ):
""" # """
Extract the API key from query param. Must match the # Extract the API key from query param. Must match the
set secret_api_key or the websocket connection will be closed. # set secret_api_key or the websocket connection will be closed.
""" # """
if token == self.secret_api_key: # if token == self.secret_api_key:
return token # return token
else: # else:
logger.info("Denying websocket request...") # logger.info("Denying websocket request...")
await websocket.close(code=status.WS_1008_POLICY_VIOLATION) # await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
#
def register_leader_endpoint(self, path: str = "/signals/ws"): # def register_leader_endpoint(self, path: str = "/signals/ws"):
""" # """
Attach and start the main leader loop to the ApiServer # Attach and start the main leader loop to the ApiServer
#
:param path: The endpoint path # :param path: The endpoint path
""" # """
if not self.api_server: # if not self.api_server:
raise RuntimeError("The leader needs the ApiServer to be active") # raise RuntimeError("The leader needs the ApiServer to be active")
#
# The endpoint function for running the main leader loop # # The endpoint function for running the main leader loop
@self.api_server.app.websocket(path) # @self.api_server.app.websocket(path)
async def leader_endpoint( # async def leader_endpoint(
websocket: FastAPIWebSocket, # websocket: FastAPIWebSocket,
api_key: str = Depends(self.get_api_token) # api_key: str = Depends(self.get_api_token)
): # ):
await self.leader_endpoint_loop(websocket) # await self.leader_endpoint_loop(websocket)
#
async def leader_endpoint_loop(self, websocket: FastAPIWebSocket): # async def leader_endpoint_loop(self, websocket: FastAPIWebSocket):
""" # """
The WebSocket endpoint served by the ApiServer. This handles connections, # The WebSocket endpoint served by the ApiServer. This handles connections,
and adding them to the channel manager. # and adding them to the channel manager.
""" # """
try: # try:
if is_websocket_alive(websocket): # if is_websocket_alive(websocket):
logger.info(f"Follower connected - {websocket.client}") # logger.info(f"Follower connected - {websocket.client}")
channel = await self.channel_manager.on_connect(websocket) # channel = await self.channel_manager.on_connect(websocket)
#
# Send initial data here # # Send initial data here
# Data is being broadcasted right away as soon as startup, # # Data is being broadcasted right away as soon as startup,
# we may not have to send initial data at all. Further testing # # we may not have to send initial data at all. Further testing
# required. # # required.
await self.send_initial_data(channel) # await self.send_initial_data(channel)
#
# Keep connection open until explicitly closed, and sleep # # Keep connection open until explicitly closed, and sleep
try: # try:
while not channel.is_closed(): # while not channel.is_closed():
request = await channel.recv() # request = await channel.recv()
logger.info(f"Follower request - {request}") # logger.info(f"Follower request - {request}")
#
except WebSocketDisconnect: # except WebSocketDisconnect:
# Handle client disconnects # # Handle client disconnects
logger.info(f"Follower disconnected - {websocket.client}") # logger.info(f"Follower disconnected - {websocket.client}")
await self.channel_manager.on_disconnect(websocket) # await self.channel_manager.on_disconnect(websocket)
except Exception as e: # except Exception as e:
logger.info(f"Follower connection failed - {websocket.client}") # logger.info(f"Follower connection failed - {websocket.client}")
logger.exception(e) # logger.exception(e)
# Handle cases like - # # Handle cases like -
# RuntimeError('Cannot call "send" once a closed message has been sent') # # RuntimeError('Cannot call "send" once a closed message has been sent')
await self.channel_manager.on_disconnect(websocket) # await self.channel_manager.on_disconnect(websocket)
#
except Exception: # except Exception:
logger.error(f"Failed to serve - {websocket.client}") # logger.error(f"Failed to serve - {websocket.client}")
await self.channel_manager.on_disconnect(websocket) # await self.channel_manager.on_disconnect(websocket)
#
async def _handle_leader_connection(self, leader, lock): # async def _handle_leader_connection(self, leader, lock):
""" # """
Given a leader, connect and wait on data. If connection is lost, # Given a leader, connect and wait on data. If connection is lost,
it will attempt to reconnect. # it will attempt to reconnect.
""" # """
try: # try:
url, token = leader["url"], leader["api_token"] # url, token = leader["url"], leader["api_token"]
websocket_url = f"{url}?token={token}" # websocket_url = f"{url}?token={token}"
#
logger.info(f"Attempting to connect to Leader at: {url}") # logger.info(f"Attempting to connect to Leader at: {url}")
while True: # while True:
try: # try:
async with websockets.connect(websocket_url) as ws: # async with websockets.connect(websocket_url) as ws:
channel = await self.channel_manager.on_connect(ws) # channel = await self.channel_manager.on_connect(ws)
logger.info(f"Connection to Leader at {url} successful") # logger.info(f"Connection to Leader at {url} successful")
while True: # while True:
try: # try:
data = await asyncio.wait_for( # data = await asyncio.wait_for(
channel.recv(), # channel.recv(),
timeout=self.reply_timeout # timeout=self.reply_timeout
) # )
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed): # except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
# We haven't received data yet. Check the connection and continue. # # We haven't received data yet. Check the connection and continue.
try: # try:
# ping # # ping
ping = await channel.ping() # ping = await channel.ping()
await asyncio.wait_for(ping, timeout=self.ping_timeout) # await asyncio.wait_for(ping, timeout=self.ping_timeout)
logger.debug(f"Connection to {url} still alive...") # logger.debug(f"Connection to {url} still alive...")
continue # continue
except Exception: # except Exception:
logger.info( # logger.info(
f"Ping error {url} - retrying in {self.sleep_time}s") # f"Ping error {url} - retrying in {self.sleep_time}s")
asyncio.sleep(self.sleep_time) # asyncio.sleep(self.sleep_time)
break # break
#
async with lock: # async with lock:
# Acquire lock so only 1 coro handling at a time # # Acquire lock so only 1 coro handling at a time
# as we call the RPC module in the main thread # # as we call the RPC module in the main thread
await self._handle_leader_message(data) # await self._handle_leader_message(data)
#
except (socket.gaierror, ConnectionRefusedError): # except (socket.gaierror, ConnectionRefusedError):
logger.info(f"Connection Refused - retrying connection in {self.sleep_time}s") # logger.info(f"Connection Refused - retrying connection in {self.sleep_time}s")
await asyncio.sleep(self.sleep_time) # await asyncio.sleep(self.sleep_time)
continue # continue
except websockets.exceptions.InvalidStatusCode as e: # except websockets.exceptions.InvalidStatusCode as e:
logger.error(f"Connection Refused - {e}") # logger.error(f"Connection Refused - {e}")
await asyncio.sleep(self.sleep_time) # await asyncio.sleep(self.sleep_time)
continue # continue
#
except asyncio.CancelledError: # except asyncio.CancelledError:
pass # pass

View File

@ -1,61 +1,61 @@
from typing import Union # from typing import Union
#
from fastapi import WebSocket as FastAPIWebSocket # from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket # from websockets import WebSocketClientProtocol as WebSocket
#
from freqtrade.rpc.external_signal.types import WebSocketType # from freqtrade.rpc.external_signal.types import WebSocketType
#
#
class WebSocketProxy: # class WebSocketProxy:
""" # """
WebSocketProxy object to bring the FastAPIWebSocket and websockets.WebSocketClientProtocol # WebSocketProxy object to bring the FastAPIWebSocket and websockets.WebSocketClientProtocol
under the same API # under the same API
""" # """
#
def __init__(self, websocket: WebSocketType): # def __init__(self, websocket: WebSocketType):
self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket # self._websocket: Union[FastAPIWebSocket, WebSocket] = websocket
#
async def send(self, data): # async def send(self, data):
""" # """
Send data on the wrapped websocket # Send data on the wrapped websocket
""" # """
if isinstance(data, str): # if isinstance(data, str):
data = data.encode() # data = data.encode()
#
if hasattr(self._websocket, "send_bytes"): # if hasattr(self._websocket, "send_bytes"):
await self._websocket.send_bytes(data) # await self._websocket.send_bytes(data)
else: # else:
await self._websocket.send(data) # await self._websocket.send(data)
#
async def recv(self): # async def recv(self):
""" # """
Receive data on the wrapped websocket # Receive data on the wrapped websocket
""" # """
if hasattr(self._websocket, "receive_bytes"): # if hasattr(self._websocket, "receive_bytes"):
return await self._websocket.receive_bytes() # return await self._websocket.receive_bytes()
else: # else:
return await self._websocket.recv() # return await self._websocket.recv()
#
async def ping(self): # async def ping(self):
""" # """
Ping the websocket, not supported by FastAPI WebSockets # Ping the websocket, not supported by FastAPI WebSockets
""" # """
if hasattr(self._websocket, "ping"): # if hasattr(self._websocket, "ping"):
return await self._websocket.ping() # return await self._websocket.ping()
return False # return False
#
async def close(self, code: int = 1000): # async def close(self, code: int = 1000):
""" # """
Close the websocket connection, only supported by FastAPI WebSockets # Close the websocket connection, only supported by FastAPI WebSockets
""" # """
if hasattr(self._websocket, "close"): # if hasattr(self._websocket, "close"):
return await self._websocket.close(code) # return await self._websocket.close(code)
pass # pass
#
async def accept(self): # async def accept(self):
""" # """
Accept the WebSocket connection, only support by FastAPI WebSockets # Accept the WebSocket connection, only support by FastAPI WebSockets
""" # """
if hasattr(self._websocket, "accept"): # if hasattr(self._websocket, "accept"):
return await self._websocket.accept() # return await self._websocket.accept()
pass # pass

View File

@ -1,65 +1,65 @@
import json # import json
import logging # import logging
from abc import ABC, abstractmethod # from abc import ABC, abstractmethod
#
import msgpack # import msgpack
import orjson # import orjson
#
from freqtrade.rpc.external_signal.proxy import WebSocketProxy # from freqtrade.rpc.external_signal.proxy import WebSocketProxy
#
#
logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
#
#
class WebSocketSerializer(ABC): # class WebSocketSerializer(ABC):
def __init__(self, websocket: WebSocketProxy): # def __init__(self, websocket: WebSocketProxy):
self._websocket: WebSocketProxy = websocket # self._websocket: WebSocketProxy = websocket
#
@abstractmethod # @abstractmethod
def _serialize(self, data): # def _serialize(self, data):
raise NotImplementedError() # raise NotImplementedError()
#
@abstractmethod # @abstractmethod
def _deserialize(self, data): # def _deserialize(self, data):
raise NotImplementedError() # raise NotImplementedError()
#
async def send(self, data: bytes): # async def send(self, data: bytes):
await self._websocket.send(self._serialize(data)) # await self._websocket.send(self._serialize(data))
#
async def recv(self) -> bytes: # async def recv(self) -> bytes:
data = await self._websocket.recv() # data = await self._websocket.recv()
#
return self._deserialize(data) # return self._deserialize(data)
#
async def close(self, code: int = 1000): # async def close(self, code: int = 1000):
await self._websocket.close(code) # await self._websocket.close(code)
#
# Going to explore using MsgPack as the serialization, # # Going to explore using MsgPack as the serialization,
# as that might be the best method for sending pandas # # as that might be the best method for sending pandas
# dataframes over the wire # # dataframes over the wire
#
#
class JSONWebSocketSerializer(WebSocketSerializer): # class JSONWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data): # def _serialize(self, data):
return json.dumps(data) # return json.dumps(data)
#
def _deserialize(self, data): # def _deserialize(self, data):
return json.loads(data) # return json.loads(data)
#
#
class ORJSONWebSocketSerializer(WebSocketSerializer): # class ORJSONWebSocketSerializer(WebSocketSerializer):
ORJSON_OPTIONS = orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY # ORJSON_OPTIONS = orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY
#
def _serialize(self, data): # def _serialize(self, data):
return orjson.dumps(data, option=self.ORJSON_OPTIONS) # return orjson.dumps(data, option=self.ORJSON_OPTIONS)
#
def _deserialize(self, data): # def _deserialize(self, data):
return orjson.loads(data, option=self.ORJSON_OPTIONS) # return orjson.loads(data, option=self.ORJSON_OPTIONS)
#
#
class MsgPackWebSocketSerializer(WebSocketSerializer): # class MsgPackWebSocketSerializer(WebSocketSerializer):
def _serialize(self, data): # def _serialize(self, data):
return msgpack.packb(data, use_bin_type=True) # return msgpack.packb(data, use_bin_type=True)
#
def _deserialize(self, data): # def _deserialize(self, data):
return msgpack.unpackb(data, raw=False) # return msgpack.unpackb(data, raw=False)

View File

@ -1,8 +1,8 @@
from typing import Any, Dict, TypeVar # from typing import Any, Dict, TypeVar
#
from fastapi import WebSocket as FastAPIWebSocket # from fastapi import WebSocket as FastAPIWebSocket
from websockets import WebSocketClientProtocol as WebSocket # from websockets import WebSocketClientProtocol as WebSocket
#
#
WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket) # WebSocketType = TypeVar("WebSocketType", FastAPIWebSocket, WebSocket)
MessageType = Dict[str, Any] # MessageType = Dict[str, Any]

View File

@ -1,22 +1,10 @@
from pandas import DataFrame # from starlette.websockets import WebSocket, WebSocketState
from starlette.websockets import WebSocket, WebSocketState #
#
from freqtrade.enums.signaltype import SignalTagType, SignalType # async def is_websocket_alive(ws: WebSocket) -> bool:
# if (
# ws.application_state == WebSocketState.CONNECTED and
async def is_websocket_alive(ws: WebSocket) -> bool: # ws.client_state == WebSocketState.CONNECTED
if ( # ):
ws.application_state == WebSocketState.CONNECTED and # return True
ws.client_state == WebSocketState.CONNECTED # return False
):
return True
return False
def remove_entry_exit_signals(dataframe: DataFrame):
dataframe[SignalType.ENTER_LONG.value] = 0
dataframe[SignalType.EXIT_LONG.value] = 0
dataframe[SignalType.ENTER_SHORT.value] = 0
dataframe[SignalType.EXIT_SHORT.value] = 0
dataframe[SignalTagType.ENTER_TAG.value] = None
dataframe[SignalTagType.EXIT_TAG.value] = None

View File

@ -19,13 +19,12 @@ from freqtrade.configuration.timerange import TimeRange
from freqtrade.constants import CANCEL_REASON, DATETIME_PRINT_FORMAT from freqtrade.constants import CANCEL_REASON, DATETIME_PRINT_FORMAT
from freqtrade.data.history import load_data from freqtrade.data.history import load_data
from freqtrade.data.metrics import calculate_max_drawdown from freqtrade.data.metrics import calculate_max_drawdown
from freqtrade.enums import (CandleType, ExitCheckTuple, ExitType, LeaderMessageType, from freqtrade.enums import (CandleType, ExitCheckTuple, ExitType, SignalDirection, State,
SignalDirection, State, TradingMode) TradingMode)
from freqtrade.exceptions import ExchangeError, PricingError from freqtrade.exceptions import ExchangeError, PricingError
from freqtrade.exchange import timeframe_to_minutes, timeframe_to_msecs from freqtrade.exchange import timeframe_to_minutes, timeframe_to_msecs
from freqtrade.loggers import bufferHandler from freqtrade.loggers import bufferHandler
from freqtrade.misc import (decimals_per_coin, json_to_dataframe, remove_entry_exit_signals, from freqtrade.misc import decimals_per_coin, shorten_date
shorten_date)
from freqtrade.persistence import PairLocks, Trade from freqtrade.persistence import PairLocks, Trade
from freqtrade.persistence.models import PairLock from freqtrade.persistence.models import PairLock
from freqtrade.plugins.pairlist.pairlist_helpers import expand_pairlist from freqtrade.plugins.pairlist.pairlist_helpers import expand_pairlist
@ -1090,65 +1089,65 @@ class RPC:
'last_process_loc': last_p.astimezone(tzlocal()).strftime(DATETIME_PRINT_FORMAT), 'last_process_loc': last_p.astimezone(tzlocal()).strftime(DATETIME_PRINT_FORMAT),
'last_process_ts': int(last_p.timestamp()), 'last_process_ts': int(last_p.timestamp()),
} }
#
# ------------------------------ EXTERNAL SIGNALS ----------------------- # # ------------------------------ EXTERNAL SIGNALS -----------------------
#
def _initial_leader_data(self): # def _initial_leader_data(self):
# We create a list of Messages to send to the follower on connect # # We create a list of Messages to send to the follower on connect
data = [] # data = []
#
# Send Pairlist data # # Send Pairlist data
data.append({ # data.append({
"data_type": LeaderMessageType.pairlist, # "data_type": LeaderMessageType.pairlist,
"data": self._freqtrade.pairlists._whitelist # "data": self._freqtrade.pairlists._whitelist
}) # })
#
return data # return data
#
def _handle_pairlist_message(self, type, data): # def _handle_pairlist_message(self, type, data):
""" # """
Handles the emitted pairlists from the Leaders # Handles the emitted pairlists from the Leaders
#
:param type: The data_type of the data # :param type: The data_type of the data
:param data: The data # :param data: The data
""" # """
pairlist = data # pairlist = data
#
logger.debug(f"Handling Pairlist message: {pairlist}") # logger.debug(f"Handling Pairlist message: {pairlist}")
#
external_pairlist = self._freqtrade.pairlists._pairlist_handlers[0] # external_pairlist = self._freqtrade.pairlists._pairlist_handlers[0]
external_pairlist.add_pairlist_data(pairlist) # external_pairlist.add_pairlist_data(pairlist)
#
def _handle_analyzed_df_message(self, type, data): # def _handle_analyzed_df_message(self, type, data):
""" # """
Handles the analyzed dataframes from the Leaders # Handles the analyzed dataframes from the Leaders
#
:param type: The data_type of the data # :param type: The data_type of the data
:param data: The data # :param data: The data
""" # """
key, value = data["key"], data["value"] # key, value = data["key"], data["value"]
pair, timeframe, candle_type = key # pair, timeframe, candle_type = key
#
# Skip any pairs that we don't have in the pairlist? # # Skip any pairs that we don't have in the pairlist?
# leader_pairlist = self._freqtrade.pairlists._whitelist # # leader_pairlist = self._freqtrade.pairlists._whitelist
# if pair not in leader_pairlist: # # if pair not in leader_pairlist:
# return # # return
#
dataframe = json_to_dataframe(value) # dataframe = json_to_dataframe(value)
#
if self._config.get('external_signal', {}).get('remove_signals_analyzed_df', False): # if self._config.get('external_signal', {}).get('remove_signals_analyzed_df', False):
dataframe = remove_entry_exit_signals(dataframe) # dataframe = remove_entry_exit_signals(dataframe)
#
logger.debug(f"Handling analyzed dataframe for {pair}") # logger.debug(f"Handling analyzed dataframe for {pair}")
logger.debug(dataframe.tail()) # logger.debug(dataframe.tail())
#
# Add the dataframe to the dataprovider # # Add the dataframe to the dataprovider
dataprovider = self._freqtrade.dataprovider # dataprovider = self._freqtrade.dataprovider
dataprovider.add_external_df(pair, timeframe, dataframe, candle_type) # dataprovider.add_external_df(pair, timeframe, dataframe, candle_type)
#
def _handle_default_message(self, type, data): # def _handle_default_message(self, type, data):
""" # """
Default leader message handler, just logs it. We should never have to # Default leader message handler, just logs it. We should never have to
run this unless the leader sends us some weird message. # run this unless the leader sends us some weird message.
""" # """
logger.debug(f"Received message from Leader of type {type}: {data}") # logger.debug(f"Received message from Leader of type {type}: {data}")

View File

@ -51,14 +51,14 @@ class RPCManager:
# Enable External Signals mode # Enable External Signals mode
# For this to be enabled, the API server must also be enabled # For this to be enabled, the API server must also be enabled
if config.get('external_signal', {}).get('enabled', False): # if config.get('external_signal', {}).get('enabled', False):
logger.info('Enabling RPC.ExternalSignalController') # logger.info('Enabling RPC.ExternalSignalController')
from freqtrade.rpc.external_signal import ExternalSignalController # from freqtrade.rpc.external_signal import ExternalSignalController
external_signals = ExternalSignalController(self._rpc, config, apiserver) # external_signals = ExternalSignalController(self._rpc, config, apiserver)
self.registered_modules.append(external_signals) # self.registered_modules.append(external_signals)
#
# Attach the controller to FreqTrade # # Attach the controller to FreqTrade
freqtrade.external_signal_controller = external_signals # freqtrade.external_signal_controller = external_signals
def cleanup(self) -> None: def cleanup(self) -> None:
""" Stops all enabled rpc modules """ """ Stops all enabled rpc modules """
@ -78,7 +78,6 @@ class RPCManager:
'status': 'stopping bot' 'status': 'stopping bot'
} }
""" """
if msg.get("type") != RPCMessageType.EMIT_DATA:
logger.info('Sending rpc message: %s', msg) logger.info('Sending rpc message: %s', msg)
if 'pair' in msg: if 'pair' in msg:
msg.update({ msg.update({
@ -138,12 +137,3 @@ class RPCManager:
'type': RPCMessageType.STARTUP, 'type': RPCMessageType.STARTUP,
'status': f'Using Protections: \n{prots}' 'status': f'Using Protections: \n{prots}'
}) })
def emit_data(self, data: Dict[str, Any]):
"""
Send a message via RPC with type RPCMessageType.EMIT_DATA
"""
self.send_msg({
"type": RPCMessageType.EMIT_DATA,
"message": data
})

58
scripts/test_ws_client.py Normal file
View File

@ -0,0 +1,58 @@
import asyncio
import logging
import socket
import websockets
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
async def _client():
try:
while True:
try:
url = "ws://localhost:8080/api/v1/message/ws?token=testtoken"
async with websockets.connect(url) as ws:
logger.info("Connection successful")
while True:
try:
data = await asyncio.wait_for(
ws.recv(),
timeout=5
)
logger.info(f"Data received - {data}")
except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed):
# We haven't received data yet. Check the connection and continue.
try:
# ping
ping = await ws.ping()
await asyncio.wait_for(ping, timeout=2)
logger.debug(f"Connection to {url} still alive...")
continue
except Exception:
logger.info(
f"Ping error {url} - retrying in 5s")
asyncio.sleep(2)
break
except (socket.gaierror, ConnectionRefusedError):
logger.info("Connection Refused - retrying connection in 5s")
await asyncio.sleep(2)
continue
except websockets.exceptions.InvalidStatusCode as e:
logger.error(f"Connection Refused - {e}")
await asyncio.sleep(2)
continue
except (asyncio.CancelledError, KeyboardInterrupt):
pass
def main():
asyncio.run(_client())
if __name__ == "__main__":
main()