Simplify custom_data stuff

This commit is contained in:
Matthias 2024-02-12 20:25:26 +01:00
parent b7904b8e80
commit 8dda28351e
3 changed files with 52 additions and 35 deletions

View File

@ -14,7 +14,7 @@ from freqtrade.util import dt_now
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CustomData(ModelBase): class _CustomData(ModelBase):
""" """
CustomData database model CustomData database model
Keeps records of metadata as key/value store Keeps records of metadata as key/value store
@ -41,6 +41,9 @@ class CustomData(ModelBase):
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=dt_now) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=dt_now)
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# Empty container value - not persisted, but filled with cd_value on query
value: Any = None
def __repr__(self): def __repr__(self):
create_time = (self.created_at.strftime(DATETIME_PRINT_FORMAT) create_time = (self.created_at.strftime(DATETIME_PRINT_FORMAT)
if self.created_at is not None else None) if self.created_at is not None else None)
@ -52,7 +55,7 @@ class CustomData(ModelBase):
@classmethod @classmethod
def query_cd(cls, key: Optional[str] = None, def query_cd(cls, key: Optional[str] = None,
trade_id: Optional[int] = None) -> Sequence['CustomData']: trade_id: Optional[int] = None) -> Sequence['_CustomData']:
""" """
Get all CustomData, if trade_id is not specified Get all CustomData, if trade_id is not specified
return will be for generic values not tied to a trade return will be for generic values not tied to a trade
@ -60,11 +63,11 @@ class CustomData(ModelBase):
""" """
filters = [] filters = []
if trade_id is not None: if trade_id is not None:
filters.append(CustomData.ft_trade_id == trade_id) filters.append(_CustomData.ft_trade_id == trade_id)
if key is not None: if key is not None:
filters.append(CustomData.cd_key.ilike(key)) filters.append(_CustomData.cd_key.ilike(key))
return CustomData.session.scalars(select(CustomData).filter(*filters)).all() return _CustomData.session.scalars(select(_CustomData).filter(*filters)).all()
class CustomDataWrapper: class CustomDataWrapper:
@ -75,9 +78,15 @@ class CustomDataWrapper:
""" """
use_db = True use_db = True
custom_data: List[CustomData] = [] custom_data: List[_CustomData] = []
unserialized_types = ['bool', 'float', 'int', 'str'] unserialized_types = ['bool', 'float', 'int', 'str']
@staticmethod
def _convert_custom_data(data: _CustomData) -> _CustomData:
if data.cd_type not in CustomDataWrapper.unserialized_types:
data.value = json.loads(data.cd_value)
return data
@staticmethod @staticmethod
def reset_custom_data() -> None: def reset_custom_data() -> None:
""" """
@ -88,17 +97,15 @@ class CustomDataWrapper:
@staticmethod @staticmethod
def get_custom_data(key: Optional[str] = None, def get_custom_data(key: Optional[str] = None,
trade_id: Optional[int] = None) -> CustomData: trade_id: Optional[int] = None) -> List[_CustomData]:
if trade_id is None: if trade_id is None:
trade_id = 0 trade_id = 0
if CustomDataWrapper.use_db: if CustomDataWrapper.use_db:
filtered_custom_data = [] filtered_custom_data = _CustomData.session.scalars(select(_CustomData).filter(
for data_entry in CustomData.query_cd(trade_id=trade_id, key=key): _CustomData.ft_trade_id == trade_id,
if data_entry.cd_type not in CustomDataWrapper.unserialized_types: _CustomData.cd_key.ilike(key))).all()
data_entry.cd_value = json.loads(data_entry.cd_value)
filtered_custom_data.append(data_entry)
return filtered_custom_data
else: else:
filtered_custom_data = [ filtered_custom_data = [
data_entry for data_entry in CustomDataWrapper.custom_data data_entry for data_entry in CustomDataWrapper.custom_data
@ -109,19 +116,19 @@ class CustomDataWrapper:
data_entry for data_entry in filtered_custom_data data_entry for data_entry in filtered_custom_data
if (data_entry.cd_key.casefold() == key.casefold()) if (data_entry.cd_key.casefold() == key.casefold())
] ]
return filtered_custom_data return [CustomDataWrapper._convert_custom_data(d) for d in filtered_custom_data]
@staticmethod @staticmethod
def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None: def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None:
value_type = type(value).__name__ value_type = type(value).__name__
value_db = None
if value_type not in CustomDataWrapper.unserialized_types: if value_type not in CustomDataWrapper.unserialized_types:
try: try:
value_db = json.dumps(value) value_db = json.dumps(value)
except TypeError as e: except TypeError as e:
logger.warning(f"could not serialize {key} value due to {e}") logger.warning(f"could not serialize {key} value due to {e}")
return
else: else:
value_db = str(value) value_db = str(value)
@ -134,19 +141,19 @@ class CustomDataWrapper:
data_entry.cd_value = value_db data_entry.cd_value = value_db
data_entry.updated_at = dt_now() data_entry.updated_at = dt_now()
else: else:
data_entry = CustomData( data_entry = _CustomData(
ft_trade_id=trade_id, ft_trade_id=trade_id,
cd_key=key, cd_key=key,
cd_type=value_type, cd_type=value_type,
cd_value=value_db, cd_value=value_db,
created_at=dt_now() created_at=dt_now(),
) )
data_entry.value = value
if CustomDataWrapper.use_db and value_db is not None: if CustomDataWrapper.use_db and value_db is not None:
data_entry.cd_value = value_db _CustomData.session.add(data_entry)
CustomData.session.add(data_entry) _CustomData.session.commit()
CustomData.session.commit() else:
elif not CustomDataWrapper.use_db:
cd_index = -1 cd_index = -1
for index, data_entry in enumerate(CustomDataWrapper.custom_data): for index, data_entry in enumerate(CustomDataWrapper.custom_data):
if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key: if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key:
@ -161,11 +168,3 @@ class CustomDataWrapper:
CustomDataWrapper.custom_data[cd_index] = data_entry CustomDataWrapper.custom_data[cd_index] = data_entry
else: else:
CustomDataWrapper.custom_data.append(data_entry) CustomDataWrapper.custom_data.append(data_entry)
@staticmethod
def get_all_custom_data() -> List[CustomData]:
if CustomDataWrapper.use_db:
return list(CustomData.query_cd())
else:
return CustomDataWrapper.custom_data

