pytorch - data convertor - create tensor directly on device, simplify code

This commit is contained in:
yinon 2023-07-13 15:38:58 +00:00
parent 9cb45a3810
commit 0c9aa86885

View File

@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import Optional
import pandas as pd
import torch
@ -12,14 +11,14 @@ class PyTorchDataConvertor(ABC):
"""
@abstractmethod
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> torch.Tensor:
def convert_x(self, df: pd.DataFrame, device: str) -> torch.Tensor:
"""
:param df: "*_features" dataframe.
:param device: The device to use for training (e.g. 'cpu', 'cuda').
"""
@abstractmethod
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> torch.Tensor:
def convert_y(self, df: pd.DataFrame, device: str) -> torch.Tensor:
"""
:param df: "*_labels" dataframe.
:param device: The device to use for training (e.g. 'cpu', 'cuda').
@ -33,8 +32,8 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
def __init__(
self,
target_tensor_type: Optional[torch.dtype] = None,
squeeze_target_tensor: bool = False
target_tensor_type: torch.dtype = torch.float32,
squeeze_target_tensor: bool = False,
):
"""
:param target_tensor_type: type of target tensor, for classification use
@ -45,23 +44,14 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
self._target_tensor_type = target_tensor_type
self._squeeze_target_tensor = squeeze_target_tensor
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> torch.Tensor:
x = torch.from_numpy(df.values).float()
if device:
x = x.to(device)
def convert_x(self, df: pd.DataFrame, device: str) -> torch.Tensor:
numpy_arrays = df.values
x = torch.tensor(numpy_arrays, device=device, dtype=torch.float32)
return x
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> torch.Tensor:
y = torch.from_numpy(df.values)
if self._target_tensor_type:
y = y.to(self._target_tensor_type)
def convert_y(self, df: pd.DataFrame, device: str) -> torch.Tensor:
numpy_arrays = df.values
y = torch.tensor(numpy_arrays, device=device, dtype=self._target_tensor_type)
if self._squeeze_target_tensor:
y = y.squeeze()
if device:
y = y.to(device)
return y