2023-05-01 13:18:03 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class WindowDataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(self, xs, ys, window_size):
|
|
|
|
self.xs = xs
|
|
|
|
self.ys = ys
|
|
|
|
self.window_size = window_size
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.xs) - self.window_size
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
idx_rev = len(self.xs) - self.window_size - index - 1
|
2024-05-12 15:12:20 +00:00
|
|
|
window_x = self.xs[idx_rev : idx_rev + self.window_size, :]
|
2023-05-01 13:18:03 +00:00
|
|
|
# Beware of indexing, these two window_x and window_y are aimed at the same row!
|
|
|
|
# this is what happens when you use :
|
|
|
|
window_y = self.ys[idx_rev + self.window_size - 1, :].unsqueeze(0)
|
|
|
|
return window_x, window_y
|