freqtrade_origin/freqtrade/rpc/api_server/webserver.py
2022-11-14 20:27:45 -07:00

312 lines
11 KiB
Python

import asyncio
import logging
from ipaddress import IPv4Address
from typing import Any, Dict
import orjson
import uvicorn
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Look into alternatives
from janus import Queue as ThreadedQueue
from starlette.responses import JSONResponse
from freqtrade.constants import Config
from freqtrade.exceptions import OperationalException
from freqtrade.rpc.api_server.uvicorn_threaded import UvicornServer
from freqtrade.rpc.api_server.ws.message_stream import MessageStream
from freqtrade.rpc.api_server.ws_schemas import WSMessageSchemaType
from freqtrade.rpc.rpc import RPC, RPCException, RPCHandler
logger = logging.getLogger(__name__)
class FTJSONResponse(JSONResponse):
media_type = "application/json"
def render(self, content: Any) -> bytes:
"""
Use rapidjson for responses
Handles NaN and Inf / -Inf in a javascript way by default.
"""
return orjson.dumps(content, option=orjson.OPT_SERIALIZE_NUMPY)
class ApiServer(RPCHandler):
__instance = None
__initialized = False
_rpc: RPC
# Backtesting type: Backtesting
_bt = None
_bt_data = None
_bt_timerange = None
_bt_last_config: Config = {}
_has_rpc: bool = False
_bgtask_running: bool = False
_config: Config = {}
# Exchange - only available in webserver mode.
_exchange = None
# websocket message queue stuff
# _ws_channel_manager = None
# _ws_thread = None
# _ws_loop = None
_message_stream = None
def __new__(cls, *args, **kwargs):
"""
This class is a singleton.
We'll only have one instance of it around.
"""
if ApiServer.__instance is None:
ApiServer.__instance = object.__new__(cls)
ApiServer.__initialized = False
return ApiServer.__instance
def __init__(self, config: Config, standalone: bool = False) -> None:
ApiServer._config = config
if self.__initialized and (standalone or self._standalone):
return
self._standalone: bool = standalone
self._server = None
self._ws_queue = None
self._ws_publisher_task = None
ApiServer.__initialized = True
api_config = self._config['api_server']
# ApiServer._ws_channel_manager = ChannelManager()
self.app = FastAPI(title="Freqtrade API",
docs_url='/docs' if api_config.get('enable_openapi', False) else None,
redoc_url=None,
default_response_class=FTJSONResponse,
)
self.configure_app(self.app, self._config)
self.start_api()
def add_rpc_handler(self, rpc: RPC):
"""
Attach rpc handler
"""
if not self._has_rpc:
ApiServer._rpc = rpc
ApiServer._has_rpc = True
else:
# This should not happen assuming we didn't mess up.
raise OperationalException('RPC Handler already attached.')
def cleanup(self) -> None:
""" Cleanup pending module resources """
ApiServer._has_rpc = False
del ApiServer._rpc
if self._server and not self._standalone:
logger.info("Stopping API Server")
self._server.cleanup()
# if self._ws_thread and self._ws_loop:
# logger.info("Stopping API Server background tasks")
# if self._ws_background_task:
# # Cancel the queue task
# self._ws_background_task.cancel()
# self._ws_thread.join()
# self._ws_thread = None
# self._ws_loop = None
# self._ws_background_task = None
@classmethod
def shutdown(cls):
cls.__initialized = False
del cls.__instance
cls.__instance = None
cls._has_rpc = False
cls._rpc = None
def send_msg(self, msg: Dict[str, Any]) -> None:
if self._ws_queue:
sync_q = self._ws_queue.sync_q
sync_q.put(msg)
def handle_rpc_exception(self, request, exc):
logger.exception(f"API Error calling: {exc}")
return JSONResponse(
status_code=502,
content={'error': f"Error querying {request.url.path}: {exc.message}"}
)
def configure_app(self, app: FastAPI, config):
from freqtrade.rpc.api_server.api_auth import http_basic_or_jwt_token, router_login
from freqtrade.rpc.api_server.api_backtest import router as api_backtest
from freqtrade.rpc.api_server.api_v1 import router as api_v1
from freqtrade.rpc.api_server.api_v1 import router_public as api_v1_public
from freqtrade.rpc.api_server.api_ws import router as ws_router
from freqtrade.rpc.api_server.web_ui import router_ui
app.include_router(api_v1_public, prefix="/api/v1")
app.include_router(api_v1, prefix="/api/v1",
dependencies=[Depends(http_basic_or_jwt_token)],
)
app.include_router(api_backtest, prefix="/api/v1",
dependencies=[Depends(http_basic_or_jwt_token)],
)
app.include_router(ws_router, prefix="/api/v1")
app.include_router(router_login, prefix="/api/v1", tags=["auth"])
# UI Router MUST be last!
app.include_router(router_ui, prefix='')
app.add_middleware(
CORSMiddleware,
allow_origins=config['api_server'].get('CORS_origins', []),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_exception_handler(RPCException, self.handle_rpc_exception)
app.add_event_handler(
event_type="startup",
func=self._api_startup_event
)
app.add_event_handler(
event_type="shutdown",
func=self._api_shutdown_event
)
async def _api_startup_event(self):
if not ApiServer._message_stream:
ApiServer._message_stream = MessageStream()
if not self._ws_queue:
self._ws_queue = ThreadedQueue()
if not self._ws_publisher_task:
self._ws_publisher_task = asyncio.create_task(
self._publish_messages()
)
async def _api_shutdown_event(self):
if ApiServer._message_stream:
ApiServer._message_stream = None
if self._ws_queue:
self._ws_queue = None
if self._ws_publisher_task:
self._ws_publisher_task.cancel()
async def _publish_messages(self):
"""
Background task that reads messages from the queue and adds them
to the message stream
"""
try:
async_queue = self._ws_queue.async_q
message_stream = ApiServer._message_stream
while message_stream:
message: WSMessageSchemaType = await async_queue.get()
message_stream.publish(message)
# Make sure to throttle how fast we
# publish messages as some clients will be
# slower than others
await asyncio.sleep(0.01)
async_queue.task_done()
finally:
self._ws_queue = None
# def start_message_queue(self):
# if self._ws_thread:
# return
# # Create a new loop, as it'll be just for the background thread
# self._ws_loop = asyncio.new_event_loop()
# # Start the thread
# self._ws_thread = Thread(target=self._ws_loop.run_forever)
# self._ws_thread.start()
# # Finally, submit the coro to the thread
# self._ws_background_task = asyncio.run_coroutine_threadsafe(
# self._broadcast_queue_data(), loop=self._ws_loop)
# async def _broadcast_queue_data(self):
# # Instantiate the queue in this coroutine so it's attached to our loop
# self._ws_queue = ThreadedQueue()
# async_queue = self._ws_queue.async_q
# try:
# while True:
# logger.debug("Getting queue messages...")
# # Get data from queue
# message: WSMessageSchemaType = await async_queue.get()
# logger.debug(f"Found message of type: {message.get('type')}")
# async_queue.task_done()
# # Broadcast it
# await self._ws_channel_manager.broadcast(message)
# except asyncio.CancelledError:
# pass
# # For testing, shouldn't happen when stable
# except Exception as e:
# logger.exception(f"Exception happened in background task: {e}")
# finally:
# # Disconnect channels and stop the loop on cancel
# await self._ws_channel_manager.disconnect_all()
# self._ws_loop.stop()
# # Avoid adding more items to the queue if they aren't
# # going to get broadcasted.
# self._ws_queue = None
def start_api(self):
"""
Start API ... should be run in thread.
"""
rest_ip = self._config['api_server']['listen_ip_address']
rest_port = self._config['api_server']['listen_port']
logger.info(f'Starting HTTP Server at {rest_ip}:{rest_port}')
if not IPv4Address(rest_ip).is_loopback:
logger.warning("SECURITY WARNING - Local Rest Server listening to external connections")
logger.warning("SECURITY WARNING - This is insecure please set to your loopback,"
"e.g 127.0.0.1 in config.json")
if not self._config['api_server'].get('password'):
logger.warning("SECURITY WARNING - No password for local REST Server defined. "
"Please make sure that this is intentional!")
if (self._config['api_server'].get('jwt_secret_key', 'super-secret')
in ('super-secret, somethingrandom')):
logger.warning("SECURITY WARNING - `jwt_secret_key` seems to be default."
"Others may be able to log into your bot.")
logger.info('Starting Local Rest Server.')
verbosity = self._config['api_server'].get('verbosity', 'error')
uvconfig = uvicorn.Config(self.app,
port=rest_port,
host=rest_ip,
use_colors=False,
log_config=None,
access_log=True if verbosity != 'error' else False,
ws_ping_interval=None # We do this explicitly ourselves
)
try:
self._server = UvicornServer(uvconfig)
if self._standalone:
self._server.run()
else:
# self.start_message_queue()
self._server.run_in_thread()
except Exception:
logger.exception("Api server failed to start.")