diff --git a/freqtrade/persistence/custom_data.py b/freqtrade/persistence/custom_data.py index bf6056278..e8fa0d960 100644 --- a/freqtrade/persistence/custom_data.py +++ b/freqtrade/persistence/custom_data.py @@ -14,7 +14,7 @@ from freqtrade.util import dt_now logger = logging.getLogger(__name__) -class CustomData(ModelBase): +class _CustomData(ModelBase): """ CustomData database model Keeps records of metadata as key/value store @@ -41,6 +41,9 @@ class CustomData(ModelBase): created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=dt_now) updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + # Empty container value - not persisted, but filled with cd_value on query + value: Any = None + def __repr__(self): create_time = (self.created_at.strftime(DATETIME_PRINT_FORMAT) if self.created_at is not None else None) @@ -52,7 +55,7 @@ class CustomData(ModelBase): @classmethod def query_cd(cls, key: Optional[str] = None, - trade_id: Optional[int] = None) -> Sequence['CustomData']: + trade_id: Optional[int] = None) -> Sequence['_CustomData']: """ Get all CustomData, if trade_id is not specified return will be for generic values not tied to a trade @@ -60,11 +63,11 @@ class CustomData(ModelBase): """ filters = [] if trade_id is not None: - filters.append(CustomData.ft_trade_id == trade_id) + filters.append(_CustomData.ft_trade_id == trade_id) if key is not None: - filters.append(CustomData.cd_key.ilike(key)) + filters.append(_CustomData.cd_key.ilike(key)) - return CustomData.session.scalars(select(CustomData).filter(*filters)).all() + return _CustomData.session.scalars(select(_CustomData).filter(*filters)).all() class CustomDataWrapper: @@ -75,9 +78,15 @@ class CustomDataWrapper: """ use_db = True - custom_data: List[CustomData] = [] + custom_data: List[_CustomData] = [] unserialized_types = ['bool', 'float', 'int', 'str'] + @staticmethod + def _convert_custom_data(data: _CustomData) -> _CustomData: + if data.cd_type not in CustomDataWrapper.unserialized_types: + data.value = json.loads(data.cd_value) + return data + @staticmethod def reset_custom_data() -> None: """ @@ -88,17 +97,15 @@ class CustomDataWrapper: @staticmethod def get_custom_data(key: Optional[str] = None, - trade_id: Optional[int] = None) -> CustomData: + trade_id: Optional[int] = None) -> List[_CustomData]: if trade_id is None: trade_id = 0 if CustomDataWrapper.use_db: - filtered_custom_data = [] - for data_entry in CustomData.query_cd(trade_id=trade_id, key=key): - if data_entry.cd_type not in CustomDataWrapper.unserialized_types: - data_entry.cd_value = json.loads(data_entry.cd_value) - filtered_custom_data.append(data_entry) - return filtered_custom_data + filtered_custom_data = _CustomData.session.scalars(select(_CustomData).filter( + _CustomData.ft_trade_id == trade_id, + _CustomData.cd_key.ilike(key))).all() + else: filtered_custom_data = [ data_entry for data_entry in CustomDataWrapper.custom_data @@ -109,19 +116,19 @@ class CustomDataWrapper: data_entry for data_entry in filtered_custom_data if (data_entry.cd_key.casefold() == key.casefold()) ] - return filtered_custom_data + return [CustomDataWrapper._convert_custom_data(d) for d in filtered_custom_data] @staticmethod def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None: value_type = type(value).__name__ - value_db = None if value_type not in CustomDataWrapper.unserialized_types: try: value_db = json.dumps(value) except TypeError as e: logger.warning(f"could not serialize {key} value due to {e}") + return else: value_db = str(value) @@ -134,19 +141,19 @@ class CustomDataWrapper: data_entry.cd_value = value_db data_entry.updated_at = dt_now() else: - data_entry = CustomData( + data_entry = _CustomData( ft_trade_id=trade_id, cd_key=key, cd_type=value_type, cd_value=value_db, - created_at=dt_now() + created_at=dt_now(), ) + data_entry.value = value if CustomDataWrapper.use_db and value_db is not None: - data_entry.cd_value = value_db - CustomData.session.add(data_entry) - CustomData.session.commit() - elif not CustomDataWrapper.use_db: + _CustomData.session.add(data_entry) + _CustomData.session.commit() + else: cd_index = -1 for index, data_entry in enumerate(CustomDataWrapper.custom_data): if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key: @@ -161,11 +168,3 @@ class CustomDataWrapper: CustomDataWrapper.custom_data[cd_index] = data_entry else: CustomDataWrapper.custom_data.append(data_entry) - - @staticmethod - def get_all_custom_data() -> List[CustomData]: - - if CustomDataWrapper.use_db: - return list(CustomData.query_cd()) - else: - return CustomDataWrapper.custom_data diff --git a/freqtrade/persistence/models.py b/freqtrade/persistence/models.py index 189b80fa6..1a69b271c 100644 --- a/freqtrade/persistence/models.py +++ b/freqtrade/persistence/models.py @@ -13,7 +13,7 @@ from sqlalchemy.pool import StaticPool from freqtrade.exceptions import OperationalException from freqtrade.persistence.base import ModelBase -from freqtrade.persistence.custom_data import CustomData +from freqtrade.persistence.custom_data import _CustomData from freqtrade.persistence.key_value_store import _KeyValueStoreModel from freqtrade.persistence.migrations import check_migrate from freqtrade.persistence.pairlock import PairLock @@ -79,8 +79,8 @@ def init_db(db_url: str) -> None: Order.session = Trade.session PairLock.session = Trade.session _KeyValueStoreModel.session = Trade.session - CustomData.session = scoped_session(sessionmaker(bind=engine, autoflush=True), - scopefunc=get_request_or_thread_id) + _CustomData.session = scoped_session(sessionmaker(bind=engine, autoflush=True), + scopefunc=get_request_or_thread_id) previous_tables = inspect(engine).get_table_names() ModelBase.metadata.create_all(engine) diff --git a/freqtrade/persistence/trade_model.py b/freqtrade/persistence/trade_model.py index 7487c72b3..55a075cc9 100644 --- a/freqtrade/persistence/trade_model.py +++ b/freqtrade/persistence/trade_model.py @@ -23,7 +23,7 @@ from freqtrade.exchange import (ROUND_DOWN, ROUND_UP, amount_to_contract_precisi from freqtrade.leverage import interest from freqtrade.misc import safe_value_fallback from freqtrade.persistence.base import ModelBase, SessionType -from freqtrade.persistence.custom_data import CustomData, CustomDataWrapper +from freqtrade.persistence.custom_data import CustomDataWrapper, _CustomData from freqtrade.util import FtPrecise, dt_from_ts, dt_now, dt_ts @@ -1206,10 +1206,28 @@ class LocalTrade: ] def set_custom_data(self, key: str, value: Any) -> None: + """ + Set custom data for this trade + :param key: key of the custom data + :param value: value of the custom data (must be JSON serializable) + """ CustomDataWrapper.set_custom_data(key=key, value=value, trade_id=self.id) - def get_custom_data(self, key: Optional[str]) -> List[_CustomData]: - return CustomDataWrapper.get_custom_data(key=key, trade_id=self.id) + def get_custom_data(self, key: str) -> Optional[_CustomData]: + """ + Get custom data for this trade + :param key: key of the custom data + """ + data = CustomDataWrapper.get_custom_data(key=key, trade_id=self.id) + if data: + return data[0] + return None + + def get_all_custom_data(self) -> List[_CustomData]: + """ + Get all custom data for this trade + """ + return CustomDataWrapper.get_custom_data(trade_id=self.id) @property def nr_of_successful_entries(self) -> int: