Skip to content

Modules

Neural network modules and model components.

Core Modules

core

SRRep(key_out='e_rep', cutoff_fn='none', rc=5.2, reduce_sum=True)

Bases: Module

GFN1-stype short range repulsion function

Source code in aimnet/modules/core.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def __init__(self, key_out="e_rep", cutoff_fn="none", rc=5.2, reduce_sum=True):
    super().__init__()
    from aimnet.constants import get_gfn1_rep

    self.key_out = key_out
    self.cutoff_fn = cutoff_fn
    self.reduce_sum = reduce_sum

    self.rc: Tensor
    self.register_buffer("rc", torch.tensor(rc))
    gfn1_repa, gfn1_repb = get_gfn1_rep()
    weight = torch.stack([gfn1_repa, gfn1_repb], dim=-1)
    self.params = nn.Embedding(87, 2, padding_idx=0, _weight=weight)
    self.params.weight.requires_grad_(False)

MLP(n_in, n_out, hidden=None, activation_fn='torch.nn.GELU', activation_kwargs=None, weight_init_fn='torch.nn.init.xavier_normal_', bias=True, last_linear=True)

Convenience function to build MLP from config

Source code in aimnet/modules/core.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def MLP(
    n_in: int,
    n_out: int,
    hidden: list[int] | None = None,
    activation_fn: Callable | str = "torch.nn.GELU",
    activation_kwargs: dict[str, Any] | None = None,
    weight_init_fn: Callable | str = "torch.nn.init.xavier_normal_",
    bias: bool = True,
    last_linear: bool = True,
):
    """Convenience function to build MLP from config"""
    if hidden is None:
        hidden = []
    if activation_kwargs is None:
        activation_kwargs = {}
    # hp search hack
    hidden = [x for x in hidden if x > 0]
    if isinstance(activation_fn, str):
        activation_fn = get_init_module(activation_fn, kwargs=activation_kwargs)
    activation = cast(nn.Module, activation_fn)
    if isinstance(weight_init_fn, str):
        weight_init_fn = get_module(weight_init_fn)
    weight_init = cast(Callable[[Tensor], Any], weight_init_fn)
    sizes = [n_in, *hidden, n_out]
    layers: list[nn.Module] = []
    for i in range(1, len(sizes)):
        n_in, n_out = sizes[i - 1], sizes[i]
        layer = nn.Linear(n_in, n_out, bias=bias)
        with torch.no_grad():
            weight_init(layer.weight)
            if bias:
                nn.init.zeros_(layer.bias)
        layers.append(layer)
        if not (last_linear and i == len(sizes) - 1):
            layers.append(activation)
    return nn.Sequential(*layers)

AEV and Convolution Modules

aev

AEVSV(rmin=0.8, rc_s=5.0, nshifts_s=16, eta_s=None, rc_v=None, nshifts_v=None, eta_v=None, shifts_s=None, shifts_v=None)

Bases: Module

AEV module to expand distances and vectors toneighbors over shifted Gaussian basis functions.

Parameters:

rmin : float, optional Minimum distance for the Gaussian basis functions. Default is 0.8. rc_s : float, optional Cutoff radius for scalar features. Default is 5.0. nshifts_s : int, optional Number of shifts for scalar features. Default is 16. eta_s : Optional[float], optional Width of the Gaussian basis functions for scalar features. Will estimate reasonable default. rc_v : Optional[float], optional Cutoff radius for vector features. Default is same as rc_s. nshifts_v : Optional[int], optional Number of shifts for vector features. Default is same as nshifts_s eta_v : Optional[float], optional Width of the Gaussian basis functions for vector features. Will estimate reasonable default. shifts_s : Optional[List[float]], optional List of shifts for scalar features. Default equidistant between rmin and rc_s shifts_v : Optional[List[float]], optional List of shifts for vector features. Default equidistant between rmin and rc_v

