Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,25 @@
from typing import (
Any,
ClassVar,
Final,
)

import numpy as np
import torch
import torch.nn as nn

from deepmd.dpmodel.utils import EnvMat as DPEnvMat
from deepmd.dpmodel.utils.seed import (
child_seed,
)
from deepmd.pt.model.descriptor import (
DescriptorBlock,
prod_env_mat,
)
from deepmd.pt.model.network.mlp import (
EmbeddingNet,
NetworkCollection,
)
from deepmd.pt.utils import (
env,
)
Expand All @@ -29,9 +35,18 @@
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.tabulate import (
DPTabulate,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand All @@ -45,28 +60,6 @@
check_version_compatibility,
)

try:
from typing import (
Final,
)
except ImportError:
from torch.jit import Final

from deepmd.dpmodel.utils import EnvMat as DPEnvMat
from deepmd.pt.model.network.mlp import (
EmbeddingNet,
NetworkCollection,
)
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.tabulate import (
DPTabulate,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)

from .base_descriptor import (
BaseDescriptor,
)
Expand Down Expand Up @@ -776,7 +769,6 @@ def forward(

dmatrix = dmatrix.view(-1, self.nnei, 4)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
[nfnl, 4, self.filter_neuron[-1]],
dtype=self.prec,
Expand All @@ -790,9 +782,6 @@ def forward(
if self.type_one_side:
ii = embedding_idx
ti = -1
# torch.jit is not happy with slice(None)
# ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device)
# applying a mask seems to cause performance degradation
ti_mask = None
else:
# ti: center atom type, ii: neighbor type...
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,6 @@ def forward(
self.filter_neuron[-1],
self.is_sorted,
)[0]
# to make torchscript happy
gg = torch.empty(
nframes,
nloc,
Expand Down
7 changes: 2 additions & 5 deletions deepmd/pt/model/network/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,9 @@ def forward(
yy: torch.Tensor
The output.
"""
# mean = xx.mean(dim=-1, keepdim=True)
# variance = xx.var(dim=-1, unbiased=False, keepdim=True)
# The following operation is the same as above, but will not raise error when using jit model to inference.
# See https://github.com/pytorch/pytorch/issues/85792
if xx.numel() > 0:
variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True)
mean = xx.mean(dim=-1, keepdim=True)
variance = xx.var(dim=-1, unbiased=False, keepdim=True)
yy = (xx - mean) / torch.sqrt(variance + self.eps)
else:
yy = xx
Expand Down
19 changes: 13 additions & 6 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,20 @@ def nlist_distinguish_types(
snsel = tnlist.shape[2]
for ii, ss in enumerate(sel):
# nloc x s(nsel)
# to int because bool cannot be sort on GPU
pick_mask = (tnlist == ii).to(torch.int32)
# nloc x s(nsel), stable sort, nearer neighbors first
pick_mask, imap = torch.sort(pick_mask, dim=-1, descending=True, stable=True)
# nloc x s(nsel)
mask = tnlist == ii
order = (
snsel - torch.arange(snsel, device=mask.device, dtype=torch.int64)
).view(1, 1, -1)
key = torch.where(
mask,
order.to(torch.float32),
torch.full_like(order, -1, dtype=torch.float32),
)
topk_vals, imap = torch.topk(key, ss, dim=-1, largest=True)
# nloc x nsel[ii]
inlist = torch.gather(nlist, 2, imap)
inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1)
valid = topk_vals > 0
inlist = inlist.masked_fill(~valid, -1)
# nloc x nsel[ii]
ret_nlist.append(inlist[..., :ss])
return torch.concat(ret_nlist, dim=-1)
Expand Down
13 changes: 7 additions & 6 deletions source/tests/pt/model/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_consistency(
err_msg=err_msg,
)

def test_jit(
def test_export(
self,
) -> None:
rng = np.random.default_rng(GLOBAL_SEED)
Expand All @@ -132,8 +132,6 @@ def test_jit(
[False, True], # use_econf_tebd
):
dtype = PRECISION_DICT[prec]
rtol, atol = get_tols(prec)
err_msg = f"idt={idt} prec={prec}"
# dpa1 new impl
dd0 = DescrptDPA1(
self.rcut,
Expand All @@ -151,6 +149,9 @@ def test_jit(
)
dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
# dd1 = DescrptDPA1.deserialize(dd0.serialize())
model = torch.jit.script(dd0)
# model = torch.jit.script(dd1)

coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE)
atype_ext = torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE)
nlist = torch.tensor(self.nlist, dtype=int, device=env.DEVICE)

_ = torch.export.export(dd0, (coord_ext, atype_ext, nlist))
Loading
Loading