chore: update persistence to modern typing syntax

This commit is contained in:
Matthias 2024-10-04 06:55:05 +02:00
parent 2e69e38adb
commit e9a6ba03f9
6 changed files with 56 additions and 53 deletions

View File

@ -1,7 +1,8 @@
import json import json
import logging import logging
from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from typing import Any, ClassVar, List, Optional, Sequence from typing import Any, ClassVar, Optional
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, select from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, select
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
@ -85,7 +86,7 @@ class CustomDataWrapper:
""" """
use_db = True use_db = True
custom_data: List[_CustomData] = [] custom_data: list[_CustomData] = []
unserialized_types = ["bool", "float", "int", "str"] unserialized_types = ["bool", "float", "int", "str"]
@staticmethod @staticmethod
@ -116,7 +117,7 @@ class CustomDataWrapper:
_CustomData.session.commit() _CustomData.session.commit()
@staticmethod @staticmethod
def get_custom_data(*, trade_id: int, key: Optional[str] = None) -> List[_CustomData]: def get_custom_data(*, trade_id: int, key: Optional[str] = None) -> list[_CustomData]:
if CustomDataWrapper.use_db: if CustomDataWrapper.use_db:
filters = [ filters = [
_CustomData.ft_trade_id == trade_id, _CustomData.ft_trade_id == trade_id,

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import List, Optional from typing import Optional
from sqlalchemy import inspect, select, text, update from sqlalchemy import inspect, select, text, update
@ -10,19 +10,19 @@ from freqtrade.persistence.trade_model import Order, Trade
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_table_names_for_table(inspector, tabletype) -> List[str]: def get_table_names_for_table(inspector, tabletype) -> list[str]:
return [t for t in inspector.get_table_names() if t.startswith(tabletype)] return [t for t in inspector.get_table_names() if t.startswith(tabletype)]
def has_column(columns: List, searchname: str) -> bool: def has_column(columns: list, searchname: str) -> bool:
return len(list(filter(lambda x: x["name"] == searchname, columns))) == 1 return len(list(filter(lambda x: x["name"] == searchname, columns))) == 1
def get_column_def(columns: List, column: str, default: str) -> str: def get_column_def(columns: list, column: str, default: str) -> str:
return default if not has_column(columns, column) else column return default if not has_column(columns, column) else column
def get_backup_name(tabs: List[str], backup_prefix: str): def get_backup_name(tabs: list[str], backup_prefix: str):
table_back_name = backup_prefix table_back_name = backup_prefix
for i, table_back_name in enumerate(tabs): for i, table_back_name in enumerate(tabs):
table_back_name = f"{backup_prefix}{i}" table_back_name = f"{backup_prefix}{i}"
@ -77,9 +77,9 @@ def migrate_trades_and_orders_table(
inspector, inspector,
engine, engine,
trade_back_name: str, trade_back_name: str,
cols: List, cols: list,
order_back_name: str, order_back_name: str,
cols_order: List, cols_order: list,
): ):
base_currency = get_column_def(cols, "base_currency", "null") base_currency = get_column_def(cols, "base_currency", "null")
stake_currency = get_column_def(cols, "stake_currency", "null") stake_currency = get_column_def(cols, "stake_currency", "null")
@ -230,7 +230,7 @@ def drop_orders_table(engine, table_back_name: str):
connection.execute(text("drop table orders")) connection.execute(text("drop table orders"))
def migrate_orders_table(engine, table_back_name: str, cols_order: List): def migrate_orders_table(engine, table_back_name: str, cols_order: list):
ft_fee_base = get_column_def(cols_order, "ft_fee_base", "null") ft_fee_base = get_column_def(cols_order, "ft_fee_base", "null")
average = get_column_def(cols_order, "average", "null") average = get_column_def(cols_order, "average", "null")
stop_price = get_column_def(cols_order, "stop_price", "null") stop_price = get_column_def(cols_order, "stop_price", "null")
@ -262,7 +262,7 @@ def migrate_orders_table(engine, table_back_name: str, cols_order: List):
) )
def migrate_pairlocks_table(decl_base, inspector, engine, pairlock_back_name: str, cols: List): def migrate_pairlocks_table(decl_base, inspector, engine, pairlock_back_name: str, cols: list):
# Schema migration necessary # Schema migration necessary
with engine.begin() as connection: with engine.begin() as connection:
connection.execute(text(f"alter table pairlocks rename to {pairlock_back_name}")) connection.execute(text(f"alter table pairlocks rename to {pairlock_back_name}"))

