mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
pytorch - data convertor - create tensor directly on device, simplify code
This commit is contained in:
parent
9cb45a3810
commit
0c9aa86885
|
@ -1,5 +1,4 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
@ -12,14 +11,14 @@ class PyTorchDataConvertor(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> torch.Tensor:
|
||||
def convert_x(self, df: pd.DataFrame, device: str) -> torch.Tensor:
|
||||
"""
|
||||
:param df: "*_features" dataframe.
|
||||
:param device: The device to use for training (e.g. 'cpu', 'cuda').
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> torch.Tensor:
|
||||
def convert_y(self, df: pd.DataFrame, device: str) -> torch.Tensor:
|
||||
"""
|
||||
:param df: "*_labels" dataframe.
|
||||
:param device: The device to use for training (e.g. 'cpu', 'cuda').
|
||||
|
@ -33,8 +32,8 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
target_tensor_type: Optional[torch.dtype] = None,
|
||||
squeeze_target_tensor: bool = False
|
||||
target_tensor_type: torch.dtype = torch.float32,
|
||||
squeeze_target_tensor: bool = False,
|
||||
):
|
||||
"""
|
||||
:param target_tensor_type: type of target tensor, for classification use
|
||||
|
@ -45,23 +44,14 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
|
|||
self._target_tensor_type = target_tensor_type
|
||||
self._squeeze_target_tensor = squeeze_target_tensor
|
||||
|
||||
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> torch.Tensor:
|
||||
x = torch.from_numpy(df.values).float()
|
||||
if device:
|
||||
x = x.to(device)
|
||||
|
||||
def convert_x(self, df: pd.DataFrame, device: str) -> torch.Tensor:
|
||||
numpy_arrays = df.values
|
||||
x = torch.tensor(numpy_arrays, device=device, dtype=torch.float32)
|
||||
return x
|
||||
|
||||
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> torch.Tensor:
|
||||
y = torch.from_numpy(df.values)
|
||||
|
||||
if self._target_tensor_type:
|
||||
y = y.to(self._target_tensor_type)
|
||||
|
||||
def convert_y(self, df: pd.DataFrame, device: str) -> torch.Tensor:
|
||||
numpy_arrays = df.values
|
||||
y = torch.tensor(numpy_arrays, device=device, dtype=self._target_tensor_type)
|
||||
if self._squeeze_target_tensor:
|
||||
y = y.squeeze()
|
||||
|
||||
if device:
|
||||
y = y.to(device)
|
||||
|
||||
return y
|
||||
|
|
Loading…
Reference in New Issue
Block a user