Source code in aimnet/modules/aev.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    rmin: float = 0.8,
    rc_s: float = 5.0,
    nshifts_s: int = 16,
    eta_s: float | None = None,
    rc_v: float | None = None,
    nshifts_v: int | None = None,
    eta_v: float | None = None,
    shifts_s: list[float] | None = None,
    shifts_v: list[float] | None = None,
):
    super().__init__()

    self._init_basis(rc_s, eta_s, nshifts_s, shifts_s, rmin, mod="_s")
    if rc_v is not None:
        if rc_v > rc_s:
            raise ValueError("rc_v must be less than or equal to rc_s")
        if nshifts_v is None:
            raise ValueError("nshifts_v must not be None")
        self._init_basis(rc_v, eta_v, nshifts_v, shifts_v, rmin, mod="_v")
        self._dual_basis = True
    else:
        # dummy init
        self._init_basis(rc_s, eta_s, nshifts_s, shifts_s, rmin, mod="_v")
        self._dual_basis = False

    self.dmat_fill = rc_s

ConvSV(nshifts_s, nchannel, d2features=False, nshifts_v=None, ncomb_v=None)

Bases: Module

AIMNet2 type convolution: encoding of local environment which combines geometry of local environment and atomic features.

Parameters:

nshifts_s : int Number of shifts (gaussian basis functions) for scalar convolution. nchannel : int Number of feature channels for atomic features. d2features : bool, optional Flag indicating whether to use 2D features. Default is False. nshifts_v : Optional[int], optional Number of shifts for vector convolution. If not provided, defaults to the value of nshifts_s. ncomb_v : Optional[int], optional Number of linear combinations for vector features. If not provided, defaults to the value of nshifts_v.

Source code in aimnet/modules/aev.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def __init__(
    self,
    nshifts_s: int,
    nchannel: int,
    d2features: bool = False,
    nshifts_v: int | None = None,
    ncomb_v: int | None = None,
):
    super().__init__()
    nshifts_v = nshifts_v or nshifts_s
    ncomb_v = ncomb_v or nshifts_v
    agh = _init_ahg(nchannel, nshifts_v, ncomb_v)
    self.register_parameter("agh", nn.Parameter(agh, requires_grad=True))
    self.do_vector = True
    self.nchannel = nchannel
    self.d2features = d2features
    self.nshifts_s = nshifts_s
    self.nshifts_v = nshifts_v
    self.ncomb_v = ncomb_v

Long-Range Modules

lr

D3TS(a1, a2, s8, s6=1.0, key_in='disp_param', key_out='energy')

Bases: Module

DFT-D3-like pairwise dispersion with TS combination rule

Source code in aimnet/modules/lr.py
903
904
905
906
907
908
909
910
911
912
def __init__(self, a1: float, a2: float, s8: float, s6: float = 1.0, key_in="disp_param", key_out="energy"):
    super().__init__()
    self.r4r2: Tensor
    self.register_buffer("r4r2", constants.get_r4r2())
    self.a1 = a1
    self.a2 = a2
    self.s6 = s6
    self.s8 = s8
    self.key_in = key_in
    self.key_out = key_out

DFTD3(s8, a1, a2, s6=1.0, cutoff=15.0, smoothing_fraction=0.2, key_out='energy')

Bases: Module

DFT-D3 implementation using nvalchemiops GPU-accelerated kernels.

BJ damping, C6 and C8 terms, without 3-body term.

This implementation uses nvalchemiops.torch.interactions.dispersion.dftd3 for GPU-accelerated computation of dispersion energies, forces, and virial. The embedded model path injects explicit forces/virial into autograd only when coordinate or strain gradients are requested; the external calculator path returns detached derivative terms.

Parameters

s8 : float Scaling factor for C8 term. a1 : float BJ damping parameter 1. a2 : float BJ damping parameter 2. s6 : float, optional Scaling factor for C6 term. Default is 1.0. cutoff : float, optional Cutoff distance in Angstroms for smoothing. Default is 15.0. smoothing_fraction : float, optional Fraction of cutoff distance used for smoothing window width. Smoothing starts at cutoff * (1 - smoothing_fraction) and ends at cutoff. Example: With cutoff=15.0 and smoothing_fraction=0.2: - Smoothing starts at 12.0 Ã… (15.0 * 0.8) - Smoothing ends at 15.0 Ã… Default is 0.2 (20% of cutoff as smoothing window). key_out : str, optional Key for output energy in data dict. Default is "energy". Attributes


