Modules

Neural network modules and model components.

Core Modules

core

SRRep(key_out='e_rep', cutoff_fn='none', rc=5.2, reduce_sum=True)

Bases: Module

GFN1-stype short range repulsion function

Source code in aimnet/modules/core.py
205
206
207
208
209
210
211
212
213
214
215
216
217
def __init__(self, key_out="e_rep", cutoff_fn="none", rc=5.2, reduce_sum=True):
    super().__init__()
    from aimnet.constants import get_gfn1_rep

    self.key_out = key_out
    self.cutoff_fn = cutoff_fn
    self.reduce_sum = reduce_sum

    self.register_buffer("rc", torch.tensor(rc))
    gfn1_repa, gfn1_repb = get_gfn1_rep()
    weight = torch.stack([gfn1_repa, gfn1_repb], dim=-1)
    self.params = nn.Embedding(87, 2, padding_idx=0, _weight=weight)
    self.params.weight.requires_grad_(False)

MLP(n_in, n_out, hidden=None, activation_fn='torch.nn.GELU', activation_kwargs=None, weight_init_fn='torch.nn.init.xavier_normal_', bias=True, last_linear=True)

Convenience function to build MLP from config

Source code in aimnet/modules/core.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def MLP(
    n_in: int,
    n_out: int,
    hidden: list[int] | None = None,
    activation_fn: Callable | str = "torch.nn.GELU",
    activation_kwargs: dict[str, Any] | None = None,
    weight_init_fn: Callable | str = "torch.nn.init.xavier_normal_",
    bias: bool = True,
    last_linear: bool = True,
):
    """Convenience function to build MLP from config"""
    if hidden is None:
        hidden = []
    if activation_kwargs is None:
        activation_kwargs = {}
    # hp search hack
    hidden = [x for x in hidden if x > 0]
    if isinstance(activation_fn, str):
        activation_fn = get_init_module(activation_fn, kwargs=activation_kwargs)
    if isinstance(weight_init_fn, str):
        weight_init_fn = get_module(weight_init_fn)
    sizes = [n_in, *hidden, n_out]
    layers = []
    for i in range(1, len(sizes)):
        n_in, n_out = sizes[i - 1], sizes[i]
        layer = nn.Linear(n_in, n_out, bias=bias)
        with torch.no_grad():
            weight_init_fn(layer.weight)
            if bias:
                nn.init.zeros_(layer.bias)
        layers.append(layer)
        if not (last_linear and i == len(sizes) - 1):
            layers.append(activation_fn)
    return nn.Sequential(*layers)

AIMNet2 Model

aimnet2

Base Classes

base

AIMNet2Base()

Bases: Module

Base class for AIMNet2 models. Implements pre-processing data: converting to right dtype and device, setting nb mode, calculating masks.

Source code in aimnet/models/base.py
191
192
193
194
def __init__(self):
    super().__init__()
    # Use object.__setattr__ to avoid TorchScript tracing this attribute
    object.__setattr__(self, "_metadata", None)

metadata property

Return model metadata if available.

prepare_input(data)

Common operations for input preparation.

