freqtrade_origin/freqtrade/freqai/torch/datasets.py
2024-05-13 07:10:25 +02:00

20 lines
668 B
Python

import torch
class WindowDataset(torch.utils.data.Dataset):
def __init__(self, xs, ys, window_size):
self.xs = xs
self.ys = ys
self.window_size = window_size
def __len__(self):
return len(self.xs) - self.window_size
def __getitem__(self, index):
idx_rev = len(self.xs) - self.window_size - index - 1
window_x = self.xs[idx_rev : idx_rev + self.window_size, :]
# Beware of indexing, these two window_x and window_y are aimed at the same row!
# this is what happens when you use :
window_y = self.ys[idx_rev + self.window_size - 1, :].unsqueeze(0)
return window_x, window_y