Data
Dataset handling and data loading utilities.
SizeGroupedDataset
SizeGroupedDataset(data=None, keys=None, shard=None)
Source code in aimnet/data/sgdataset.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208 | def __init__(
self,
data: str | list[str] | dict[int, dict[str, np.ndarray]] | dict[int, DataGroup] | None = None,
keys: list[str] | None = None,
shard: tuple[int, int] | None = None,
):
# main containers
self._data: dict[int, DataGroup] = {}
self._meta: dict[str, str] = {}
# load data
if isinstance(data, str):
if os.path.isdir(data):
self.load_datadir(data, keys=keys, shard=shard)
else:
self.load_h5(data, keys=keys, shard=shard)
elif isinstance(data, (list, tuple)):
self.load_files(data, shard=shard)
elif isinstance(data, dict):
self.load_dict(data)
self.loader_mode = False
self.x: list[str] = []
self.y: list[str] = []
|
SizeGroupedSampler
SizeGroupedSampler(ds, batch_size, batch_mode='molecules', shuffle=False, batches_per_epoch=-1)
Source code in aimnet/data/sgdataset.py
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477 | def __init__(
self,
ds: SizeGroupedDataset,
batch_size: int,
batch_mode: str = "molecules",
shuffle: bool = False,
batches_per_epoch: int = -1,
):
self.ds = ds
self.batch_size = batch_size
if batch_mode not in ["molecules", "atoms"]:
raise ValueError(f"Unknown batch_mode {batch_mode}")
self.batch_mode = batch_mode
self.shuffle = shuffle
self.batches_per_epoch = batches_per_epoch
|