smoothing_on : float Distance where smoothing starts (Angstroms). smoothing_off : float Distance where smoothing ends / cutoff (Angstroms). s6, s8, a1, a2 : float BJ damping parameters.

Notes

Neighbor list keys follow a suffix resolution pattern: methods first look for module-specific keys (e.g., nbmat_dftd3, shifts_dftd3), falling back to shared _lr suffix (nbmat_lr, shifts_lr) if not found.

Source code in aimnet/modules/lr.py
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
def __init__(
    self,
    s8: float,
    a1: float,
    a2: float,
    s6: float = 1.0,
    cutoff: float = 15.0,
    smoothing_fraction: float = 0.2,
    key_out: str = "energy",
):
    super().__init__()
    self.key_out = key_out
    # BJ damping parameters
    self.s6 = s6
    self.s8 = s8
    self.a1 = a1
    self.a2 = a2

    # Smoothing parameters as module attributes
    self.smoothing_on: float = cutoff * (1 - smoothing_fraction)
    self.smoothing_off: float = cutoff

    # Load D3 reference parameters and convert to nvalchemiops format
    dirname = os.path.dirname(os.path.dirname(__file__))
    filename = os.path.join(dirname, "dftd3_data.pt")
    param = torch.load(filename, map_location="cpu", weights_only=True)

    c6ab_packed = param["c6ab"]
    c6ab = c6ab_packed[..., 0].contiguous()
    cn_ref = c6ab_packed[..., 1].contiguous()

    # Register buffers for D3 parameters
    self.rcov: Tensor
    self.r4r2: Tensor
    self.c6ab: Tensor
    self.cn_ref: Tensor
    self.register_buffer("rcov", param["rcov"].float())
    self.register_buffer("r4r2", param["r4r2"].float())
    self.register_buffer("c6ab", c6ab.float())
    self.register_buffer("cn_ref", cn_ref.float())

forward(data, *, compute_forces=False, compute_virial=False, hessian=False, scaling=None, coord_unstrained=None, cell_unstrained=None)

Compute DFT-D3 energy and optional explicit derivative terms.

The embedded path returns an autograd-capable energy only when the coordinate or explicit calculator strain inputs require it. Explicit derivative requests return detached energy and derivative terms. The strain-wrapper kwargs are for direct differentiable callers; the calculator uses explicit derivative terms for stress because DFT-D3 has no trainable parameters.

The returned virial follows the calculator-side external-derivative convention: get_derivatives subtracts terms.virial.mT from the strain gradient (the same path DSF uses). FD-validated against dE/dscaling in :class:tests.test_dftd3.TestDFTD3ForwardTerms.

Source code in aimnet/modules/lr.py
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
def forward(
    self,
    data: dict[str, Tensor],
    *,
    compute_forces: bool = False,
    compute_virial: bool = False,
    hessian: bool = False,
    scaling: Tensor | None = None,
    coord_unstrained: Tensor | None = None,
    cell_unstrained: Tensor | None = None,
) -> dict[str, Tensor] | tuple[dict[str, Tensor], ExternalDerivativeTerms | None]:
    """Compute DFT-D3 energy and optional explicit derivative terms.

    The embedded path returns an autograd-capable energy only when the
    coordinate or explicit calculator strain inputs require it. Explicit
    derivative requests return detached energy and derivative terms.
    The strain-wrapper kwargs are for direct differentiable callers; the
    calculator uses explicit derivative terms for stress because DFT-D3 has
    no trainable parameters.

    The returned virial follows the calculator-side external-derivative
    convention: ``get_derivatives`` subtracts ``terms.virial.mT`` from the
    strain gradient (the same path DSF uses). FD-validated against
    ``dE/dscaling`` in :class:`tests.test_dftd3.TestDFTD3ForwardTerms`.
    """
    derivative_terms = compute_forces or compute_virial
    if hessian:
        if derivative_terms:
            raise ValueError("hessian=True uses differentiable DFTD3 energy; do not request explicit terms")
        energy_ev = self._compute_energy_torch(data).double()
        if self.key_out in data:
            data[self.key_out] = data[self.key_out].double() + energy_ev
        else:
            data[self.key_out] = energy_ev
        return data

    use_strain_wrapper = False
    if not derivative_terms and isinstance(scaling, Tensor) and scaling.requires_grad:
        if not isinstance(coord_unstrained, Tensor) or not isinstance(cell_unstrained, Tensor):
            raise ValueError("strain-aware DFTD3 requires coord_unstrained and cell_unstrained")
        use_strain_wrapper = True

    kernel_data = data
    if use_strain_wrapper:
        assert isinstance(coord_unstrained, Tensor)
        assert isinstance(cell_unstrained, Tensor)
        kernel_data = {**data, "coord": coord_unstrained, "cell": cell_unstrained}
    kernel_inputs = self._prepare_dftd3_inputs(kernel_data)

    common_args = (
        kernel_inputs.numbers_flat,
        kernel_inputs.batch_idx,
        kernel_inputs.neighbor_matrix,
        kernel_inputs.neighbor_matrix_shifts,
        int(kernel_inputs.fill_value),
        int(kernel_inputs.num_systems),
        self.rcov,
        self.r4r2,
        self.c6ab,
        self.cn_ref,
        float(self.a1),
        float(self.a2),
        float(self.s8),
        float(self.s6),
        float(self.smoothing_on),
        float(self.smoothing_off),
    )

    if derivative_terms:
        with torch.no_grad():
            energy_ev, forces_ev_flat, virial_kernel = _call_dftd3_kernel(
                coord=kernel_inputs.coord_flat.detach(),
                numbers=kernel_inputs.numbers_flat,
                batch_idx=kernel_inputs.batch_idx,
                neighbor_matrix=kernel_inputs.neighbor_matrix,
                neighbor_matrix_shifts=kernel_inputs.neighbor_matrix_shifts,
                fill_value=int(kernel_inputs.fill_value),
                num_systems=int(kernel_inputs.num_systems),
                cell=kernel_inputs.cell_for_kernel.detach() if kernel_inputs.cell_for_kernel is not None else None,
                rcov=self.rcov,
                r4r2=self.r4r2,
                c6_reference=self.c6ab,
                coord_num_ref=self.cn_ref,
                a1=float(self.a1),
                a2=float(self.a2),
                s8=float(self.s8),
                s6=float(self.s6),
                smoothing_on=float(self.smoothing_on),
                smoothing_off=float(self.smoothing_off),
                compute_virial=compute_virial,
            )
        energy_ev = energy_ev.detach().double()
    elif use_strain_wrapper:
        energy_ev = _DFTD3EnergyFunction.apply(
            kernel_inputs.coord_flat,
            kernel_inputs.cell_for_kernel,
            scaling,
            *common_args,
        )
        forces_ev_flat = kernel_inputs.coord_flat.new_empty(0)
        virial_kernel = kernel_inputs.coord_flat.new_empty(0)
    elif kernel_inputs.coord_flat.requires_grad:
        energy_ev = _DFTD3EnergyFunction.apply(
            kernel_inputs.coord_flat,
            kernel_inputs.cell_for_kernel,
            None,
            *common_args,
        )
        forces_ev_flat = kernel_inputs.coord_flat.new_empty(0)
        virial_kernel = kernel_inputs.coord_flat.new_empty(0)
    else:
        with torch.no_grad():
            energy_ev, forces_ev_flat, virial_kernel = _call_dftd3_kernel(
                coord=kernel_inputs.coord_flat.detach(),
                numbers=kernel_inputs.numbers_flat,
                batch_idx=kernel_inputs.batch_idx,
                neighbor_matrix=kernel_inputs.neighbor_matrix,
                neighbor_matrix_shifts=kernel_inputs.neighbor_matrix_shifts,
                fill_value=int(kernel_inputs.fill_value),
                num_systems=int(kernel_inputs.num_systems),
                cell=kernel_inputs.cell_for_kernel.detach() if kernel_inputs.cell_for_kernel is not None else None,
                rcov=self.rcov,
                r4r2=self.r4r2,
                c6_reference=self.c6ab,
                coord_num_ref=self.cn_ref,
                a1=float(self.a1),
                a2=float(self.a2),
                s8=float(self.s8),
                s6=float(self.s6),
                smoothing_on=float(self.smoothing_on),
                smoothing_off=float(self.smoothing_off),
                compute_virial=False,
            )
        energy_ev = energy_ev.detach().double()

    if self.key_out in data:
        data[self.key_out] = data[self.key_out].double() + energy_ev
    else:
        data[self.key_out] = energy_ev

    forces_ev: Tensor | None = None
    if compute_forces:
        forces_ev = self._restore_dftd3_forces_shape(
            forces_ev_flat.detach(),
            kernel_inputs.nb_mode,
            kernel_inputs.coord.shape,
        )

    # The nvalchemi DFTD3 kernel already returns the strain virial in the
    # external-derivative-term convention used by the calculator
    # (``dedc -= terms.virial.mT``). Verified by FD against
    # ``dE/dscaling`` under row-vector strain - see
    # ``tests/test_dftd3.py::TestDFTD3ForwardTerms``.
    external_virial: Tensor | None = None
    if compute_virial and virial_kernel.numel() > 0:
        external_virial = virial_kernel.detach().contiguous()

    terms = None
    if derivative_terms:
        terms = ExternalDerivativeTerms(forces=forces_ev, virial=external_virial)
        return data, terms
    return data

set_smoothing(cutoff, smoothing_fraction=0.2)

Update smoothing parameters based on new cutoff and fraction.

Parameters

cutoff : float Cutoff distance in Angstroms. smoothing_fraction : float Fraction of cutoff used as smoothing window width. Smoothing occurs from cutoff * (1 - smoothing_fraction) to cutoff. Example: smoothing_fraction=0.2 means smoothing over last 20% of cutoff distance (from 0.8*cutoff to cutoff). Default is 0.2.

Source code in aimnet/modules/lr.py
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
def set_smoothing(self, cutoff: float, smoothing_fraction: float = 0.2) -> None:
    """Update smoothing parameters based on new cutoff and fraction.

    Parameters
    ----------
    cutoff : float
        Cutoff distance in Angstroms.
    smoothing_fraction : float
        Fraction of cutoff used as smoothing window width.
        Smoothing occurs from cutoff * (1 - smoothing_fraction) to cutoff.
        Example: smoothing_fraction=0.2 means smoothing over last 20%
        of cutoff distance (from 0.8*cutoff to cutoff). Default is 0.2.
    """
    self.smoothing_on = cutoff * (1 - smoothing_fraction)
    self.smoothing_off = cutoff

ExternalDerivativeTerms(forces=None, virial=None) dataclass

Explicit derivative terms returned by external nvalchemiops backends.

LRCoulomb(key_in='charges', key_out='e_h', rc=4.6, method='simple', dsf_alpha=0.2, dsf_rc=15.0, ewald_accuracy=1e-06, subtract_sr=True, envelope='exp')

Bases: Module

Long-range Coulomb energy module.

Computes electrostatic energy using one of several methods: simple (all pairs), DSF (damped shifted force), Ewald summation, or Particle Mesh Ewald (PME). DSF, Ewald, and PME are backed by nvalchemiops; Ewald and PME require periodic systems with a cell.

Parameters

key_in : str Key for input charges in data dict. Default is "charges". key_out : str Key for output energy in data dict. Default is "e_h". rc : float Short-range cutoff radius. Default is 4.6 Angstrom. method : str Coulomb method: "simple", "dsf", "ewald", or "pme". Default is "simple". dsf_alpha : float Alpha parameter for DSF method. Default is 0.2. dsf_rc : float Cutoff for DSF method. Default is 15.0. ewald_accuracy : float Target accuracy for Ewald and PME summation. Controls real-space and reciprocal-space cutoffs (and PME mesh dimensions). Lower values give higher accuracy at higher cost. Default is 1e-6. subtract_sr : bool Whether to subtract short-range contribution. Default is True. envelope : str Envelope function for SR cutoff: "exp" or "cosine". Default is "exp".

Notes

Energy accumulation uses float64 for numerical precision, particularly important for large systems where many small contributions can suffer from floating-point error accumulation.

Neighbor list keys follow a suffix resolution pattern: methods first look for module-specific keys (e.g., nbmat_coulomb, shifts_coulomb), falling back to shared _lr suffix (nbmat_lr, shifts_lr) if not found.

DSF uses nvalchemiops.torch.interactions.electrostatics.dsf_coulomb. Its energy is differentiable through charges, but not through positions or cell; the calculator consumes explicit DSF forces/virial for inference and rejects DSF force/stress training and Hessian requests.

Ewald/PME call nvalchemiops directly. Inference uses hybrid_forces=True so energy remains differentiable through charges and fixed-charge geometry derivatives are returned as explicit terms. Training derivative paths use a small local autograd.Function wrapper because the installed nvalchemiops coordinate backward kernels do not currently provide a registered backward-of-backward. Calculator Hessian requests are rejected for Ewald/PME because complete Hessians require true second coordinate derivatives.

Source code in aimnet/modules/lr.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def __init__(
    self,
    key_in: str = "charges",
    key_out: str = "e_h",
    rc: float = 4.6,
    method: str = "simple",
    dsf_alpha: float = 0.2,
    dsf_rc: float = 15.0,
    ewald_accuracy: float = 1e-6,
    subtract_sr: bool = True,
    envelope: str = "exp",
):
    super().__init__()
    self.key_in = key_in
    self.key_out = key_out
    # Pairwise convention factor used by simple/dsf (sums over ordered pairs).
    # Ewald/PME nvalchemiops outputs are converted with k_e = Hartree * Bohr.
    self._factor = constants.half_Hartree * constants.Bohr
    self.rc: Tensor
    self.register_buffer("rc", torch.tensor(rc))
    self.dsf_alpha = dsf_alpha
    self.dsf_rc = dsf_rc
    self.ewald_accuracy = ewald_accuracy
    self.subtract_sr = subtract_sr
    if envelope not in ("exp", "cosine"):
        raise ValueError(f"Unknown envelope {envelope}, must be 'exp' or 'cosine'")
    self.envelope = envelope
    if method in ("simple", "dsf", "ewald", "pme"):
        self.method = method
    else:
        raise ValueError(f"Unknown method {method}")

coul_ewald(data)

Per-system Ewald energy in eV. Requires cell and nbmat_lr/shifts_lr.

Source code in aimnet/modules/lr.py
706
707
708
709
def coul_ewald(self, data: dict[str, Tensor]) -> Tensor:
    """Per-system Ewald energy in eV. Requires ``cell`` and ``nbmat_lr``/``shifts_lr``."""
    energy, _terms = self._coul_nvalchemi(data, backend="ewald")
    return energy

coul_pme(data)

Per-system PME energy in eV. Requires cell and nbmat_lr/shifts_lr.

Source code in aimnet/modules/lr.py
711
712
713
714
def coul_pme(self, data: dict[str, Tensor]) -> Tensor:
    """Per-system PME energy in eV. Requires ``cell`` and ``nbmat_lr``/``shifts_lr``."""
    energy, _terms = self._coul_nvalchemi(data, backend="pme")
    return energy

coul_simple(data)

Compute pairwise Coulomb energy.

With subtract_sr=True (default): Returns LR only (FULL - SR) With subtract_sr=False: Returns FULL pairwise Coulomb

Source code in aimnet/modules/lr.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def coul_simple(self, data: dict[str, Tensor]) -> Tensor:
    """Compute pairwise Coulomb energy.

    With subtract_sr=True (default): Returns LR only (FULL - SR)
    With subtract_sr=False: Returns FULL pairwise Coulomb
    """
    suffix = nbops.resolve_suffix(data, ["_coulomb", "_lr"])
    data = ops.lazy_calc_dij(data, suffix)
    d_ij = data[f"d_ij{suffix}"]
    q = data[self.key_in]
    q_i, q_j = nbops.get_ij(q, data, suffix=suffix)
    q_ij = q_i * q_j
    # Compute FULL pairwise Coulomb (no exp_cutoff weighting)
    e_ij = q_ij / d_ij
    e_ij = nbops.mask_ij_(e_ij, data, 0.0, suffix=suffix)
    e_i = e_ij.sum(-1, dtype=torch.float64)
    e = self._factor * nbops.mol_sum(e_i, data)
    # Same pattern as dsf/ewald - subtract SR to get LR
    if self.subtract_sr:
        e = e - self.coul_simple_sr(data)
    return e

SRCoulomb(rc=4.6, key_in='charges', key_out='energy', envelope='exp')

Bases: Module

Subtract short-range Coulomb contribution from energy.

For models trained with "simple" Coulomb mode, the NN has implicitly learned the short-range Coulomb interaction. When using DSF or Ewald summation for the full Coulomb energy, we need to subtract this short-range contribution to avoid double-counting.

Parameters

rc : float Cutoff radius for short-range Coulomb. Default is 4.6 Angstrom. key_in : str Key for input charges in data dict. Default is "charges". key_out : str Key for output energy in data dict. Default is "energy". envelope : str Envelope function for cutoff: "exp" (mollifier) or "cosine". Default is "exp".

Source code in aimnet/modules/lr.py
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
def __init__(
    self,
    rc: float = 4.6,
    key_in: str = "charges",
    key_out: str = "energy",
    envelope: str = "exp",
):
    super().__init__()
    self.key_in = key_in
    self.key_out = key_out
    self._factor = constants.half_Hartree * constants.Bohr
    self.rc: Tensor
    self.register_buffer("rc", torch.tensor(rc))
    if envelope not in ("exp", "cosine"):
        raise ValueError(f"Unknown envelope {envelope}, must be 'exp' or 'cosine'")
    self.envelope = envelope

forward(data)

Subtract short-range Coulomb from energy.

Source code in aimnet/modules/lr.py
809
810
811
812
813
814
815
816
817
818
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
    """Subtract short-range Coulomb from energy."""
    e_sr = _calc_coulomb_sr(data, self.rc, self.envelope, self.key_in, self._factor)

    # Subtract short-range Coulomb from energy (in float64)
    if self.key_out in data:
        data[self.key_out] = data[self.key_out].double() - e_sr
    else:
        data[self.key_out] = -e_sr
    return data

AIMNet2 Model

aimnet2

Base Classes

base

AIMNet2Base()

Bases: Module

Base class for AIMNet2 models. Implements pre-processing data: converting to right dtype and device, setting nb mode, calculating masks.

Source code in aimnet/models/base.py
211
212
213
214
def __init__(self):
    super().__init__()
    # Use object.__setattr__ to avoid TorchScript tracing this attribute
    object.__setattr__(self, "_metadata", None)

metadata property

Return model metadata if available.

prepare_input(data)

Common operations for input preparation.

Source code in aimnet/models/base.py
230
231
232
233
234
235
236
237
238
239
def prepare_input(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
    """Common operations for input preparation."""
    data = self._prepare_dtype(data)
    data = nbops.set_nb_mode(data)
    data = nbops.calc_masks(data)

    assert data["charge"].ndim == 1, "Charge should be 1D tensor."
    if "mult" in data:
        assert data["mult"].ndim == 1, "Mult should be 1D tensor."
    return data

ModelMetadata

Bases: TypedDict

Metadata returned by load_model().

This TypedDict documents the structure of the metadata dictionary.

load_model(path, device='cpu')

Load model from file, supporting both new and legacy formats.

Automatically detects format: - New format: state dict with embedded YAML config and metadata - Legacy format: JIT-compiled TorchScript model

Parameters

path : str Path to the model file (.pt or .jpt). device : str Device to load the model on. Default is "cpu".

Returns

model : nn.Module The loaded model with weights. metadata : ModelMetadata Dictionary containing model metadata. See ModelMetadata TypedDict for fields.

Notes

For legacy JIT models (format_version=1), needs_coulomb and needs_dispersion are False because LR modules are already embedded in the TorchScript model.

Source code in aimnet/models/base.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def load_model(path: str, device: str = "cpu") -> tuple[nn.Module, ModelMetadata]:
    """Load model from file, supporting both new and legacy formats.

    Automatically detects format:
    - New format: state dict with embedded YAML config and metadata
    - Legacy format: JIT-compiled TorchScript model

    Parameters
    ----------
    path : str
        Path to the model file (.pt or .jpt).
    device : str
        Device to load the model on. Default is "cpu".

    Returns
    -------
    model : nn.Module
        The loaded model with weights.
    metadata : ModelMetadata
        Dictionary containing model metadata. See ModelMetadata TypedDict for fields.

    Notes
    -----
    For legacy JIT models (format_version=1), `needs_coulomb` and `needs_dispersion`
    are False because LR modules are already embedded in the TorchScript model.
    """
    import yaml

    # Try weights_only=True first (secure for new .pt format).
    # Falls back to weights_only=False for legacy TorchScript .jpt archives,
    # which require full deserialization to load the frozen computation graph.
    try:
        data = torch.load(path, map_location=device, weights_only=True)
    except Exception:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", ".*looks like a TorchScript archive.*")
            data = torch.load(path, map_location=device, weights_only=False)

    # Check result type to determine format
    if isinstance(data, dict) and "model_yaml" in data:
        # New state dict format
        model_config = yaml.safe_load(data["model_yaml"])
        model = build_module(model_config)
        if not isinstance(model, nn.Module):
            raise TypeError("Built model configuration did not produce an nn.Module.")

        # Atomic shifts store SAE/reference-energy values and may be float64 in
        # the file. Cast before load_state_dict so copy_ does not truncate them
        # into the default float32 embedding.
        if hasattr(model, "outputs") and hasattr(model.outputs, "atomic_shift"):
            model.outputs.atomic_shift.double()

        # Use strict=False because modules may differ between formats
        load_result = model.load_state_dict(data["state_dict"], strict=False)

        # Check for unexpected missing/extra keys
        real_missing, real_unexpected = validate_state_dict_keys(load_result.missing_keys, load_result.unexpected_keys)
        if real_missing or real_unexpected:
            msg_parts = []
            if real_missing:
                msg_parts.append(f"Missing keys: {real_missing}")
            if real_unexpected:
                msg_parts.append(f"Unexpected keys: {real_unexpected}")
            warnings.warn(f"State dict mismatch during model loading. {'; '.join(msg_parts)}", stacklevel=2)

        model = model.to(device)

        metadata: ModelMetadata = {
            "format_version": data.get("format_version", 2),  # Default 2 for early v2 files without version
            "cutoff": data["cutoff"],
            "needs_coulomb": data.get("needs_coulomb", False),
            "needs_dispersion": data.get("needs_dispersion", False),
            "coulomb_mode": data.get("coulomb_mode", "none"),
            "coulomb_sr_rc": data.get("coulomb_sr_rc"),
            "coulomb_sr_envelope": data.get("coulomb_sr_envelope"),
            "d3_params": data.get("d3_params"),
            "has_embedded_lr": data.get("has_embedded_lr", False),
            "implemented_species": data.get("implemented_species", []),
            "family": data.get("family"),
            "supports_charged_systems": data.get("supports_charged_systems"),
            "has_embedded_d3ts": data.get("has_embedded_d3ts", False),
        }

        # Attach metadata to model for easy access
        model._metadata = metadata  # type: ignore[assignment]

        return model, metadata

    elif isinstance(data, torch.jit.ScriptModule):
        # Legacy JIT format - LR modules are already embedded in the TorchScript model
        model = data
        legacy_metadata: ModelMetadata = {
            "format_version": 1,  # Legacy .jpt format is v1
            "cutoff": float(model.cutoff),
            # Legacy models have LR modules embedded - don't add external ones
            "needs_coulomb": False,
            "needs_dispersion": False,
            "coulomb_mode": "full_embedded",
            # No coulomb_sr_rc/envelope for legacy (full Coulomb is embedded)
            "d3_params": extract_d3_params(model) if has_externalizable_dftd3(model) else None,
            "implemented_species": extract_species(model),
        }

        # Attempt metadata assignment; silently fails for JIT models
        with contextlib.suppress(AttributeError, RuntimeError):
            model._metadata = legacy_metadata  # type: ignore[attr-defined]

        return model, legacy_metadata

    else:
        raise ValueError(f"Unknown model format: {type(data)}")