View File

@ -5,7 +5,7 @@ This module contains the class to persist trades into SQLite
import logging import logging
import threading import threading
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Dict, Final, Optional from typing import Any, Final, Optional
from sqlalchemy import create_engine, inspect from sqlalchemy import create_engine, inspect
from sqlalchemy.exc import NoSuchModuleError from sqlalchemy.exc import NoSuchModuleError
@ -51,7 +51,7 @@ def init_db(db_url: str) -> None:
:param db_url: Database to use :param db_url: Database to use
:return: None :return: None
""" """
kwargs: Dict[str, Any] = {} kwargs: dict[str, Any] = {}
if db_url == "sqlite:///": if db_url == "sqlite:///":
raise OperationalException( raise OperationalException(

View File

@ -1,5 +1,5 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, ClassVar, Dict, Optional from typing import Any, ClassVar, Optional
from sqlalchemy import ScalarResult, String, or_, select from sqlalchemy import ScalarResult, String, or_, select
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@ -64,7 +64,7 @@ class PairLock(ModelBase):
def get_all_locks() -> ScalarResult["PairLock"]: def get_all_locks() -> ScalarResult["PairLock"]:
return PairLock.session.scalars(select(PairLock)) return PairLock.session.scalars(select(PairLock))
def to_json(self) -> Dict[str, Any]: def to_json(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"pair": self.pair, "pair": self.pair,

View File

@ -1,6 +1,7 @@
import logging import logging
from collections.abc import Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional, Sequence from typing import Optional
from sqlalchemy import select from sqlalchemy import select
@ -19,7 +20,7 @@ class PairLocks:
""" """
use_db = True use_db = True
locks: List[PairLock] = [] locks: list[PairLock] = []
timeframe: str = "" timeframe: str = ""

View File

@ -4,10 +4,11 @@ This module contains the class to persist trades into SQLite
import logging import logging
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from math import isclose from math import isclose
from typing import Any, ClassVar, Dict, List, Optional, Sequence, cast from typing import Any, ClassVar, Optional, cast
from sqlalchemy import ( from sqlalchemy import (
Enum, Enum,
@ -215,8 +216,8 @@ class Order(ModelBase):
) )
self.order_update_date = datetime.now(timezone.utc) self.order_update_date = datetime.now(timezone.utc)
def to_ccxt_object(self, stopPriceName: str = "stopPrice") -> Dict[str, Any]: def to_ccxt_object(self, stopPriceName: str = "stopPrice") -> dict[str, Any]:
order: Dict[str, Any] = { order: dict[str, Any] = {
"id": self.order_id, "id": self.order_id,
"symbol": self.ft_pair, "symbol": self.ft_pair,
"price": self.price, "price": self.price,
@ -243,7 +244,7 @@ class Order(ModelBase):
return order return order
def to_json(self, entry_side: str, minified: bool = False) -> Dict[str, Any]: def to_json(self, entry_side: str, minified: bool = False) -> dict[str, Any]:
""" """
:param minified: If True, only return a subset of the data is returned. :param minified: If True, only return a subset of the data is returned.
Only used for backtesting. Only used for backtesting.
@ -308,7 +309,7 @@ class Order(ModelBase):
trade.adjust_stop_loss(trade.open_rate, trade.stop_loss_pct) trade.adjust_stop_loss(trade.open_rate, trade.stop_loss_pct)
@staticmethod @staticmethod
def update_orders(orders: List["Order"], order: Dict[str, Any]): def update_orders(orders: list["Order"], order: dict[str, Any]):
""" """
Get all non-closed orders - useful when trying to batch-update orders Get all non-closed orders - useful when trying to batch-update orders
""" """
@ -327,7 +328,7 @@ class Order(ModelBase):
@classmethod @classmethod
def parse_from_ccxt_object( def parse_from_ccxt_object(
cls, cls,
order: Dict[str, Any], order: dict[str, Any],
pair: str, pair: str,
side: str, side: str,
amount: Optional[float] = None, amount: Optional[float] = None,
@ -373,17 +374,17 @@ class LocalTrade:
use_db: bool = False use_db: bool = False
# Trades container for backtesting # Trades container for backtesting
bt_trades: List["LocalTrade"] = [] bt_trades: list["LocalTrade"] = []
bt_trades_open: List["LocalTrade"] = [] bt_trades_open: list["LocalTrade"] = []
# Copy of trades_open - but indexed by pair # Copy of trades_open - but indexed by pair
bt_trades_open_pp: Dict[str, List["LocalTrade"]] = defaultdict(list) bt_trades_open_pp: dict[str, list["LocalTrade"]] = defaultdict(list)
bt_open_open_trade_count: int = 0 bt_open_open_trade_count: int = 0
bt_total_profit: float = 0 bt_total_profit: float = 0
realized_profit: float = 0 realized_profit: float = 0
id: int = 0 id: int = 0
orders: List[Order] = [] orders: list[Order] = []
exchange: str = "" exchange: str = ""
pair: str = "" pair: str = ""
@ -569,7 +570,7 @@ class LocalTrade:
return "" return ""
@property @property
def open_orders(self) -> List[Order]: def open_orders(self) -> list[Order]:
""" """
All open orders for this trade excluding stoploss orders All open orders for this trade excluding stoploss orders
""" """
@ -586,7 +587,7 @@ class LocalTrade:
return len(open_orders_wo_sl) > 0 return len(open_orders_wo_sl) > 0
@property @property
def open_sl_orders(self) -> List[Order]: def open_sl_orders(self) -> list[Order]:
""" """
All open stoploss orders for this trade All open stoploss orders for this trade
""" """
@ -603,14 +604,14 @@ class LocalTrade:
return len(open_sl_orders) > 0 return len(open_sl_orders) > 0
@property @property
def sl_orders(self) -> List[Order]: def sl_orders(self) -> list[Order]:
""" """
All stoploss orders for this trade All stoploss orders for this trade
""" """
return [o for o in self.orders if o.ft_order_side in ["stoploss"]] return [o for o in self.orders if o.ft_order_side in ["stoploss"]]
@property @property
def open_orders_ids(self) -> List[str]: def open_orders_ids(self) -> list[str]:
open_orders_ids_wo_sl = [ open_orders_ids_wo_sl = [
oo.order_id for oo in self.open_orders if oo.ft_order_side not in ["stoploss"] oo.order_id for oo in self.open_orders if oo.ft_order_side not in ["stoploss"]
] ]
@ -637,7 +638,7 @@ class LocalTrade:
f"open_rate={self.open_rate:.8f}, open_since={open_since})" f"open_rate={self.open_rate:.8f}, open_since={open_since})"
) )
def to_json(self, minified: bool = False) -> Dict[str, Any]: def to_json(self, minified: bool = False) -> dict[str, Any]:
""" """
:param minified: If True, only return a subset of the data is returned. :param minified: If True, only return a subset of the data is returned.
Only used for backtesting. Only used for backtesting.
@ -956,7 +957,7 @@ class LocalTrade:
else: else:
return False return False
def update_order(self, order: Dict) -> None: def update_order(self, order: dict) -> None:
Order.update_orders(self.orders, order) Order.update_orders(self.orders, order)
@property @property
@ -1280,7 +1281,7 @@ class LocalTrade:
else: else:
return None return None
def select_filled_orders(self, order_side: Optional[str] = None) -> List["Order"]: def select_filled_orders(self, order_side: Optional[str] = None) -> list["Order"]:
""" """
Finds filled orders for this order side. Finds filled orders for this order side.
Will not return open orders which already partially filled. Will not return open orders which already partially filled.
@ -1296,7 +1297,7 @@ class LocalTrade:
and o.status in NON_OPEN_EXCHANGE_STATES and o.status in NON_OPEN_EXCHANGE_STATES
] ]
def select_filled_or_open_orders(self) -> List["Order"]: def select_filled_or_open_orders(self) -> list["Order"]:
""" """
Finds filled or open orders Finds filled or open orders
:param order_side: Side of the order (either 'buy', 'sell', or None) :param order_side: Side of the order (either 'buy', 'sell', or None)
@ -1341,7 +1342,7 @@ class LocalTrade:
return data[0] return data[0]
return None return None
def get_all_custom_data(self) -> List[_CustomData]: def get_all_custom_data(self) -> list[_CustomData]:
""" """
Get all custom data for this trade Get all custom data for this trade
""" """
@ -1399,7 +1400,7 @@ class LocalTrade:
is_open: Optional[bool] = None, is_open: Optional[bool] = None,
open_date: Optional[datetime] = None, open_date: Optional[datetime] = None,
close_date: Optional[datetime] = None, close_date: Optional[datetime] = None,
) -> List["LocalTrade"]: ) -> list["LocalTrade"]:
""" """
Helper function to query Trades. Helper function to query Trades.
Returns a List of trades, filtered on the parameters given. Returns a List of trades, filtered on the parameters given.
@ -1460,7 +1461,7 @@ class LocalTrade:
LocalTrade.bt_open_open_trade_count -= 1 LocalTrade.bt_open_open_trade_count -= 1
@staticmethod @staticmethod
def get_open_trades() -> List[Any]: def get_open_trades() -> list[Any]:
""" """
Retrieve open trades Retrieve open trades
""" """
@ -1609,10 +1610,10 @@ class Trade(ModelBase, LocalTrade):
id: Mapped[int] = mapped_column(Integer, primary_key=True) # type: ignore id: Mapped[int] = mapped_column(Integer, primary_key=True) # type: ignore
orders: Mapped[List[Order]] = relationship( orders: Mapped[list[Order]] = relationship(
"Order", order_by="Order.id", cascade="all, delete-orphan", lazy="selectin", innerjoin=True "Order", order_by="Order.id", cascade="all, delete-orphan", lazy="selectin", innerjoin=True
) # type: ignore ) # type: ignore
custom_data: Mapped[List[_CustomData]] = relationship( custom_data: Mapped[list[_CustomData]] = relationship(
"_CustomData", cascade="all, delete-orphan", lazy="raise" "_CustomData", cascade="all, delete-orphan", lazy="raise"
) )
@ -1759,7 +1760,7 @@ class Trade(ModelBase, LocalTrade):
is_open: Optional[bool] = None, is_open: Optional[bool] = None,
open_date: Optional[datetime] = None, open_date: Optional[datetime] = None,
close_date: Optional[datetime] = None, close_date: Optional[datetime] = None,
) -> List["LocalTrade"]: ) -> list["LocalTrade"]:
""" """
Helper function to query Trades.j Helper function to query Trades.j
Returns a List of trades, filtered on the parameters given. Returns a List of trades, filtered on the parameters given.
@ -1778,7 +1779,7 @@ class Trade(ModelBase, LocalTrade):
trade_filter.append(Trade.close_date > close_date) trade_filter.append(Trade.close_date > close_date)
if is_open is not None: if is_open is not None:
trade_filter.append(Trade.is_open.is_(is_open)) trade_filter.append(Trade.is_open.is_(is_open))
return cast(List[LocalTrade], Trade.get_trades(trade_filter).all()) return cast(list[LocalTrade], Trade.get_trades(trade_filter).all())
else: else:
return LocalTrade.get_trades_proxy( return LocalTrade.get_trades_proxy(
pair=pair, is_open=is_open, open_date=open_date, close_date=close_date pair=pair, is_open=is_open, open_date=open_date, close_date=close_date
@ -1886,12 +1887,12 @@ class Trade(ModelBase, LocalTrade):
return total_open_stake_amount or 0 return total_open_stake_amount or 0
@staticmethod @staticmethod
def get_overall_performance(minutes=None) -> List[Dict[str, Any]]: def get_overall_performance(minutes=None) -> list[dict[str, Any]]:
""" """
Returns List of dicts containing all Trades, including profit and trade count Returns List of dicts containing all Trades, including profit and trade count
NOTE: Not supported in Backtesting. NOTE: Not supported in Backtesting.
""" """
filters: List = [Trade.is_open.is_(False)] filters: list = [Trade.is_open.is_(False)]
if minutes: if minutes:
start_date = datetime.now(timezone.utc) - timedelta(minutes=minutes) start_date = datetime.now(timezone.utc) - timedelta(minutes=minutes)
filters.append(Trade.close_date >= start_date) filters.append(Trade.close_date >= start_date)
@ -1921,14 +1922,14 @@ class Trade(ModelBase, LocalTrade):
] ]
@staticmethod @staticmethod
def get_enter_tag_performance(pair: Optional[str]) -> List[Dict[str, Any]]: def get_enter_tag_performance(pair: Optional[str]) -> list[dict[str, Any]]:
""" """
Returns List of dicts containing all Trades, based on buy tag performance Returns List of dicts containing all Trades, based on buy tag performance
Can either be average for all pairs or a specific pair provided Can either be average for all pairs or a specific pair provided
NOTE: Not supported in Backtesting. NOTE: Not supported in Backtesting.
""" """
filters: List = [Trade.is_open.is_(False)] filters: list = [Trade.is_open.is_(False)]
if pair is not None: if pair is not None:
filters.append(Trade.pair == pair) filters.append(Trade.pair == pair)
@ -1956,14 +1957,14 @@ class Trade(ModelBase, LocalTrade):
] ]
@staticmethod @staticmethod
def get_exit_reason_performance(pair: Optional[str]) -> List[Dict[str, Any]]: def get_exit_reason_performance(pair: Optional[str]) -> list[dict[str, Any]]:
""" """
Returns List of dicts containing all Trades, based on exit reason performance Returns List of dicts containing all Trades, based on exit reason performance
Can either be average for all pairs or a specific pair provided Can either be average for all pairs or a specific pair provided
NOTE: Not supported in Backtesting. NOTE: Not supported in Backtesting.
""" """
filters: List = [Trade.is_open.is_(False)] filters: list = [Trade.is_open.is_(False)]
if pair is not None: if pair is not None:
filters.append(Trade.pair == pair) filters.append(Trade.pair == pair)
sell_tag_perf = Trade.session.execute( sell_tag_perf = Trade.session.execute(
@ -1990,14 +1991,14 @@ class Trade(ModelBase, LocalTrade):
] ]
@staticmethod @staticmethod
def get_mix_tag_performance(pair: Optional[str]) -> List[Dict[str, Any]]: def get_mix_tag_performance(pair: Optional[str]) -> list[dict[str, Any]]:
""" """
Returns List of dicts containing all Trades, based on entry_tag + exit_reason performance Returns List of dicts containing all Trades, based on entry_tag + exit_reason performance
Can either be average for all pairs or a specific pair provided Can either be average for all pairs or a specific pair provided
NOTE: Not supported in Backtesting. NOTE: Not supported in Backtesting.
""" """
filters: List = [Trade.is_open.is_(False)] filters: list = [Trade.is_open.is_(False)]
if pair is not None: if pair is not None:
filters.append(Trade.pair == pair) filters.append(Trade.pair == pair)
mix_tag_perf = Trade.session.execute( mix_tag_perf = Trade.session.execute(
@ -2014,7 +2015,7 @@ class Trade(ModelBase, LocalTrade):
.order_by(desc("profit_sum_abs")) .order_by(desc("profit_sum_abs"))
).all() ).all()
resp: List[Dict] = [] resp: list[dict] = []
for _, enter_tag, exit_reason, profit, profit_abs, count in mix_tag_perf: for _, enter_tag, exit_reason, profit, profit_abs, count in mix_tag_perf:
enter_tag = enter_tag if enter_tag is not None else "Other" enter_tag = enter_tag if enter_tag is not None else "Other"
exit_reason = exit_reason if exit_reason is not None else "Other" exit_reason = exit_reason if exit_reason is not None else "Other"
@ -2053,7 +2054,7 @@ class Trade(ModelBase, LocalTrade):
NOTE: Not supported in Backtesting. NOTE: Not supported in Backtesting.
:returns: Tuple containing (pair, profit_sum) :returns: Tuple containing (pair, profit_sum)
""" """
filters: List = [Trade.is_open.is_(False)] filters: list = [Trade.is_open.is_(False)]
if start_date: if start_date:
filters.append(Trade.close_date >= start_date) filters.append(Trade.close_date >= start_date)