add MD simulation in hpc

pull/9458/head
zhangxinfeng3 4 years ago
parent 6b9e402790
commit 9cfaba8983

@ -0,0 +1,110 @@
# Contents
- [Description](#description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Result](#result)
- [ModelZoo Homepage](#modelzoo-homepage)
## Description
Molecular Dynamics (MD) is playing an increasingly important role in the research of biology, pharmacy, chemistry, and materials science. The architecture is based on DeePMD, which using an NN scheme for MD simulations, which overcomes the limitations associated to auxiliary quantities like the symmetry functions or the Coulomb matrix. Each environment contains a number of atoms, whose local coordinates are arranged in a symmetry preserving way following the prescription of the Deep Potential method. According to the atomic position, atomic types and box tensor to construct energy, force and virial.
Thanks a lot for DeePMD team's help.
[1] Paper: L Zhang, J Han, H Wang, R Car, W E. Deep potential molecular dynamics: a scalable model with the accuracy of quantum mechanics. Physical review letters 120 (14), 143001 (2018).
[2] Paper: H Wang, L Zhang, J Han, W E. DeePMD-kit: A deep learning package for many-body potential energy representation and molecular dynamics. Computer Physics Communications 228, 178-184 (2018).
## Model Architecture
The overall network architecture of MD simulation is show below.
[Link](https://arxiv.org/abs/1707.09571)
## Dataset
Dataset used: deepmodeling/deepmd-kit/examples/water/data
The data is generated by Quantum Espresso and the input of Quantum Espresso is setted manually.
The directory structure of the data is as follows:
```text
└─data
├─type.raw
├─set.000
│ ├──box.npy
│ ├──coord.npy
│ ├──energy.npy
│ └──force.npy
├─set.001
├─set.002
└─set.003
```
In `deepmodeling/deepmd-kit/source`:
- Use `train/DataSystem.py` to get coord and atype.
- Use function compute_input_stats in `train/DataSystem.py` to get avg and std.
- Use `op/descrpt_se_a.cc` to get nlist.
- Save coord, atype, avg, std and nlist as `Npz` file for infer.
## Environment Requirements
- Hardware (Ascend)
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## Script Description
### Script and Sample Code
```shell
├── md
├── README.md # descriptions about MD
├── script
│ ├── eval.sh # evaluation script
├── src
│ ├── descriptor.py # descriptor function
│ ├── virial.py # calculating virial function
│ └── network.py # MD simulation architecture
└── eval.py # evaluation interface
```
### Training Process
To Be Done
### Evaluation Process
After installing MindSpore via the official website, you can start evaluation as follows:
```shell
python eval.py --dataset_path [DATASET_PATH] --checkpoint_path [CHECKPOINT_PATH]
```
> checkpoint can be trained by using DeePMD-kit, and convert into the ckpt of MindSpore.
### Result
推理的结果如下:
```text
atom_ener: -94.38766 -94.294426 -94.39194 -94.70758 -94.51311 -94.457954 ...
force: 1.64911175 -1.09822524 0.46055657 -1.34915102 -0.33827361 -0.97184098 ...
virial: -11.736662 -4.286214 2.8852937 -4.286209 -10.408775 -5.6738234 ...
```
## ModelZoo Homepage
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

@ -0,0 +1,55 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval."""
import argparse
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.network import Network
parser = argparse.ArgumentParser(description='MD Simulation')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
if __name__ == '__main__':
# get input data
r = np.load(args_opt.dataset_path)
d_coord, d_nlist, avg, std, atype = r['d_coord'], r['d_nlist'], r['avg'], r['std'], r['atype']
batch_size = 1
atype_tensor = Tensor(atype)
avg_tensor = Tensor(avg)
std_tensor = Tensor(std)
d_coord_tensor = Tensor(np.reshape(d_coord, (1, -1, 3)))
d_nlist_tensor = Tensor(d_nlist)
frames = []
for i in range(batch_size):
frames.append(i * 1536)
frames = Tensor(frames)
# evaluation
net = Network()
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.to_float(mstype.float32)
energy, atom_ener, virial = \
net(d_coord_tensor, d_nlist_tensor, frames, avg_tensor, std_tensor, atype_tensor)
print(energy)
print(atom_ener)
print(virial)

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# eval script
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../eval.py --dataset_path=./$DATA_PATH --checkpoint_path=./$CKPT_PATH > log.txt 2>&1 &

@ -0,0 +1,207 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The construction of the descriptor."""
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
class ComputeRij(nn.Cell):
"""compute rij."""
def __init__(self):
super(ComputeRij, self).__init__()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cast = P.Cast()
self.rsum = P.ReduceSum()
self.broadcastto = P.BroadcastTo((1, 192 * 138))
self.broadcastto1 = P.BroadcastTo((1, 192, 138, 3))
self.expdims = P.ExpandDims()
self.concat = P.Concat(axis=1)
self.gather = P.GatherV2()
self.mul = P.Mul()
self.slice = P.Slice()
def construct(self, d_coord_tensor, nlist_tensor, frames):
"""construct function."""
d_coord_tensor = self.cast(d_coord_tensor, mstype.float32)
d_coord_tensor = self.reshape(d_coord_tensor, (1, -1, 3))
coord_tensor = self.slice(d_coord_tensor, (0, 0, 0), (1, 192, 3))
nlist_tensor = self.cast(nlist_tensor, mstype.int32)
nlist_tensor = self.reshape(nlist_tensor, (1, 192, 138))
b_nlist = nlist_tensor > -1
b_nlist = self.cast(b_nlist, mstype.int32)
nlist_tensor_r = b_nlist * nlist_tensor
nlist_tensor_r = self.reshape(nlist_tensor_r, (-1,))
frames = self.cast(frames, mstype.int32)
frames = self.expdims(frames, 1)
frames = self.broadcastto(frames)
frames = self.reshape(frames, (-1,))
nlist_tensor_r = nlist_tensor_r + frames
nlist_tensor_r = self.reshape(nlist_tensor_r, (-1,))
d_coord_tensor = self.reshape(d_coord_tensor, (-1, 3))
selected_coord = self.gather(d_coord_tensor, nlist_tensor_r, 0)
selected_coord = self.reshape(selected_coord, (1, 192, 138, 3))
coord_tensor_expanded = self.expdims(coord_tensor, 2)
coord_tensor_expanded = self.broadcastto1(coord_tensor_expanded)
result_rij_m = selected_coord - coord_tensor_expanded
b_nlist_expanded = self.expdims(b_nlist, 3)
b_nlist_expanded = self.broadcastto1(b_nlist_expanded)
result_rij = result_rij_m * b_nlist_expanded
return result_rij
class ComputeDescriptor(nn.Cell):
"""compute descriptor."""
def __init__(self):
super(ComputeDescriptor, self).__init__()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cast = P.Cast()
self.rsum = P.ReduceSum()
self.broadcastto = P.BroadcastTo((1, 192 * 138))
self.broadcastto1 = P.BroadcastTo((1, 192, 138, 3))
self.broadcastto2 = P.BroadcastTo((1, 192, 138, 3, 3))
self.broadcastto3 = P.BroadcastTo((1, 192, 138, 4))
self.broadcastto4 = P.BroadcastTo((1, 192, 138, 4, 3))
self.expdims = P.ExpandDims()
self.concat = P.Concat(axis=3)
self.gather = P.GatherV2()
self.mul = P.Mul()
self.slice = P.Slice()
self.square = P.Square()
self.inv = P.Inv()
self.sqrt = P.Sqrt()
self.ones = P.OnesLike()
self.eye = P.Eye()
def construct(self, rij_tensor, avg_tensor, std_tensor, nlist_tensor, atype_tensor, r_min=5.8, r_max=6.0):
"""construct function."""
nlist_tensor = self.reshape(nlist_tensor, (1, 192, 138))
b_nlist = nlist_tensor > -1
b_nlist = self.cast(b_nlist, mstype.int32)
b_nlist_expanded = self.expdims(b_nlist, 3)
b_nlist_4 = self.broadcastto3(b_nlist_expanded)
b_nlist_3 = self.broadcastto1(b_nlist_expanded)
b_nlist_expanded = self.expdims(b_nlist_expanded, 4)
b_nlist_33 = self.broadcastto2(b_nlist_expanded)
rij_tensor = rij_tensor + self.cast(1 - b_nlist_3, mstype.float32)
r_2 = self.square(rij_tensor)
d_2 = self.rsum(r_2, 3)
invd_2 = self.inv(d_2)
invd = self.sqrt(invd_2)
invd_4 = self.square(invd_2)
d = invd * d_2
invd_3 = invd_4 * d
b_d_1 = self.cast(d < r_max, mstype.int32)
b_d_2 = self.cast(d < r_min, mstype.int32)
b_d_3 = self.cast(d >= r_min, mstype.int32)
du = 1.0 / (r_max - r_min)
uu = (d - r_min) * du
vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1
dd = (3 * uu * uu * (-6 * uu * uu + 15 * uu - 10) + uu * uu * uu * (-12 * uu + 15)) * du
sw = vv * b_d_3 * b_d_1 + b_d_2
dsw = dd * b_d_3 * b_d_1
invd_2_e = self.expdims(invd_2, 3)
invd_2_e = self.broadcastto1(invd_2_e)
descrpt_1 = rij_tensor * invd_2_e
factor0 = invd_3 * sw - invd_2 * dsw
factor0 = self.expdims(factor0, 3)
factor0 = self.broadcastto1(factor0)
descrpt_deriv_0 = rij_tensor * factor0
descrpt_deriv_0 = descrpt_deriv_0 * b_nlist_3
descrpt_deriv_0 = self.expdims(descrpt_deriv_0, 3)
factor1_0 = self.eye(3, 3, mstype.float32)
factor1_0 = self.expdims(factor1_0, 0)
factor1_0 = self.expdims(factor1_0, 0)
factor1_0 = self.expdims(factor1_0, 0)
factor1_1 = self.expdims(invd_2 * sw, 3)
factor1_1 = self.expdims(factor1_1, 4)
descrpt_deriv_1_0 = factor1_0 * factor1_1
rij_tensor_e1 = self.expdims(rij_tensor, 4)
rij_tensor_e2 = self.expdims(rij_tensor, 3)
rij_tensor_e1 = self.broadcastto2(rij_tensor_e1)
rij_tensor_e2 = self.broadcastto2(rij_tensor_e2)
factor1_3 = self.expdims(2.0 * invd_4 * sw, 3)
factor1_3 = self.expdims(factor1_3, 4)
factor1_3 = self.broadcastto2(factor1_3)
descrpt_deriv_1_1 = factor1_3 * rij_tensor_e1 * rij_tensor_e2
factor1_4 = self.expdims(invd * dsw, 3)
factor1_4 = self.expdims(factor1_4, 3)
factor1_4 = self.broadcastto2(factor1_4)
descrpt_1_e = self.expdims(descrpt_1, 4)
descrpt_1_e = self.broadcastto2(descrpt_1_e)
descrpt_deriv_1_2 = descrpt_1_e * rij_tensor_e2 * factor1_4
descrpt_deriv_1 = (descrpt_deriv_1_1 - descrpt_deriv_1_0 - descrpt_deriv_1_2) * b_nlist_33
descrpt_deriv = self.concat((descrpt_deriv_0, descrpt_deriv_1))
invd_e = self.expdims(invd, 3)
descrpt = self.concat((invd_e, descrpt_1))
sw = self.broadcastto3(self.expdims(sw, 3))
descrpt = descrpt * sw * b_nlist_4
avg_tensor = self.cast(avg_tensor, mstype.float32)
std_tensor = self.cast(std_tensor, mstype.float32)
atype_tensor = self.reshape(atype_tensor, (-1,))
atype_tensor = self.cast(atype_tensor, mstype.int32)
avg_tensor = self.gather(avg_tensor, atype_tensor, 0)
std_tensor = self.gather(std_tensor, atype_tensor, 0)
avg_tensor = self.reshape(avg_tensor, (1, 192, 138, 4))
std_tensor = self.reshape(std_tensor, (1, 192, 138, 4))
std_tensor_2 = self.expdims(std_tensor, 4)
std_tensor_2 = self.broadcastto4(std_tensor_2)
descrpt = (descrpt - avg_tensor) / std_tensor
descrpt_deriv = descrpt_deriv / std_tensor_2
return descrpt, descrpt_deriv
class DescriptorSeA(nn.Cell):
def __init__(self):
super(DescriptorSeA, self).__init__()
self.compute_rij = ComputeRij()
self.compute_descriptor = ComputeDescriptor()
def construct(self, coord, nlist, frames, avg, std, atype):
rij = self.compute_rij(coord, nlist, frames)
descrpt, descrpt_deriv = self.compute_descriptor(rij, avg, std, nlist, atype)
return rij, descrpt, descrpt_deriv

File diff suppressed because it is too large Load Diff

@ -0,0 +1,68 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Calculate virial of atoms."""
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore.ops import operations as P
class ProdVirialSeA(nn.Cell):
"""calculate virial."""
def __init__(self):
super(ProdVirialSeA, self).__init__()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cast = P.Cast()
self.rsum = P.ReduceSum()
self.rksum = P.ReduceSum(keep_dims=True)
self.broadcastto1 = P.BroadcastTo((1, 192, 138, 4, 3, 3))
self.broadcastto2 = P.BroadcastTo((1, 192, 138, 4, 3))
self.broadcastto3 = P.BroadcastTo((1, 192, 138, 3))
self.expdims = P.ExpandDims()
def construct(self, net_deriv_reshape, descrpt_deriv, rij, nlist):
"""construct function."""
descrpt_deriv = self.cast(descrpt_deriv, mstype.float32)
descrpt_deriv = self.reshape(descrpt_deriv, (1, 192, 138, 4, 3))
net_deriv_reshape = self.cast(net_deriv_reshape, mstype.float32)
net_deriv_reshape = self.reshape(net_deriv_reshape, (1, 192, 138, 4))
net_deriv_reshape = self.expdims(net_deriv_reshape, 4)
net_deriv_reshape = self.broadcastto2(net_deriv_reshape)
rij = self.cast(rij, mstype.float32)
rij = self.reshape(rij, (1, 192, 138, 3))
rij = self.expdims(rij, 3)
rij = self.expdims(rij, 4)
rij = self.broadcastto1(rij)
nlist = self.cast(nlist, mstype.int32)
nlist = self.reshape(nlist, (1, 192, 138))
nlist = self.expdims(nlist, 3)
nlist = self.broadcastto3(nlist)
tmp = descrpt_deriv * net_deriv_reshape
b_blist = self.cast(nlist > -1, mstype.int32)
b_blist = self.expdims(b_blist, 3)
b_blist = self.broadcastto2(b_blist)
tmp_1 = tmp * b_blist
tmp_1 = self.expdims(tmp_1, 5)
tmp_1 = self.broadcastto1(tmp_1)
out = tmp_1 * rij
out = self.rsum(out, (1, 2, 3))
return out
Loading…
Cancel
Save