Skip to content

_base

Base functions for universal machine-learned interatomic potentials.

logger module-attribute

logger = getLogger(__name__)

pick_calculator cached

pick_calculator(
    method: Literal["mace-mp-0", "m3gnet", "chgnet"],
    **kwargs
) -> Calculator

Adapted from matcalc.util.get_universal_calculator.

.. deprecated:: 0.7.6 method mace will be removed in a later version, it is replaced by 'mace-mp-0' which more accurately reflects the nature of the model and allows for versioning in the future.

Parameters:

  • method (Literal['mace-mp-0', 'm3gnet', 'chgnet']) –

    Name of the calculator to use

  • **kwargs

    Custom kwargs for the underlying calculator. Set a value to quacc.Remove to remove a pre-existing key entirely. For a list of available keys, refer to the mace.calculators.mace_mp, chgnet.model.dynamics.CHGNetCalculator, or matgl.ext.ase.M3GNetCalculator calculators.

Returns:

Source code in quacc/recipes/mlp/_base.py
@lru_cache
def pick_calculator(
    method: Literal["mace-mp-0", "m3gnet", "chgnet"], **kwargs
) -> Calculator:
    """
    Adapted from `matcalc.util.get_universal_calculator`.

    .. deprecated:: 0.7.6
            method `mace` will be removed in a later version, it is replaced by 'mace-mp-0'
            which more accurately reflects the nature of the model and allows for versioning
            in the future.

    Parameters
    ----------
    method
        Name of the calculator to use
    **kwargs
        Custom kwargs for the underlying calculator. Set a value to
        `quacc.Remove` to remove a pre-existing key entirely. For a list of available
        keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`,
        or `matgl.ext.ase.M3GNetCalculator` calculators.

    Returns
    -------
    Calculator
        The chosen calculator
    """
    import torch

    if not torch.cuda.is_available():
        logger.warning("CUDA is not available to PyTorch. Calculations will be slow.")

    if method.lower() == "m3gnet":
        import matgl
        from matgl import __version__
        from matgl.ext.ase import PESCalculator

        model = matgl.load_model("M3GNet-MP-2021.2.8-DIRECT-PES")
        kwargs.setdefault("stress_weight", 1.0 / 160.21766208)
        calc = PESCalculator(potential=model, **kwargs)

    elif method.lower() == "chgnet":
        from chgnet import __version__
        from chgnet.model.dynamics import CHGNetCalculator

        calc = CHGNetCalculator(**kwargs)

    elif method.lower() in ["mace-mp-0", "mace"]:
        from mace import __version__
        from mace.calculators import mace_mp

        if method.lower() == "mace":
            warn(
                "'mace' is deprecated and support will be removed. Use 'mace-mp-0' instead!",
                DeprecationWarning,
                stacklevel=3,
            )

        if "default_dtype" not in kwargs:
            kwargs["default_dtype"] = "float64"
        calc = mace_mp(**kwargs)

    else:
        raise ValueError(f"Unrecognized {method=}.")

    calc.parameters["version"] = __version__

    return calc