From 6b204c97ed44c4e1c8258a3cebfda8ea2694f44d Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Mon, 3 Apr 2023 19:02:07 +0300 Subject: [PATCH] fix pytorch data convertor type hints --- freqtrade/freqai/torch/PyTorchDataConvertor.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/freqtrade/freqai/torch/PyTorchDataConvertor.py b/freqtrade/freqai/torch/PyTorchDataConvertor.py index e7d5c3ffe..a31ccdc79 100644 --- a/freqtrade/freqai/torch/PyTorchDataConvertor.py +++ b/freqtrade/freqai/torch/PyTorchDataConvertor.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Tuple +from typing import List, Optional import pandas as pd import torch @@ -12,19 +12,17 @@ class PyTorchDataConvertor(ABC): """ @abstractmethod - def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: + def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]: """ :param df: "*_features" dataframe. :param device: The device to use for training (e.g. 'cpu', 'cuda'). - :returns: tuple of tensors. """ @abstractmethod - def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: + def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]: """ :param df: "*_labels" dataframe. :param device: The device to use for training (e.g. 'cpu', 'cuda'). - :returns: tuple of tensors. """ @@ -47,14 +45,14 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor): self._target_tensor_type = target_tensor_type self._squeeze_target_tensor = squeeze_target_tensor - def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: + def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]: x = torch.from_numpy(df.values).float() if device: x = x.to(device) - return x, + return [x] - def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: + def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]: y = torch.from_numpy(df.values) if self._target_tensor_type: @@ -66,4 +64,4 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor): if device: y = y.to(device) - return y, + return [y]