Source code for autogl.data.dataloader

import torch.utils.data
from torch.utils.data.dataloader import default_collate

from .batch import Batch


[docs]class DataLoader(torch.utils.data.DataLoader): r"""Data loader which merges data objects from a :class:`cogdl.data.dataset` to a mini-batch. Args: dataset (Dataset): The dataset from which to load the data. batch_size (int, optional): How may samples per batch to load. (default: :obj:`1`) shuffle (bool, optional): If set to :obj:`True`, the data will be reshuffled at every epoch (default: :obj:`True`) """ def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): super(DataLoader, self).__init__( dataset, batch_size, shuffle, collate_fn=lambda data_list: Batch.from_data_list(data_list), **kwargs )
[docs]class DataListLoader(torch.utils.data.DataLoader): r"""Data loader which merges data objects from a :class:`cogdl.data.dataset` to a python list. .. note:: This data loader should be used for multi-gpu support via :class:`cogdl.nn.DataParallel`. Args: dataset (Dataset): The dataset from which to load the data. batch_size (int, optional): How may samples per batch to load. (default: :obj:`1`) shuffle (bool, optional): If set to :obj:`True`, the data will be reshuffled at every epoch (default: :obj:`True`) """ def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): super(DataListLoader, self).__init__( dataset, batch_size, shuffle, collate_fn=lambda data_list: data_list, **kwargs )
[docs]class DenseDataLoader(torch.utils.data.DataLoader): r"""Data loader which merges data objects from a :class:`cogdl.data.dataset` to a mini-batch. .. note:: To make use of this data loader, all graphs in the dataset needs to have the same shape for each its attributes. Therefore, this data loader should only be used when working with *dense* adjacency matrices. Args: dataset (Dataset): The dataset from which to load the data. batch_size (int, optional): How may samples per batch to load. (default: :obj:`1`) shuffle (bool, optional): If set to :obj:`True`, the data will be reshuffled at every epoch (default: :obj:`True`) """ def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): def dense_collate(data_list): batch = Batch() for key in data_list[0].keys: batch[key] = default_collate([d[key] for d in data_list]) return batch super(DenseDataLoader, self).__init__( dataset, batch_size, shuffle, collate_fn=dense_collate, **kwargs )