View File

@ -13,7 +13,7 @@ from sqlalchemy.pool import StaticPool
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.persistence.base import ModelBase from freqtrade.persistence.base import ModelBase
from freqtrade.persistence.custom_data import CustomData from freqtrade.persistence.custom_data import _CustomData
from freqtrade.persistence.key_value_store import _KeyValueStoreModel from freqtrade.persistence.key_value_store import _KeyValueStoreModel
from freqtrade.persistence.migrations import check_migrate from freqtrade.persistence.migrations import check_migrate
from freqtrade.persistence.pairlock import PairLock from freqtrade.persistence.pairlock import PairLock
@ -79,7 +79,7 @@ def init_db(db_url: str) -> None:
Order.session = Trade.session Order.session = Trade.session
PairLock.session = Trade.session PairLock.session = Trade.session
_KeyValueStoreModel.session = Trade.session _KeyValueStoreModel.session = Trade.session
CustomData.session = scoped_session(sessionmaker(bind=engine, autoflush=True), _CustomData.session = scoped_session(sessionmaker(bind=engine, autoflush=True),
scopefunc=get_request_or_thread_id) scopefunc=get_request_or_thread_id)
previous_tables = inspect(engine).get_table_names() previous_tables = inspect(engine).get_table_names()

View File

@ -23,7 +23,7 @@ from freqtrade.exchange import (ROUND_DOWN, ROUND_UP, amount_to_contract_precisi
from freqtrade.leverage import interest from freqtrade.leverage import interest
from freqtrade.misc import safe_value_fallback from freqtrade.misc import safe_value_fallback
from freqtrade.persistence.base import ModelBase, SessionType from freqtrade.persistence.base import ModelBase, SessionType
from freqtrade.persistence.custom_data import CustomData, CustomDataWrapper from freqtrade.persistence.custom_data import CustomDataWrapper, _CustomData
from freqtrade.util import FtPrecise, dt_from_ts, dt_now, dt_ts from freqtrade.util import FtPrecise, dt_from_ts, dt_now, dt_ts
@ -1206,10 +1206,28 @@ class LocalTrade:
] ]
def set_custom_data(self, key: str, value: Any) -> None: def set_custom_data(self, key: str, value: Any) -> None:
"""
Set custom data for this trade
:param key: key of the custom data
:param value: value of the custom data (must be JSON serializable)
"""
CustomDataWrapper.set_custom_data(key=key, value=value, trade_id=self.id) CustomDataWrapper.set_custom_data(key=key, value=value, trade_id=self.id)
def get_custom_data(self, key: Optional[str]) -> List[_CustomData]: def get_custom_data(self, key: str) -> Optional[_CustomData]:
return CustomDataWrapper.get_custom_data(key=key, trade_id=self.id) """
Get custom data for this trade
:param key: key of the custom data
"""
data = CustomDataWrapper.get_custom_data(key=key, trade_id=self.id)
if data:
return data[0]
return None
def get_all_custom_data(self) -> List[_CustomData]:
"""
Get all custom data for this trade
"""
return CustomDataWrapper.get_custom_data(trade_id=self.id)
@property @property
def nr_of_successful_entries(self) -> int: def nr_of_successful_entries(self) -> int: