Source code for deep_time_series.dataset

import numpy as np
import pandas as pd
from torch.utils.data import Dataset

from .chunk import BaseChunkSpec, ChunkExtractor
from .plotting import plot_chunks


[docs]class TimeSeriesDataset(Dataset): def __init__( self, data_frames: pd.DataFrame | list[pd.DataFrame], chunk_specs: list[BaseChunkSpec], return_time_index: bool = True, ): if isinstance(data_frames, pd.DataFrame): data_frames = [data_frames] self.data_frames = data_frames self.chunk_specs = chunk_specs self.return_time_index = return_time_index self._preprocess() def _preprocess(self): self.chunk_extractors = [ ChunkExtractor(df, self.chunk_specs) for df in self.data_frames ] self.lengths = [ len(df) - self.chunk_extractors[0].chunk_length + 1 for df in self.data_frames ] self.min_start_time_index = max( 0, -self.chunk_extractors[0].chunk_min_t ) def __len__(self): return sum(self.lengths) def __getitem__(self, i): cumsum = np.cumsum([0] + self.lengths) df_index = np.argmax(cumsum - i > 0) - 1 chunk_extractor = self.chunk_extractors[df_index] start_time_index = i - cumsum[df_index] + self.min_start_time_index chunk_dict = chunk_extractor.extract( start_time_index, self.return_time_index ) return chunk_dict def plot_chunks(self): plot_chunks(self.chunk_specs)