mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
Simplify custom_data stuff
This commit is contained in:
parent
b7904b8e80
commit
8dda28351e
|
@ -14,7 +14,7 @@ from freqtrade.util import dt_now
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CustomData(ModelBase):
|
class _CustomData(ModelBase):
|
||||||
"""
|
"""
|
||||||
CustomData database model
|
CustomData database model
|
||||||
Keeps records of metadata as key/value store
|
Keeps records of metadata as key/value store
|
||||||
|
@ -41,6 +41,9 @@ class CustomData(ModelBase):
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=dt_now)
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=dt_now)
|
||||||
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||||
|
|
||||||
|
# Empty container value - not persisted, but filled with cd_value on query
|
||||||
|
value: Any = None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
create_time = (self.created_at.strftime(DATETIME_PRINT_FORMAT)
|
create_time = (self.created_at.strftime(DATETIME_PRINT_FORMAT)
|
||||||
if self.created_at is not None else None)
|
if self.created_at is not None else None)
|
||||||
|
@ -52,7 +55,7 @@ class CustomData(ModelBase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def query_cd(cls, key: Optional[str] = None,
|
def query_cd(cls, key: Optional[str] = None,
|
||||||
trade_id: Optional[int] = None) -> Sequence['CustomData']:
|
trade_id: Optional[int] = None) -> Sequence['_CustomData']:
|
||||||
"""
|
"""
|
||||||
Get all CustomData, if trade_id is not specified
|
Get all CustomData, if trade_id is not specified
|
||||||
return will be for generic values not tied to a trade
|
return will be for generic values not tied to a trade
|
||||||
|
@ -60,11 +63,11 @@ class CustomData(ModelBase):
|
||||||
"""
|
"""
|
||||||
filters = []
|
filters = []
|
||||||
if trade_id is not None:
|
if trade_id is not None:
|
||||||
filters.append(CustomData.ft_trade_id == trade_id)
|
filters.append(_CustomData.ft_trade_id == trade_id)
|
||||||
if key is not None:
|
if key is not None:
|
||||||
filters.append(CustomData.cd_key.ilike(key))
|
filters.append(_CustomData.cd_key.ilike(key))
|
||||||
|
|
||||||
return CustomData.session.scalars(select(CustomData).filter(*filters)).all()
|
return _CustomData.session.scalars(select(_CustomData).filter(*filters)).all()
|
||||||
|
|
||||||
|
|
||||||
class CustomDataWrapper:
|
class CustomDataWrapper:
|
||||||
|
@ -75,9 +78,15 @@ class CustomDataWrapper:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
use_db = True
|
use_db = True
|
||||||
custom_data: List[CustomData] = []
|
custom_data: List[_CustomData] = []
|
||||||
unserialized_types = ['bool', 'float', 'int', 'str']
|
unserialized_types = ['bool', 'float', 'int', 'str']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_custom_data(data: _CustomData) -> _CustomData:
|
||||||
|
if data.cd_type not in CustomDataWrapper.unserialized_types:
|
||||||
|
data.value = json.loads(data.cd_value)
|
||||||
|
return data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def reset_custom_data() -> None:
|
def reset_custom_data() -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -88,17 +97,15 @@ class CustomDataWrapper:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_custom_data(key: Optional[str] = None,
|
def get_custom_data(key: Optional[str] = None,
|
||||||
trade_id: Optional[int] = None) -> CustomData:
|
trade_id: Optional[int] = None) -> List[_CustomData]:
|
||||||
if trade_id is None:
|
if trade_id is None:
|
||||||
trade_id = 0
|
trade_id = 0
|
||||||
|
|
||||||
if CustomDataWrapper.use_db:
|
if CustomDataWrapper.use_db:
|
||||||
filtered_custom_data = []
|
filtered_custom_data = _CustomData.session.scalars(select(_CustomData).filter(
|
||||||
for data_entry in CustomData.query_cd(trade_id=trade_id, key=key):
|
_CustomData.ft_trade_id == trade_id,
|
||||||
if data_entry.cd_type not in CustomDataWrapper.unserialized_types:
|
_CustomData.cd_key.ilike(key))).all()
|
||||||
data_entry.cd_value = json.loads(data_entry.cd_value)
|
|
||||||
filtered_custom_data.append(data_entry)
|
|
||||||
return filtered_custom_data
|
|
||||||
else:
|
else:
|
||||||
filtered_custom_data = [
|
filtered_custom_data = [
|
||||||
data_entry for data_entry in CustomDataWrapper.custom_data
|
data_entry for data_entry in CustomDataWrapper.custom_data
|
||||||
|
@ -109,19 +116,19 @@ class CustomDataWrapper:
|
||||||
data_entry for data_entry in filtered_custom_data
|
data_entry for data_entry in filtered_custom_data
|
||||||
if (data_entry.cd_key.casefold() == key.casefold())
|
if (data_entry.cd_key.casefold() == key.casefold())
|
||||||
]
|
]
|
||||||
return filtered_custom_data
|
return [CustomDataWrapper._convert_custom_data(d) for d in filtered_custom_data]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None:
|
def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None:
|
||||||
|
|
||||||
value_type = type(value).__name__
|
value_type = type(value).__name__
|
||||||
value_db = None
|
|
||||||
|
|
||||||
if value_type not in CustomDataWrapper.unserialized_types:
|
if value_type not in CustomDataWrapper.unserialized_types:
|
||||||
try:
|
try:
|
||||||
value_db = json.dumps(value)
|
value_db = json.dumps(value)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
logger.warning(f"could not serialize {key} value due to {e}")
|
logger.warning(f"could not serialize {key} value due to {e}")
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
value_db = str(value)
|
value_db = str(value)
|
||||||
|
|
||||||
|
@ -134,19 +141,19 @@ class CustomDataWrapper:
|
||||||
data_entry.cd_value = value_db
|
data_entry.cd_value = value_db
|
||||||
data_entry.updated_at = dt_now()
|
data_entry.updated_at = dt_now()
|
||||||
else:
|
else:
|
||||||
data_entry = CustomData(
|
data_entry = _CustomData(
|
||||||
ft_trade_id=trade_id,
|
ft_trade_id=trade_id,
|
||||||
cd_key=key,
|
cd_key=key,
|
||||||
cd_type=value_type,
|
cd_type=value_type,
|
||||||
cd_value=value_db,
|
cd_value=value_db,
|
||||||
created_at=dt_now()
|
created_at=dt_now(),
|
||||||
)
|
)
|
||||||
|
data_entry.value = value
|
||||||
|
|
||||||
if CustomDataWrapper.use_db and value_db is not None:
|
if CustomDataWrapper.use_db and value_db is not None:
|
||||||
data_entry.cd_value = value_db
|
_CustomData.session.add(data_entry)
|
||||||
CustomData.session.add(data_entry)
|
_CustomData.session.commit()
|
||||||
CustomData.session.commit()
|
else:
|
||||||
elif not CustomDataWrapper.use_db:
|
|
||||||
cd_index = -1
|
cd_index = -1
|
||||||
for index, data_entry in enumerate(CustomDataWrapper.custom_data):
|
for index, data_entry in enumerate(CustomDataWrapper.custom_data):
|
||||||
if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key:
|
if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key:
|
||||||
|
@ -161,11 +168,3 @@ class CustomDataWrapper:
|
||||||
CustomDataWrapper.custom_data[cd_index] = data_entry
|
CustomDataWrapper.custom_data[cd_index] = data_entry
|
||||||
else:
|
else:
|
||||||
CustomDataWrapper.custom_data.append(data_entry)
|
CustomDataWrapper.custom_data.append(data_entry)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_all_custom_data() -> List[CustomData]:
|
|
||||||
|
|
||||||
if CustomDataWrapper.use_db:
|
|
||||||
return list(CustomData.query_cd())
|
|
||||||
else:
|
|
||||||
return CustomDataWrapper.custom_data
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
from freqtrade.exceptions import OperationalException
|
from freqtrade.exceptions import OperationalException
|
||||||
from freqtrade.persistence.base import ModelBase
|
from freqtrade.persistence.base import ModelBase
|
||||||
from freqtrade.persistence.custom_data import CustomData
|
from freqtrade.persistence.custom_data import _CustomData
|
||||||
from freqtrade.persistence.key_value_store import _KeyValueStoreModel
|
from freqtrade.persistence.key_value_store import _KeyValueStoreModel
|
||||||
from freqtrade.persistence.migrations import check_migrate
|
from freqtrade.persistence.migrations import check_migrate
|
||||||
from freqtrade.persistence.pairlock import PairLock
|
from freqtrade.persistence.pairlock import PairLock
|
||||||
|
@ -79,7 +79,7 @@ def init_db(db_url: str) -> None:
|
||||||
Order.session = Trade.session
|
Order.session = Trade.session
|
||||||
PairLock.session = Trade.session
|
PairLock.session = Trade.session
|
||||||
_KeyValueStoreModel.session = Trade.session
|
_KeyValueStoreModel.session = Trade.session
|
||||||
CustomData.session = scoped_session(sessionmaker(bind=engine, autoflush=True),
|
_CustomData.session = scoped_session(sessionmaker(bind=engine, autoflush=True),
|
||||||
scopefunc=get_request_or_thread_id)
|
scopefunc=get_request_or_thread_id)
|
||||||
|
|
||||||
previous_tables = inspect(engine).get_table_names()
|
previous_tables = inspect(engine).get_table_names()
|
||||||
|
|
|
@ -23,7 +23,7 @@ from freqtrade.exchange import (ROUND_DOWN, ROUND_UP, amount_to_contract_precisi
|
||||||
from freqtrade.leverage import interest
|
from freqtrade.leverage import interest
|
||||||
from freqtrade.misc import safe_value_fallback
|
from freqtrade.misc import safe_value_fallback
|
||||||
from freqtrade.persistence.base import ModelBase, SessionType
|
from freqtrade.persistence.base import ModelBase, SessionType
|
||||||
from freqtrade.persistence.custom_data import CustomData, CustomDataWrapper
|
from freqtrade.persistence.custom_data import CustomDataWrapper, _CustomData
|
||||||
from freqtrade.util import FtPrecise, dt_from_ts, dt_now, dt_ts
|
from freqtrade.util import FtPrecise, dt_from_ts, dt_now, dt_ts
|
||||||
|
|
||||||
|
|
||||||
|
@ -1206,10 +1206,28 @@ class LocalTrade:
|
||||||
]
|
]
|
||||||
|
|
||||||
def set_custom_data(self, key: str, value: Any) -> None:
|
def set_custom_data(self, key: str, value: Any) -> None:
|
||||||
|
"""
|
||||||
|
Set custom data for this trade
|
||||||
|
:param key: key of the custom data
|
||||||
|
:param value: value of the custom data (must be JSON serializable)
|
||||||
|
"""
|
||||||
CustomDataWrapper.set_custom_data(key=key, value=value, trade_id=self.id)
|
CustomDataWrapper.set_custom_data(key=key, value=value, trade_id=self.id)
|
||||||
|
|
||||||
def get_custom_data(self, key: Optional[str]) -> List[_CustomData]:
|
def get_custom_data(self, key: str) -> Optional[_CustomData]:
|
||||||
return CustomDataWrapper.get_custom_data(key=key, trade_id=self.id)
|
"""
|
||||||
|
Get custom data for this trade
|
||||||
|
:param key: key of the custom data
|
||||||
|
"""
|
||||||
|
data = CustomDataWrapper.get_custom_data(key=key, trade_id=self.id)
|
||||||
|
if data:
|
||||||
|
return data[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_all_custom_data(self) -> List[_CustomData]:
|
||||||
|
"""
|
||||||
|
Get all custom data for this trade
|
||||||
|
"""
|
||||||
|
return CustomDataWrapper.get_custom_data(trade_id=self.id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nr_of_successful_entries(self) -> int:
|
def nr_of_successful_entries(self) -> int:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user