chore: update persistence to modern typing syntax

This commit is contained in:
Matthias 2024-10-04 06:55:05 +02:00
parent 2e69e38adb
commit e9a6ba03f9
6 changed files with 56 additions and 53 deletions

View File

@ -1,7 +1,8 @@
import json
import logging
from collections.abc import Sequence
from datetime import datetime
from typing import Any, ClassVar, List, Optional, Sequence
from typing import Any, ClassVar, Optional
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, select
from sqlalchemy.orm import Mapped, mapped_column, relationship
@ -85,7 +86,7 @@ class CustomDataWrapper:
"""
use_db = True
custom_data: List[_CustomData] = []
custom_data: list[_CustomData] = []
unserialized_types = ["bool", "float", "int", "str"]
@staticmethod
@ -116,7 +117,7 @@ class CustomDataWrapper:
_CustomData.session.commit()
@staticmethod
def get_custom_data(*, trade_id: int, key: Optional[str] = None) -> List[_CustomData]:
def get_custom_data(*, trade_id: int, key: Optional[str] = None) -> list[_CustomData]:
if CustomDataWrapper.use_db:
filters = [
_CustomData.ft_trade_id == trade_id,

View File

@ -1,5 +1,5 @@
import logging
from typing import List, Optional
from typing import Optional
from sqlalchemy import inspect, select, text, update
@ -10,19 +10,19 @@ from freqtrade.persistence.trade_model import Order, Trade
logger = logging.getLogger(__name__)
def get_table_names_for_table(inspector, tabletype) -> List[str]:
def get_table_names_for_table(inspector, tabletype) -> list[str]:
return [t for t in inspector.get_table_names() if t.startswith(tabletype)]
def has_column(columns: List, searchname: str) -> bool:
def has_column(columns: list, searchname: str) -> bool:
return len(list(filter(lambda x: x["name"] == searchname, columns))) == 1
def get_column_def(columns: List, column: str, default: str) -> str:
def get_column_def(columns: list, column: str, default: str) -> str:
return default if not has_column(columns, column) else column
def get_backup_name(tabs: List[str], backup_prefix: str):
def get_backup_name(tabs: list[str], backup_prefix: str):
table_back_name = backup_prefix
for i, table_back_name in enumerate(tabs):
table_back_name = f"{backup_prefix}{i}"
@ -77,9 +77,9 @@ def migrate_trades_and_orders_table(
inspector,
engine,
trade_back_name: str,
cols: List,
cols: list,
order_back_name: str,
cols_order: List,
cols_order: list,
):
base_currency = get_column_def(cols, "base_currency", "null")
stake_currency = get_column_def(cols, "stake_currency", "null")
@ -230,7 +230,7 @@ def drop_orders_table(engine, table_back_name: str):
connection.execute(text("drop table orders"))
def migrate_orders_table(engine, table_back_name: str, cols_order: List):
def migrate_orders_table(engine, table_back_name: str, cols_order: list):
ft_fee_base = get_column_def(cols_order, "ft_fee_base", "null")
average = get_column_def(cols_order, "average", "null")
stop_price = get_column_def(cols_order, "stop_price", "null")
@ -262,7 +262,7 @@ def migrate_orders_table(engine, table_back_name: str, cols_order: List):
)
def migrate_pairlocks_table(decl_base, inspector, engine, pairlock_back_name: str, cols: List):
def migrate_pairlocks_table(decl_base, inspector, engine, pairlock_back_name: str, cols: list):
# Schema migration necessary
with engine.begin() as connection:
connection.execute(text(f"alter table pairlocks rename to {pairlock_back_name}"))

View File

@ -5,7 +5,7 @@ This module contains the class to persist trades into SQLite
import logging
import threading
from contextvars import ContextVar
from typing import Any, Dict, Final, Optional
from typing import Any, Final, Optional
from sqlalchemy import create_engine, inspect
from sqlalchemy.exc import NoSuchModuleError
@ -51,7 +51,7 @@ def init_db(db_url: str) -> None:
:param db_url: Database to use
:return: None
"""
kwargs: Dict[str, Any] = {}
kwargs: dict[str, Any] = {}
if db_url == "sqlite:///":
raise OperationalException(

View File

@ -1,5 +1,5 @@
from datetime import datetime, timezone
from typing import Any, ClassVar, Dict, Optional
from typing import Any, ClassVar, Optional
from sqlalchemy import ScalarResult, String, or_, select
from sqlalchemy.orm import Mapped, mapped_column
@ -64,7 +64,7 @@ class PairLock(ModelBase):
def get_all_locks() -> ScalarResult["PairLock"]:
return PairLock.session.scalars(select(PairLock))
def to_json(self) -> Dict[str, Any]:
def to_json(self) -> dict[str, Any]:
return {
"id": self.id,
"pair": self.pair,

View File

@ -1,6 +1,7 @@
import logging
from collections.abc import Sequence
from datetime import datetime, timezone
from typing import List, Optional, Sequence
from typing import Optional
from sqlalchemy import select
@ -19,7 +20,7 @@ class PairLocks:
"""
use_db = True
locks: List[PairLock] = []
locks: list[PairLock] = []
timeframe: str = ""

View File

@ -4,10 +4,11 @@ This module contains the class to persist trades into SQLite
import logging
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from math import isclose
from typing import Any, ClassVar, Dict, List, Optional, Sequence, cast
from typing import Any, ClassVar, Optional, cast
from sqlalchemy import (
Enum,
@ -215,8 +216,8 @@ class Order(ModelBase):
)
self.order_update_date = datetime.now(timezone.utc)
def to_ccxt_object(self, stopPriceName: str = "stopPrice") -> Dict[str, Any]:
order: Dict[str, Any] = {
def to_ccxt_object(self, stopPriceName: str = "stopPrice") -> dict[str, Any]:
order: dict[str, Any] = {
"id": self.order_id,
"symbol": self.ft_pair,
"price": self.price,
@ -243,7 +244,7 @@ class Order(ModelBase):
return order
def to_json(self, entry_side: str, minified: bool = False) -> Dict[str, Any]:
def to_json(self, entry_side: str, minified: bool = False) -> dict[str, Any]:
"""
:param minified: If True, only return a subset of the data is returned.
Only used for backtesting.
@ -308,7 +309,7 @@ class Order(ModelBase):
trade.adjust_stop_loss(trade.open_rate, trade.stop_loss_pct)
@staticmethod
def update_orders(orders: List["Order"], order: Dict[str, Any]):
def update_orders(orders: list["Order"], order: dict[str, Any]):
"""
Get all non-closed orders - useful when trying to batch-update orders
"""
@ -327,7 +328,7 @@ class Order(ModelBase):
@classmethod
def parse_from_ccxt_object(
cls,
order: Dict[str, Any],
order: dict[str, Any],
pair: str,
side: str,
amount: Optional[float] = None,
@ -373,17 +374,17 @@ class LocalTrade:
use_db: bool = False
# Trades container for backtesting
bt_trades: List["LocalTrade"] = []
bt_trades_open: List["LocalTrade"] = []
bt_trades: list["LocalTrade"] = []
bt_trades_open: list["LocalTrade"] = []
# Copy of trades_open - but indexed by pair
bt_trades_open_pp: Dict[str, List["LocalTrade"]] = defaultdict(list)
bt_trades_open_pp: dict[str, list["LocalTrade"]] = defaultdict(list)
bt_open_open_trade_count: int = 0
bt_total_profit: float = 0
realized_profit: float = 0
id: int = 0
orders: List[Order] = []
orders: list[Order] = []
exchange: str = ""
pair: str = ""
@ -569,7 +570,7 @@ class LocalTrade:
return ""
@property
def open_orders(self) -> List[Order]:
def open_orders(self) -> list[Order]:
"""
All open orders for this trade excluding stoploss orders
"""
@ -586,7 +587,7 @@ class LocalTrade:
return len(open_orders_wo_sl) > 0
@property
def open_sl_orders(self) -> List[Order]:
def open_sl_orders(self) -> list[Order]:
"""
All open stoploss orders for this trade
"""
@ -603,14 +604,14 @@ class LocalTrade:
return len(open_sl_orders) > 0
@property
def sl_orders(self) -> List[Order]:
def sl_orders(self) -> list[Order]:
"""
All stoploss orders for this trade
"""
return [o for o in self.orders if o.ft_order_side in ["stoploss"]]
@property
def open_orders_ids(self) -> List[str]:
def open_orders_ids(self) -> list[str]:
open_orders_ids_wo_sl = [
oo.order_id for oo in self.open_orders if oo.ft_order_side not in ["stoploss"]
]
@ -637,7 +638,7 @@ class LocalTrade:
f"open_rate={self.open_rate:.8f}, open_since={open_since})"
)
def to_json(self, minified: bool = False) -> Dict[str, Any]:
def to_json(self, minified: bool = False) -> dict[str, Any]:
"""
:param minified: If True, only return a subset of the data is returned.
Only used for backtesting.
@ -956,7 +957,7 @@ class LocalTrade:
else:
return False
def update_order(self, order: Dict) -> None:
def update_order(self, order: dict) -> None:
Order.update_orders(self.orders, order)
@property
@ -1280,7 +1281,7 @@ class LocalTrade:
else:
return None
def select_filled_orders(self, order_side: Optional[str] = None) -> List["Order"]:
def select_filled_orders(self, order_side: Optional[str] = None) -> list["Order"]:
"""
Finds filled orders for this order side.
Will not return open orders which already partially filled.
@ -1296,7 +1297,7 @@ class LocalTrade:
and o.status in NON_OPEN_EXCHANGE_STATES
]
def select_filled_or_open_orders(self) -> List["Order"]:
def select_filled_or_open_orders(self) -> list["Order"]:
"""
Finds filled or open orders
:param order_side: Side of the order (either 'buy', 'sell', or None)
@ -1341,7 +1342,7 @@ class LocalTrade:
return data[0]
return None
def get_all_custom_data(self) -> List[_CustomData]:
def get_all_custom_data(self) -> list[_CustomData]:
"""
Get all custom data for this trade
"""
@ -1399,7 +1400,7 @@ class LocalTrade:
is_open: Optional[bool] = None,
open_date: Optional[datetime] = None,
close_date: Optional[datetime] = None,
) -> List["LocalTrade"]:
) -> list["LocalTrade"]:
"""
Helper function to query Trades.
Returns a List of trades, filtered on the parameters given.
@ -1460,7 +1461,7 @@ class LocalTrade:
LocalTrade.bt_open_open_trade_count -= 1
@staticmethod
def get_open_trades() -> List[Any]:
def get_open_trades() -> list[Any]:
"""
Retrieve open trades
"""
@ -1609,10 +1610,10 @@ class Trade(ModelBase, LocalTrade):
id: Mapped[int] = mapped_column(Integer, primary_key=True) # type: ignore
orders: Mapped[List[Order]] = relationship(
orders: Mapped[list[Order]] = relationship(
"Order", order_by="Order.id", cascade="all, delete-orphan", lazy="selectin", innerjoin=True
) # type: ignore
custom_data: Mapped[List[_CustomData]] = relationship(
custom_data: Mapped[list[_CustomData]] = relationship(
"_CustomData", cascade="all, delete-orphan", lazy="raise"
)
@ -1759,7 +1760,7 @@ class Trade(ModelBase, LocalTrade):
is_open: Optional[bool] = None,
open_date: Optional[datetime] = None,
close_date: Optional[datetime] = None,
) -> List["LocalTrade"]:
) -> list["LocalTrade"]:
"""
Helper function to query Trades.j
Returns a List of trades, filtered on the parameters given.
@ -1778,7 +1779,7 @@ class Trade(ModelBase, LocalTrade):
trade_filter.append(Trade.close_date > close_date)
if is_open is not None:
trade_filter.append(Trade.is_open.is_(is_open))
return cast(List[LocalTrade], Trade.get_trades(trade_filter).all())
return cast(list[LocalTrade], Trade.get_trades(trade_filter).all())
else:
return LocalTrade.get_trades_proxy(
pair=pair, is_open=is_open, open_date=open_date, close_date=close_date
@ -1886,12 +1887,12 @@ class Trade(ModelBase, LocalTrade):
return total_open_stake_amount or 0
@staticmethod
def get_overall_performance(minutes=None) -> List[Dict[str, Any]]:
def get_overall_performance(minutes=None) -> list[dict[str, Any]]:
"""
Returns List of dicts containing all Trades, including profit and trade count
NOTE: Not supported in Backtesting.
"""
filters: List = [Trade.is_open.is_(False)]
filters: list = [Trade.is_open.is_(False)]
if minutes:
start_date = datetime.now(timezone.utc) - timedelta(minutes=minutes)
filters.append(Trade.close_date >= start_date)
@ -1921,14 +1922,14 @@ class Trade(ModelBase, LocalTrade):
]
@staticmethod
def get_enter_tag_performance(pair: Optional[str]) -> List[Dict[str, Any]]:
def get_enter_tag_performance(pair: Optional[str]) -> list[dict[str, Any]]:
"""
Returns List of dicts containing all Trades, based on buy tag performance
Can either be average for all pairs or a specific pair provided
NOTE: Not supported in Backtesting.
"""
filters: List = [Trade.is_open.is_(False)]
filters: list = [Trade.is_open.is_(False)]
if pair is not None:
filters.append(Trade.pair == pair)
@ -1956,14 +1957,14 @@ class Trade(ModelBase, LocalTrade):
]
@staticmethod
def get_exit_reason_performance(pair: Optional[str]) -> List[Dict[str, Any]]:
def get_exit_reason_performance(pair: Optional[str]) -> list[dict[str, Any]]:
"""
Returns List of dicts containing all Trades, based on exit reason performance
Can either be average for all pairs or a specific pair provided
NOTE: Not supported in Backtesting.
"""
filters: List = [Trade.is_open.is_(False)]
filters: list = [Trade.is_open.is_(False)]
if pair is not None:
filters.append(Trade.pair == pair)
sell_tag_perf = Trade.session.execute(
@ -1990,14 +1991,14 @@ class Trade(ModelBase, LocalTrade):
]
@staticmethod
def get_mix_tag_performance(pair: Optional[str]) -> List[Dict[str, Any]]:
def get_mix_tag_performance(pair: Optional[str]) -> list[dict[str, Any]]:
"""
Returns List of dicts containing all Trades, based on entry_tag + exit_reason performance
Can either be average for all pairs or a specific pair provided
NOTE: Not supported in Backtesting.
"""
filters: List = [Trade.is_open.is_(False)]
filters: list = [Trade.is_open.is_(False)]
if pair is not None:
filters.append(Trade.pair == pair)
mix_tag_perf = Trade.session.execute(
@ -2014,7 +2015,7 @@ class Trade(ModelBase, LocalTrade):
.order_by(desc("profit_sum_abs"))
).all()
resp: List[Dict] = []
resp: list[dict] = []
for _, enter_tag, exit_reason, profit, profit_abs, count in mix_tag_perf:
enter_tag = enter_tag if enter_tag is not None else "Other"
exit_reason = exit_reason if exit_reason is not None else "Other"
@ -2053,7 +2054,7 @@ class Trade(ModelBase, LocalTrade):
NOTE: Not supported in Backtesting.
:returns: Tuple containing (pair, profit_sum)
"""
filters: List = [Trade.is_open.is_(False)]
filters: list = [Trade.is_open.is_(False)]
if start_date:
filters.append(Trade.close_date >= start_date)