Skip to content

Data

Dataset handling and data loading utilities.

DataGroup

DataGroup(data=None, keys=None, shard=None)

Dict-like container for data arrays with consistent shape.

Args:

data (str | Dict[str, np.ndarray] | None): The data to be used in the dataset. It can be a string representing the path to the data file in NPZ format or a dictionary where keys are strings and values are numpy arrays.

keys (List[str] | None): A list of keys to be used from the data dictionary.

shard (Tuple[int, int] | None): A tuple representing the shard index and total shards.

Source code in aimnet/data/sgdataset.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    data: str | dict[str, np.ndarray] | h5py.Group | None = None,
    keys: list[str] | None = None,
    shard: tuple[int, int] | None = None,
):
    # main container for data
    self._data: dict[str, np.ndarray] = {}

    if data is None:
        data = {}

    s = slice(shard[0], None, shard[1]) if shard is not None else slice(None)

    # load  to dict
    if isinstance(data, str):
        if not os.path.isfile(data):
            raise FileNotFoundError(f"{data} does not exist or not a file.")
        data = np.load(data, mmap_mode="r")
        if not hasattr(data, "files"):
            raise ValueError(f"Data file {data} does not contain named arrays.")

    # take only keys
    selected_keys = set(keys) if keys is not None else set(data.keys())  # type: ignore[union-attr]
    data = {k: v[s] for k, v in data.items() if k in selected_keys}  # type: ignore[union-attr]

    # check data
    _n = None
    for k, v in data.items():
        if not isinstance(k, str):
            raise TypeError(f"Expected key to be of type str, but got {type(k).__name__}")
        if keys is not None and k not in keys:
            continue
        if _n is None:
            _n = len(v)
        if len(v) != _n:
            raise ValueError(f"Inconsistent data shape for key {k}. Expected first dimension {_n}, got {len(v)}.")
        self._data[k] = v

cv_split(cv=5, seed=None)

Return list of cv tuples containing train and val DataGroups

Source code in aimnet/data/sgdataset.py
122
123
124
125
126
127
128
129
130
131
132
133
def cv_split(self, cv: int = 5, seed=None):
    """Return list of `cv` tuples containing train and val `DataGroup`s"""
    fractions = [1 / cv] * cv
    parts = self.random_split(*fractions, seed=seed)
    splits = []
    for icv in range(cv):
        val = parts[icv]
        _idx = [_i for _i in range(cv) if _i != icv]
        train = parts[_idx[0]]
        train.cat(*[parts[_i] for _i in _idx[1:]])
        splits.append((train, val))
    return splits

sample(idx, keys=None)

Return a new DataGroup with the data indexed by idx.

Source code in aimnet/data/sgdataset.py
103
104
105
106
107
108
109
def sample(self, idx, keys=None) -> "DataGroup":
    """Return a new `DataGroup` with the data indexed by `idx`."""
    if keys is None:
        keys = self.keys()
    if isinstance(idx, int):
        idx = slice(idx, idx + 1)
    return self.__class__({k: self[k][idx] for k in keys})

SizeGroupedDataset

SizeGroupedDataset(data=None, keys=None, shard=None)

Source code in aimnet/data/sgdataset.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def __init__(
    self,
    data: str | list[str] | dict[int, dict[str, np.ndarray]] | dict[int, DataGroup] | None = None,
    keys: list[str] | None = None,
    shard: tuple[int, int] | None = None,
):
    # main containers
    self._data: dict[int, DataGroup] = {}
    self._meta: dict[str, str] = {}

    # load data
    if isinstance(data, str):
        if os.path.isdir(data):
            self.load_datadir(data, keys=keys, shard=shard)
        else:
            self.load_h5(data, keys=keys, shard=shard)
    elif isinstance(data, (list, tuple)):
        self.load_files(data, shard=shard)
    elif isinstance(data, dict):
        self.load_dict(data)
    self.loader_mode = False
    self.x: list[str] = []
    self.y: list[str] = []

SizeGroupedSampler

SizeGroupedSampler(ds, batch_size, batch_mode='molecules', shuffle=False, batches_per_epoch=-1, seed=None)

Source code in aimnet/data/sgdataset.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def __init__(
    self,
    ds: SizeGroupedDataset,
    batch_size: int,
    batch_mode: str = "molecules",
    shuffle: bool = False,
    batches_per_epoch: int = -1,
    seed: int | None = None,
):
    self.ds = ds
    self.batch_size = batch_size
    if batch_mode not in ["molecules", "atoms"]:
        raise ValueError(f"Unknown batch_mode {batch_mode}")
    self.batch_mode = batch_mode
    self.shuffle = shuffle
    self.batches_per_epoch = batches_per_epoch
    self.seed = seed