Source code in aimnet/models/base.py
210
211
212
213
214
215
216
217
218
219
def prepare_input(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
    """Common operations for input preparation."""
    data = self._prepare_dtype(data)
    data = nbops.set_nb_mode(data)
    data = nbops.calc_masks(data)

    assert data["charge"].ndim == 1, "Charge should be 1D tensor."
    if "mult" in data:
        assert data["mult"].ndim == 1, "Mult should be 1D tensor."
    return data

ModelMetadata

Bases: TypedDict

Metadata returned by load_model().

This TypedDict documents the structure of the metadata dictionary.

load_model(path, device='cpu')

Load model from file, supporting both new and legacy formats.

Automatically detects format: - New format: state dict with embedded YAML config and metadata - Legacy format: JIT-compiled TorchScript model

Parameters

path : str Path to the model file (.pt or .jpt). device : str Device to load the model on. Default is "cpu".

Returns

model : nn.Module The loaded model with weights. metadata : ModelMetadata Dictionary containing model metadata. See ModelMetadata TypedDict for fields.

Notes

For legacy JIT models (format_version=1), needs_coulomb and needs_dispersion are False because LR modules are already embedded in the TorchScript model.

Source code in aimnet/models/base.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def load_model(path: str, device: str = "cpu") -> tuple[nn.Module, ModelMetadata]:
    """Load model from file, supporting both new and legacy formats.

    Automatically detects format:
    - New format: state dict with embedded YAML config and metadata
    - Legacy format: JIT-compiled TorchScript model

    Parameters
    ----------
    path : str
        Path to the model file (.pt or .jpt).
    device : str
        Device to load the model on. Default is "cpu".

    Returns
    -------
    model : nn.Module
        The loaded model with weights.
    metadata : ModelMetadata
        Dictionary containing model metadata. See ModelMetadata TypedDict for fields.

    Notes
    -----
    For legacy JIT models (format_version=1), `needs_coulomb` and `needs_dispersion`
    are False because LR modules are already embedded in the TorchScript model.
    """
    import yaml

    # torch.load auto-detects TorchScript and dispatches to torch.jit.load
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", ".*looks like a TorchScript archive.*")
        data = torch.load(path, map_location=device, weights_only=False)

    # Check result type to determine format
    if isinstance(data, dict) and "model_yaml" in data:
        # New state dict format
        model_config = yaml.safe_load(data["model_yaml"])
        model = build_module(model_config)

        # Use strict=False because modules may differ between formats
        load_result = model.load_state_dict(data["state_dict"], strict=False)

        # Check for unexpected missing/extra keys
        real_missing, real_unexpected = validate_state_dict_keys(load_result.missing_keys, load_result.unexpected_keys)
        if real_missing or real_unexpected:
            msg_parts = []
            if real_missing:
                msg_parts.append(f"Missing keys: {real_missing}")
            if real_unexpected:
                msg_parts.append(f"Unexpected keys: {real_unexpected}")
            warnings.warn(f"State dict mismatch during model loading. {'; '.join(msg_parts)}", stacklevel=2)

        model = model.to(device)

        # Preserve float64 precision for atomic shifts (SAE values) after device transfer
        if hasattr(model, "outputs") and hasattr(model.outputs, "atomic_shift"):
            model.outputs.atomic_shift.shifts = model.outputs.atomic_shift.shifts.double()

        metadata: ModelMetadata = {
            "format_version": data.get("format_version", 2),  # Default 2 for early v2 files without version
            "cutoff": data["cutoff"],
            "needs_coulomb": data.get("needs_coulomb", False),
            "needs_dispersion": data.get("needs_dispersion", False),
            "coulomb_mode": data.get("coulomb_mode", "none"),
            "coulomb_sr_rc": data.get("coulomb_sr_rc"),
            "coulomb_sr_envelope": data.get("coulomb_sr_envelope"),
            "d3_params": data.get("d3_params"),
            "has_embedded_lr": data.get("has_embedded_lr", False),
            "implemented_species": data.get("implemented_species", []),
        }

        # Attach metadata to model for easy access
        model._metadata = metadata

        return model, metadata

    elif isinstance(data, torch.jit.ScriptModule):
        # Legacy JIT format - LR modules are already embedded in the TorchScript model
        model = data
        metadata: ModelMetadata = {
            "format_version": 1,  # Legacy .jpt format is v1
            "cutoff": float(model.cutoff),
            # Legacy models have LR modules embedded - don't add external ones
            "needs_coulomb": False,
            "needs_dispersion": False,
            "coulomb_mode": "full_embedded",
            # No coulomb_sr_rc/envelope for legacy (full Coulomb is embedded)
            "d3_params": extract_d3_params(model) if has_externalizable_dftd3(model) else None,
            "implemented_species": extract_species(model),
        }

        # Attempt metadata assignment; silently fails for JIT models
        with contextlib.suppress(AttributeError, RuntimeError):
            model._metadata = metadata  # type: ignore[attr-defined]

        return model, metadata

    else:
        raise ValueError(f"Unknown model format: {type(data)}")