mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
20 lines
666 B
Python
20 lines
666 B
Python
import torch
|
|
|
|
|
|
class WindowDataset(torch.utils.data.Dataset):
|
|
def __init__(self, xs, ys, window_size):
|
|
self.xs = xs
|
|
self.ys = ys
|
|
self.window_size = window_size
|
|
|
|
def __len__(self):
|
|
return len(self.xs) - self.window_size
|
|
|
|
def __getitem__(self, index):
|
|
idx_rev = len(self.xs) - self.window_size - index - 1
|
|
window_x = self.xs[idx_rev:idx_rev + self.window_size, :]
|
|
# Beware of indexing, these two window_x and window_y are aimed at the same row!
|
|
# this is what happens when you use :
|
|
window_y = self.ys[idx_rev + self.window_size - 1, :].unsqueeze(0)
|
|
return window_x, window_y
|