!13032 Add modules of Sponge

From: @zhangxinfeng3
Reviewed-by: @ljl0711,@wang_zi_dong
Signed-off-by: @wang_zi_dong
pull/13032/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7583b258df

@ -399,7 +399,8 @@ if(ENABLE_GPU)
${CUDNN_LIBRARY_PATH}
${CUDA_PATH}/lib64/libcudart.so
${CUDA_PATH}/lib64/stubs/libcuda.so
${CUDA_PATH}/lib64/libcusolver.so)
${CUDA_PATH}/lib64/libcusolver.so
${CUDA_PATH}/lib64/libcufft.so)
if(ENABLE_MPI)
set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
endif()

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/getcenter_impl.cuh"
__global__ void GetCenterOfGeometryKernel(const int center_numbers, float center_numbers_inverse,
const int *center_atoms, const VECTOR *crd, VECTOR *center_of_geometry) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < center_numbers) {
int atom_i = center_atoms[i];
VECTOR temp = center_numbers_inverse * crd[atom_i];
atomicAdd(&center_of_geometry[0].x, temp.x);
atomicAdd(&center_of_geometry[0].y, temp.y);
atomicAdd(&center_of_geometry[0].z, temp.z);
}
}
void GetCenterOfGeometry(const int center_numbers, float center_numbers_inverse, const int *center_atoms,
const float *crd_f, float *center_of_geometry_f, cudaStream_t stream) {
VECTOR *crd = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(crd_f));
VECTOR *center_of_geometry = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(center_of_geometry_f));
GetCenterOfGeometryKernel<<<ceilf(static_cast<float>(center_numbers) / 32), 32, 0, stream>>>(
center_numbers, center_numbers_inverse, center_atoms, crd, center_of_geometry);
cudaStreamSynchronize(stream);
return;
}
void GetCenterOfGeometry(const int center_numbers, float center_numbers_inverse, const int *center_atoms, float *crd_f,
float *center_of_geometry_f, cudaStream_t stream);

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_GETCENTER_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_GETCENTER_IMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void GetCenterOfGeometry(const int center_numbers, float center_numbers_inverse, const int *center_atoms,
const float *crd_f, float *center_of_geometry_f, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_GETCENTER_IMPL_H_

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/mdtemperature_impl.cuh"
__global__ void MDTemperatureKernel(const int residue_numbers, const int *start, const int *end, const VECTOR *atom_vel,
const float *atom_mass, float *ek) {
int residue_i = blockDim.x * blockIdx.x + threadIdx.x;
if (residue_i < residue_numbers) {
VECTOR momentum = {0., 0., 0.};
float res_mass = 0.;
int s = start[residue_i];
int e = end[residue_i];
float mass_lin;
for (int atom_i = s; atom_i < e; atom_i = atom_i + 1) {
mass_lin = atom_mass[atom_i];
momentum.x = momentum.x + mass_lin * atom_vel[atom_i].x;
momentum.y = momentum.y + mass_lin * atom_vel[atom_i].y;
momentum.z = momentum.z + mass_lin * atom_vel[atom_i].z;
res_mass = res_mass + mass_lin;
}
ek[residue_i] = 0.5 * (momentum.x * momentum.x + momentum.y * momentum.y + momentum.z * momentum.z) / res_mass *
2. / 3. / CONSTANT_kB / residue_numbers;
}
}
void MDTemperature(const int residue_numbers, const int *start, const int *end, const float *atom_vel_f,
const float *atom_mass, float *ek, cudaStream_t stream) {
VECTOR *atom_vel = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(atom_vel_f));
MDTemperatureKernel<<<ceilf(static_cast<float>(residue_numbers) / 32), 32, 0, stream>>>(residue_numbers, start, end,
atom_vel, atom_mass, ek);
cudaStreamSynchronize(stream);
return;
}
void MDTemperature(const int residue_numbers, const int *start, const int *end, const float *atom_vel_f,
const float *atom_mass, float *ek, cudaStream_t stream);

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_MDTEMPERATURE_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_MDTEMPERATURE_IMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void MDTemperature(const int residue_numbers, const int *start, const int *end, const float *atom_vel_f,
const float *atom_mass, float *ek, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_MDTEMPERATURE_IMPL_H_

@ -14,31 +14,59 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SPONGE_COMMONHW_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SPONGE_COMMONHW_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_SPONGE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_SPONGE_H_
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <curand_kernel.h>
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <cufft.h>
#include "runtime/device/gpu/cuda_common.h"
#define CONSTANT_Pi 3.1415926535897932
#define TWO_DIVIDED_BY_SQRT_PI 1.1283791670218446
#define CONSTANT_kB 0.00198716
static dim3 thread_LJ(8, 32);
struct VECTOR {
float x;
float y;
float z;
};
struct INT_VECTOR {
int int_x;
int int_y;
int int_z;
};
struct UNSIGNED_INT_VECTOR {
unsigned int uint_x;
unsigned int uint_y;
unsigned int uint_z;
};
struct NEIGHBOR_LIST {
int atom_numbers;
int *atom_serial;
};
struct UINT_VECTOR_LJ_TYPE {
unsigned int uint_x;
unsigned int uint_y;
unsigned int uint_z;
int LJ_type;
float charge;
};
struct GRID_BUCKET {
int *atom_serial;
};
struct GRID_POINTER {
int *grid_serial;
};
__device__ __host__ static inline VECTOR Get_Periodic_Displacement(const UNSIGNED_INT_VECTOR uvec_a,
const UNSIGNED_INT_VECTOR uvec_b,
const VECTOR scaler) {
@ -48,6 +76,15 @@ __device__ __host__ static inline VECTOR Get_Periodic_Displacement(const UNSIGNE
dr.z = (static_cast<int>(uvec_a.uint_z - uvec_b.uint_z)) * scaler.z;
return dr;
}
__device__ __host__ static inline VECTOR Get_Periodic_Displacement(const UINT_VECTOR_LJ_TYPE uvec_a,
const UINT_VECTOR_LJ_TYPE uvec_b,
const VECTOR scaler) {
VECTOR dr;
dr.x = (static_cast<int>(uvec_a.uint_x - uvec_b.uint_x)) * scaler.x;
dr.y = (static_cast<int>(uvec_a.uint_y - uvec_b.uint_y)) * scaler.y;
dr.z = (static_cast<int>(uvec_a.uint_z - uvec_b.uint_z)) * scaler.z;
return dr;
}
__device__ __host__ static inline VECTOR operator+(const VECTOR &veca, const VECTOR &vecb) {
VECTOR vec;
@ -91,4 +128,124 @@ __device__ __host__ static inline VECTOR operator^(const VECTOR &veca, const VEC
return vec;
}
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SPONGE_COMMON_H_
__global__ static void construct_neighbor_list_kernel(int atom_numbers, int max_neighbor_numbers, int *nl_atom_numbers,
int *nl_atom_serial, NEIGHBOR_LIST *nl) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < atom_numbers; i += gridDim.x * blockDim.x) {
nl[i].atom_numbers = nl_atom_numbers[i];
nl[i].atom_serial = nl_atom_serial + i * max_neighbor_numbers;
}
}
static inline bool Malloc_Safely(void **address, size_t size) {
address[0] = NULL;
address[0] = reinterpret_cast<void *>(malloc(size));
if (address[0] != NULL) {
return true;
} else {
printf("malloc failed!\n");
getchar();
return false;
}
}
static inline bool Cuda_Malloc_Safely(void **address, size_t size) {
cudaError_t cuda_error = cudaMalloc(&address[0], size);
if (cuda_error == 0) {
return true;
} else {
printf("cudaMalloc failed! error %d\n", cuda_error);
getchar();
return false;
}
}
__global__ static void Copy_Crd_To_New_Crd_Start(const int atom_numbers, const UNSIGNED_INT_VECTOR *crd,
UINT_VECTOR_LJ_TYPE *new_crd, const int *LJ_type,
const float *charge) {
int atom_i = blockDim.x * blockIdx.x + threadIdx.x;
if (atom_i < atom_numbers) {
new_crd[atom_i].uint_x = crd[atom_i].uint_x;
new_crd[atom_i].uint_y = crd[atom_i].uint_y;
new_crd[atom_i].uint_z = crd[atom_i].uint_z;
new_crd[atom_i].LJ_type = LJ_type[atom_i];
new_crd[atom_i].charge = charge[atom_i];
}
}
__global__ static void Rand_Normal(const int float4_numbers, curandStatePhilox4_32_10_t *rand_state,
float4 *rand_float4) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < float4_numbers) {
rand_float4[i] = curand_normal4(&rand_state[i]);
}
}
__global__ static void Setup_Rand_Normal_Kernel(const int float4_numbers, curandStatePhilox4_32_10_t *rand_state,
const int seed) {
int id = threadIdx.x + blockIdx.x * blockDim.x;
/* Each thread gets same seed, a different sequence
number, no offset */
if (id < float4_numbers) {
curand_init(seed, id, 0, &rand_state[id]);
}
}
__global__ static void Reset_List(const int element_numbers, int *list, const int replace_element) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
list[i] = replace_element;
}
}
__global__ static void Reset_List(const int element_numbers, float *list, const float replace_element) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
list[i] = replace_element;
}
}
__global__ static void Sum_Of_List(const int element_numbers, const float *list, float *sum) {
if (threadIdx.x == 0) {
sum[0] = 0.;
}
__syncthreads();
float lin = 0.;
for (int i = threadIdx.x; i < element_numbers; i = i + blockDim.x) {
lin = lin + list[i];
}
atomicAdd(sum, lin);
}
__global__ static void Scale_List(const int element_numbers, float *list, float scaler) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
list[i] = list[i] * scaler;
}
}
__global__ static void Copy_List(const int element_numbers, const int *origin_list, int *list) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
list[i] = origin_list[i];
}
}
__global__ static void Copy_List(const int element_numbers, const float *origin_list, float *list) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
list[i] = origin_list[i];
}
}
__global__ static void Print(const size_t size, const float *input_x) {
for (size_t i = 0; i < size; i++) {
printf("%f\n", input_x[i]);
}
return;
}
__global__ static void Print(const size_t size, const int *input_x) {
for (size_t i = 0; i < size; i++) {
printf("%d\n", input_x[i]);
}
return;
}
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_SPONGE_H_

@ -29,8 +29,6 @@ __global__ void DihedralAtomEnergyKernel(int dihedral_numbers, const UNSIGNED_IN
int atom_k = atom_c[dihedral_i];
int atom_l = atom_d[dihedral_i];
int temp_ipn = ipn[dihedral_i];
float temp_pk = pk[dihedral_i];
float temp_pn = pn[dihedral_i];
float temp_gamc = gamc[dihedral_i];

@ -29,8 +29,6 @@ __global__ void DihedralEnergyKernel(int dihedral_numbers, const UNSIGNED_INT_VE
int atom_k = atom_c[dihedral_i];
int atom_l = atom_d[dihedral_i];
int temp_ipn = ipn[dihedral_i];
float temp_pk = pk[dihedral_i];
float temp_pn = pn[dihedral_i];
float temp_gamc = gamc[dihedral_i];

@ -31,7 +31,6 @@ __global__ void DihedralForceKernel(int dihedral_numbers, const UNSIGNED_INT_VEC
int temp_ipn = ipn[dihedral_i];
float temp_pk = pk[dihedral_i];
float temp_pn = pn[dihedral_i];
float temp_gamc = gamc[dihedral_i];
float temp_gams = gams[dihedral_i];

@ -0,0 +1,102 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/lj/lj_energy_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
__global__ void LJ_Energy_CUDA(const int atom_numbers, const NEIGHBOR_LIST *nl, const UINT_VECTOR_LJ_TYPE *uint_crd,
const VECTOR *boxlength, const float *LJ_type_A, const float *LJ_type_B,
const float cutoff_square, float *lj_ene) {
int atom_i = blockDim.x * blockIdx.x + threadIdx.x;
if (atom_i < atom_numbers) {
NEIGHBOR_LIST nl_i = nl[atom_i];
int N = nl_i.atom_numbers;
int atom_j;
int int_x;
int int_y;
int int_z;
UINT_VECTOR_LJ_TYPE r1 = uint_crd[atom_i], r2;
VECTOR dr;
float dr2;
float dr_2;
float dr_4;
float dr_6;
float ene_lin = 0.;
int x, y;
int atom_pair_LJ_type;
for (int j = threadIdx.y; j < N; j = j + blockDim.y) {
atom_j = nl_i.atom_serial[j];
r2 = uint_crd[atom_j];
int_x = r2.uint_x - r1.uint_x;
int_y = r2.uint_y - r1.uint_y;
int_z = r2.uint_z - r1.uint_z;
dr.x = boxlength[0].x * int_x;
dr.y = boxlength[0].y * int_y;
dr.z = boxlength[0].z * int_z;
dr2 = dr.x * dr.x + dr.y * dr.y + dr.z * dr.z;
if (dr2 < cutoff_square) {
dr_2 = 1. / dr2;
dr_4 = dr_2 * dr_2;
dr_6 = dr_4 * dr_2;
y = (r2.LJ_type - r1.LJ_type);
x = y >> 31;
y = (y ^ x) - x;
x = r2.LJ_type + r1.LJ_type;
r2.LJ_type = (x + y) >> 1;
x = (x - y) >> 1;
atom_pair_LJ_type = (r2.LJ_type * (r2.LJ_type + 1) >> 1) + x;
dr_2 = (0.083333333 * LJ_type_A[atom_pair_LJ_type] * dr_6 - 0.166666666 * LJ_type_B[atom_pair_LJ_type]) * dr_6;
ene_lin = ene_lin + dr_2;
}
}
atomicAdd(&lj_ene[atom_i], ene_lin);
}
}
void LJEnergy(const int atom_numbers, const float cutoff_square, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *scaler_f, float *uint_crd_with_LJ, int *nl_atom_numbers,
int *nl_atom_serial, int *nl, const float *d_LJ_A, const float *d_LJ_B, float *d_LJ_energy_atom,
cudaStream_t stream) {
VECTOR *scaler = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(scaler_f));
int max_neighbor_numbers = 800;
NEIGHBOR_LIST *nl_a = reinterpret_cast<NEIGHBOR_LIST *>(nl);
construct_neighbor_list_kernel<<<ceilf(static_cast<float>(atom_numbers) / 128), 128, 0, stream>>>(
atom_numbers, max_neighbor_numbers, nl_atom_numbers, nl_atom_serial, nl_a);
UINT_VECTOR_LJ_TYPE *uint_crd_with_LJ_a = reinterpret_cast<UINT_VECTOR_LJ_TYPE *>(uint_crd_with_LJ);
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
Copy_Crd_To_New_Crd_Start<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, uint_crd, uint_crd_with_LJ_a, LJtype, charge);
Reset_List<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(atom_numbers, d_LJ_energy_atom, 0.);
LJ_Energy_CUDA<<<ceilf(static_cast<float>(atom_numbers) / 8), thread_LJ, 0, stream>>>(
atom_numbers, nl_a, uint_crd_with_LJ_a, scaler, d_LJ_A, d_LJ_B, cutoff_square, d_LJ_energy_atom);
return;
}
void LJEnergy(const int atom_numbers, const float cutoff_square, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *scaler_f, float *uint_crd_with_LJ, int *nl_atom_numbers,
int *nl_atom_serial, int *nl, const float *d_LJ_A, const float *d_LJ_B, float *d_LJ_energy_atom,
cudaStream_t stream);

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_LJ_LJ_ENERGY_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_LJ_LJ_ENERGY_IMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void LJEnergy(const int atom_numbers, const float cutoff_square, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *scaler_f, float *uint_crd_with_LJ, int *nl_atom_numbers,
int *nl_atom_serial, int *nl, const float *d_LJ_A, const float *d_LJ_B, float *d_LJ_energy_atom,
cudaStream_t stream);
#endif

@ -0,0 +1,116 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/lj/lj_force_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
__global__ void LJ_Force_CUDA(const int atom_numbers, const NEIGHBOR_LIST *nl, const UINT_VECTOR_LJ_TYPE *uint_crd,
const VECTOR *boxlength, const float *LJ_type_A, const float *LJ_type_B,
const float cutoff_square, VECTOR *frc) {
int atom_i = blockDim.x * blockIdx.x + threadIdx.x;
if (atom_i < atom_numbers) {
NEIGHBOR_LIST nl_i = nl[atom_i];
int N = nl_i.atom_numbers;
int B = ceilf(static_cast<float>(N) / blockDim.y);
int atom_j;
int int_x;
int int_y;
int int_z;
UINT_VECTOR_LJ_TYPE r1 = uint_crd[atom_i], r2;
VECTOR dr;
float dr2;
float dr_2;
float dr_4;
float dr_8;
float dr_14;
float frc_abs = 0.;
VECTOR frc_lin;
VECTOR frc_record = {0., 0., 0.};
int x, y;
int atom_pair_LJ_type;
for (int j = threadIdx.y * B; j < (threadIdx.y + 1) * B; j = j + 1) {
if (j < N) {
atom_j = nl_i.atom_serial[j];
r2 = uint_crd[atom_j];
int_x = r2.uint_x - r1.uint_x;
int_y = r2.uint_y - r1.uint_y;
int_z = r2.uint_z - r1.uint_z;
dr.x = boxlength[0].x * int_x;
dr.y = boxlength[0].y * int_y;
dr.z = boxlength[0].z * int_z;
dr2 = dr.x * dr.x + dr.y * dr.y + dr.z * dr.z;
if (dr2 < cutoff_square) {
dr_2 = 1. / dr2;
dr_4 = dr_2 * dr_2;
dr_8 = dr_4 * dr_4;
dr_14 = dr_8 * dr_4 * dr_2;
y = (r2.LJ_type - r1.LJ_type);
x = y >> 31;
y = (y ^ x) - x;
x = r2.LJ_type + r1.LJ_type;
r2.LJ_type = (x + y) >> 1;
x = (x - y) >> 1;
atom_pair_LJ_type = (r2.LJ_type * (r2.LJ_type + 1) >> 1) + x;
frc_abs = -LJ_type_A[atom_pair_LJ_type] * dr_14 + LJ_type_B[atom_pair_LJ_type] * dr_8;
frc_lin.x = frc_abs * dr.x;
frc_lin.y = frc_abs * dr.y;
frc_lin.z = frc_abs * dr.z;
frc_record.x = frc_record.x + frc_lin.x;
frc_record.y = frc_record.y + frc_lin.y;
frc_record.z = frc_record.z + frc_lin.z;
atomicAdd(&frc[atom_j].x, -frc_lin.x);
atomicAdd(&frc[atom_j].y, -frc_lin.y);
atomicAdd(&frc[atom_j].z, -frc_lin.z);
}
}
}
atomicAdd(&frc[atom_i].x, frc_record.x);
atomicAdd(&frc[atom_i].y, frc_record.y);
atomicAdd(&frc[atom_i].z, frc_record.z);
}
}
void LJForce(const int atom_numbers, const float cutoff_square, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *scaler_f, float *uint_crd_with_LJ, int *nl_atom_numbers,
int *nl_atom_serial, int *nl, const float *d_LJ_A, const float *d_LJ_B, float *frc_f,
cudaStream_t stream) {
VECTOR *frc = reinterpret_cast<VECTOR *>(frc_f);
VECTOR *scaler = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(scaler_f));
int max_neighbor_numbers = 800;
NEIGHBOR_LIST *nl_a = reinterpret_cast<NEIGHBOR_LIST *>(nl);
construct_neighbor_list_kernel<<<ceilf(static_cast<float>(atom_numbers) / 128), 128, 0, stream>>>(
atom_numbers, max_neighbor_numbers, nl_atom_numbers, nl_atom_serial, nl_a);
UINT_VECTOR_LJ_TYPE *uint_crd_with_LJ_a = reinterpret_cast<UINT_VECTOR_LJ_TYPE *>(uint_crd_with_LJ);
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
Copy_Crd_To_New_Crd_Start<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, uint_crd, uint_crd_with_LJ_a, LJtype, charge);
LJ_Force_CUDA<<<ceilf(static_cast<float>(atom_numbers) / 8), thread_LJ, 0, stream>>>(
atom_numbers, nl_a, uint_crd_with_LJ_a, scaler, d_LJ_A, d_LJ_B, cutoff_square, frc);
return;
}
void LJForce(const int atom_numbers, const float cutoff_square, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *scaler_f, float *uint_crd_with_LJ, int *nl_atom_numbers,
int *nl_atom_serial, int *nl, const float *d_LJ_A, const float *d_LJ_B, float *frc_f, cudaStream_t stream);

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_LJ_LJ_FORCE_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_LJ_LJ_FORCE_IMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void LJForce(const int atom_numbers, const float cutoff_square, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *scaler_f, float *uint_crd_with_LJ, int *nl_atom_numbers,
int *nl_atom_serial, int *nl, const float *d_LJ_A, const float *d_LJ_B, float *frc_f, cudaStream_t stream);
#endif

@ -0,0 +1,132 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/lj/lj_force_with_pme_direct_force_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
__global__ void LJ_Force_With_Direct_CF_CUDA(const int atom_numbers, const NEIGHBOR_LIST *nl,
const UINT_VECTOR_LJ_TYPE *uint_crd, const VECTOR *boxlength,
const float *LJ_type_A, const float *LJ_type_B, const float cutoff,
VECTOR *frc, const float pme_beta, const float sqrt_pi) {
int atom_i = blockDim.x * blockIdx.x + threadIdx.x;
if (atom_i < atom_numbers) {
NEIGHBOR_LIST nl_i = nl[atom_i];
int N = nl_i.atom_numbers;
int atom_j;
int int_x;
int int_y;
int int_z;
UINT_VECTOR_LJ_TYPE r1 = uint_crd[atom_i], r2;
VECTOR dr;
float dr_2;
float dr_4;
float dr_8;
float dr_6;
float frc_abs = 0.;
VECTOR frc_lin;
VECTOR frc_record = {0., 0., 0.};
float charge_i = r1.charge;
float charge_j;
float dr_abs;
float dr_1;
float beta_dr;
float frc_cf_abs;
int x, y;
int atom_pair_LJ_type;
for (int j = threadIdx.y; j < N; j = j + blockDim.y) {
atom_j = nl_i.atom_serial[j];
r2 = uint_crd[atom_j];
charge_j = r2.charge;
int_x = r2.uint_x - r1.uint_x;
int_y = r2.uint_y - r1.uint_y;
int_z = r2.uint_z - r1.uint_z;
dr.x = boxlength[0].x * int_x;
dr.y = boxlength[0].y * int_y;
dr.z = boxlength[0].z * int_z;
dr_abs = norm3df(dr.x, dr.y, dr.z);
if (dr_abs < cutoff) {
dr_1 = 1. / dr_abs;
dr_2 = dr_1 * dr_1;
dr_4 = dr_2 * dr_2;
dr_8 = dr_4 * dr_4;
dr_6 = dr_4 * dr_2;
y = (r2.LJ_type - r1.LJ_type);
x = y >> 31;
y = (y ^ x) - x;
x = r2.LJ_type + r1.LJ_type;
r2.LJ_type = (x + y) >> 1;
x = (x - y) >> 1;
atom_pair_LJ_type = (r2.LJ_type * (r2.LJ_type + 1) >> 1) + x;
frc_abs = (-LJ_type_A[atom_pair_LJ_type] * dr_6 + LJ_type_B[atom_pair_LJ_type]) * dr_8;
beta_dr = pme_beta * dr_abs;
frc_cf_abs = beta_dr * sqrt_pi * expf(-beta_dr * beta_dr) + erfcf(beta_dr);
frc_cf_abs = frc_cf_abs * dr_2 * dr_1;
frc_cf_abs = charge_i * charge_j * frc_cf_abs;
frc_abs = frc_abs - frc_cf_abs;
frc_lin.x = frc_abs * dr.x;
frc_lin.y = frc_abs * dr.y;
frc_lin.z = frc_abs * dr.z;
frc_record.x = frc_record.x + frc_lin.x;
frc_record.y = frc_record.y + frc_lin.y;
frc_record.z = frc_record.z + frc_lin.z;
atomicAdd(&frc[atom_j].x, -frc_lin.x);
atomicAdd(&frc[atom_j].y, -frc_lin.y);
atomicAdd(&frc[atom_j].z, -frc_lin.z);
}
}
atomicAdd(&frc[atom_i].x, frc_record.x);
atomicAdd(&frc[atom_i].y, frc_record.y);
atomicAdd(&frc[atom_i].z, frc_record.z);
}
}
void LJForceWithPMEDirectForce(const int atom_numbers, const float cutoff, const float pme_beta, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *scaler_f, float *uint_crd_with_LJ,
int *nl_atom_numbers, int *nl_atom_serial, int *nl, const float *d_LJ_A,
const float *d_LJ_B, float *frc_f, cudaStream_t stream) {
VECTOR *frc = reinterpret_cast<VECTOR *>(frc_f);
VECTOR *scaler = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(scaler_f));
int max_neighbor_numbers = 800;
NEIGHBOR_LIST *nl_a = reinterpret_cast<NEIGHBOR_LIST *>(nl);
construct_neighbor_list_kernel<<<ceilf(static_cast<float>(atom_numbers) / 128), 128, 0, stream>>>(
atom_numbers, max_neighbor_numbers, nl_atom_numbers, nl_atom_serial, nl_a);
UINT_VECTOR_LJ_TYPE *uint_crd_with_LJ_a = reinterpret_cast<UINT_VECTOR_LJ_TYPE *>(uint_crd_with_LJ);
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
Copy_Crd_To_New_Crd_Start<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, uint_crd, uint_crd_with_LJ_a, LJtype, charge);
LJ_Force_With_Direct_CF_CUDA<<<ceilf(static_cast<float>(atom_numbers) / 8), thread_LJ, 0, stream>>>(
atom_numbers, nl_a, uint_crd_with_LJ_a, scaler, d_LJ_A, d_LJ_B, cutoff, frc, pme_beta, TWO_DIVIDED_BY_SQRT_PI);
return;
}
void LJForceWithPMEDirectForce(const int atom_numbers, const float cutoff, const float pme_beta, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *scaler_f, float *uint_crd_with_LJ,
int *nl_atom_numbers, int *nl_atom_serial, int *nl, const float *d_LJ_A,
const float *d_LJ_B, float *frc_f, cudaStream_t stream);

@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_LJ_LJ_FORCE_WITH_PME_DIRECT_FORCE_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_LJ_LJ_FORCE_WITH_PME_DIRECT_FORCE_IMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void LJForceWithPMEDirectForce(const int atom_numbers, const float cutoff, const float pme_beta, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *scaler_f, float *uint_crd_with_LJ,
int *nl_atom_numbers, int *nl_atom_serial, int *nl, const float *d_LJ_A,
const float *d_LJ_B, float *frc_f, cudaStream_t stream);
#endif

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/nb14/dihedral_14_cf_atom_energy_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
__global__ void Dihedral14CFAtomEnergyKernel(const int dihedral_14_numbers, const UINT_VECTOR_LJ_TYPE *uint_crd,
const VECTOR *boxlength, const int *a_14, const int *b_14,
const float *cf_scale_factor, float *ene) {
int dihedral_14_i = blockDim.x * blockIdx.x + threadIdx.x;
if (dihedral_14_i < dihedral_14_numbers) {
int atom_i = a_14[dihedral_14_i];
int atom_j = b_14[dihedral_14_i];
UINT_VECTOR_LJ_TYPE r1 = uint_crd[atom_i];
UINT_VECTOR_LJ_TYPE r2 = uint_crd[atom_j];
int int_x;
int int_y;
int int_z;
VECTOR dr;
float r_1;
float ene_lin = 0.;
int_x = r2.uint_x - r1.uint_x;
int_y = r2.uint_y - r1.uint_y;
int_z = r2.uint_z - r1.uint_z;
dr.x = boxlength[0].x * int_x;
dr.y = boxlength[0].y * int_y;
dr.z = boxlength[0].z * int_z;
r_1 = rnorm3df(dr.x, dr.y, dr.z);
ene_lin = r1.charge * r2.charge * r_1;
ene_lin *= cf_scale_factor[dihedral_14_i];
atomicAdd(&ene[atom_i], ene_lin);
}
}
void Dihedral14CFAtomEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *boxlength_f, const int *a_14,
const int *b_14, const float *cf_scale_factor, float *ene, cudaStream_t stream) {
size_t thread_per_block = 128;
size_t block_per_grid = ceilf(static_cast<float>(atom_numbers) / 128);
UINT_VECTOR_LJ_TYPE *uint_crd_with_LJ = NULL;
Cuda_Malloc_Safely(reinterpret_cast<void **>(&uint_crd_with_LJ), sizeof(UINT_VECTOR_LJ_TYPE) * atom_numbers);
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
Copy_Crd_To_New_Crd_Start<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, uint_crd, uint_crd_with_LJ, LJtype, charge);
VECTOR *boxlength = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(boxlength_f));
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128>>>(atom_numbers, ene, 0.);
Dihedral14CFAtomEnergyKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
dihedral_14_numbers, uint_crd_with_LJ, boxlength, a_14, b_14, cf_scale_factor, ene);
cudaStreamSynchronize(stream);
return;
}
void Dihedral14CFAtomEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *boxlength_f, const int *a_14,
const int *b_14, const float *cf_scale_factor, float *ene, cudaStream_t stream);

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_CF_ATOM_ENERGY_IMPL_H
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_CF_ATOM_ENERGY_IMPL_H
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void Dihedral14CFAtomEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *boxlength_f, const int *a_14,
const int *b_14, const float *cf_scale_factor, float *ene, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_CF_ENERGY_IMPL_H

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/nb14/dihedral_14_cf_energy_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
__global__ void Dihedral14CFEnergyKernel(const int dihedral_14_numbers, const UINT_VECTOR_LJ_TYPE *uint_crd,
const VECTOR *boxlength, const int *a_14, const int *b_14,
const float *cf_scale_factor, float *ene) {
int dihedral_14_i = blockDim.x * blockIdx.x + threadIdx.x;
if (dihedral_14_i < dihedral_14_numbers) {
int atom_i = a_14[dihedral_14_i];
int atom_j = b_14[dihedral_14_i];
UINT_VECTOR_LJ_TYPE r1 = uint_crd[atom_i];
UINT_VECTOR_LJ_TYPE r2 = uint_crd[atom_j];
int int_x;
int int_y;
int int_z;
VECTOR dr;
float r_1;
float ene_lin = 0.;
int_x = r2.uint_x - r1.uint_x;
int_y = r2.uint_y - r1.uint_y;
int_z = r2.uint_z - r1.uint_z;
dr.x = boxlength[0].x * int_x;
dr.y = boxlength[0].y * int_y;
dr.z = boxlength[0].z * int_z;
r_1 = rnorm3df(dr.x, dr.y, dr.z);
ene_lin = r1.charge * r2.charge * r_1;
ene_lin *= cf_scale_factor[dihedral_14_i];
ene[dihedral_14_i] = ene_lin;
}
}
void Dihedral14CFEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *boxlength_f, const int *a_14, const int *b_14,
const float *cf_scale_factor, float *ene, cudaStream_t stream) {
size_t thread_per_block = 128;
size_t block_per_grid = ceilf(static_cast<float>(atom_numbers) / 128);
UINT_VECTOR_LJ_TYPE *uint_crd_with_LJ = NULL;
Cuda_Malloc_Safely(reinterpret_cast<void **>(&uint_crd_with_LJ), sizeof(UINT_VECTOR_LJ_TYPE) * atom_numbers);
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
Copy_Crd_To_New_Crd_Start<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, uint_crd, uint_crd_with_LJ, LJtype, charge);
VECTOR *boxlength = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(boxlength_f));
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128>>>(atom_numbers, ene, 0.);
Dihedral14CFEnergyKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
dihedral_14_numbers, uint_crd_with_LJ, boxlength, a_14, b_14, cf_scale_factor, ene);
cudaStreamSynchronize(stream);
return;
}
void Dihedral14CFEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *boxlength_f, const int *a_14, const int *b_14,
const float *cf_scale_factor, float *ene, cudaStream_t stream);

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_CF_ENERGY_IMPL_H
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_CF_ENERGY_IMPL_H
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void Dihedral14CFEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *boxlength, const int *a_14, const int *b_14,
const float *cf_scale_factor, float *ene, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_CF_ENERGY_IMPL_H

@ -0,0 +1,102 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/nb14/dihedral_14_lj_atom_energy_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
__global__ void Dihedral14LJAtomEnergyKernel(const int dihedral_14_numbers, const UINT_VECTOR_LJ_TYPE *uint_crd,
const VECTOR *boxlength, const int *a_14, const int *b_14,
const float *lj_scale_factor, const float *LJ_type_A,
const float *LJ_type_B, float *ene) {
int dihedral_14_i = blockDim.x * blockIdx.x + threadIdx.x;
if (dihedral_14_i < dihedral_14_numbers) {
int atom_i = a_14[dihedral_14_i];
int atom_j = b_14[dihedral_14_i];
UINT_VECTOR_LJ_TYPE r1 = uint_crd[atom_i];
UINT_VECTOR_LJ_TYPE r2 = uint_crd[atom_j];
int int_x;
int int_y;
int int_z;
VECTOR dr;
float dr2;
float dr_2;
float dr_4;
float dr_6;
float dr_12;
float ene_lin = 0.;
int x, y;
int atom_pair_LJ_type;
int_x = r2.uint_x - r1.uint_x;
int_y = r2.uint_y - r1.uint_y;
int_z = r2.uint_z - r1.uint_z;
dr.x = boxlength[0].x * int_x;
dr.y = boxlength[0].y * int_y;
dr.z = boxlength[0].z * int_z;
dr2 = dr.x * dr.x + dr.y * dr.y + dr.z * dr.z;
dr_2 = 1. / dr2;
dr_4 = dr_2 * dr_2;
dr_6 = dr_4 * dr_2;
dr_12 = dr_6 * dr_6;
y = (r2.LJ_type - r1.LJ_type);
x = y >> 31;
y = (y ^ x) - x;
x = r2.LJ_type + r1.LJ_type;
r2.LJ_type = (x + y) >> 1;
x = (x - y) >> 1;
atom_pair_LJ_type = (r2.LJ_type * (r2.LJ_type + 1) >> 1) + x;
ene_lin = 0.08333333 * LJ_type_A[atom_pair_LJ_type] * dr_12 -
0.1666666 * LJ_type_B[atom_pair_LJ_type] * dr_6; // LJ的A,B系数已经乘以12和6因此要反乘
ene_lin *= lj_scale_factor[dihedral_14_i];
atomicAdd(&ene[atom_i], ene_lin);
}
}
void Dihedral14LJAtomEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *boxlength_f, const int *a_14,
const int *b_14, const float *lj_scale_factor, const float *LJ_type_A,
const float *LJ_type_B, float *ene, cudaStream_t stream) {
size_t thread_per_block = 128;
size_t block_per_grid = ceilf(static_cast<float>(atom_numbers) / 128);
UINT_VECTOR_LJ_TYPE *uint_crd_with_LJ = NULL;
Cuda_Malloc_Safely(reinterpret_cast<void **>(&uint_crd_with_LJ), sizeof(UINT_VECTOR_LJ_TYPE) * atom_numbers);
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
Copy_Crd_To_New_Crd_Start<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, uint_crd, uint_crd_with_LJ, LJtype, charge);
VECTOR *boxlength = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(boxlength_f));
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128>>>(atom_numbers, ene, 0.);
Dihedral14LJAtomEnergyKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
dihedral_14_numbers, uint_crd_with_LJ, boxlength, a_14, b_14, lj_scale_factor, LJ_type_A, LJ_type_B, ene);
cudaStreamSynchronize(stream);
return;
}
void Dihedral14LJAtomEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *boxlength_f, const int *a_14,
const int *b_14, const float *lj_scale_factor, const float *LJ_type_A,
const float *LJ_type_B, float *ene, cudaStream_t stream);

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_LJ_ATOM_ENERGY_IMPL_H
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_LJ_ATOM_ENERGY_IMPL_H
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void Dihedral14LJAtomEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *boxlength_f, const int *a_14,
const int *b_14, const float *lj_scale_factor, const float *LJ_type_A,
const float *LJ_type_B, float *ene, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_LJ_ATOM_ENERGY_IMPL_H

@ -0,0 +1,140 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/nb14/dihedral_14_lj_cf_force_with_atom_energy_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
__global__ void Dihedral14LJCFForceWithAtomEnergyKernel(const int dihedral_14_numbers,
const UINT_VECTOR_LJ_TYPE *uint_crd, const VECTOR *boxlength,
const int *a_14, const int *b_14, const float *lj_scale_factor,
const float *cf_scale_factor, const float *LJ_type_A,
const float *LJ_type_B, VECTOR *frc, float *atom_energy) {
int dihedral_14_i = blockDim.x * blockIdx.x + threadIdx.x;
if (dihedral_14_i < dihedral_14_numbers) {
int int_x;
int int_y;
int int_z;
UINT_VECTOR_LJ_TYPE r1, r2;
VECTOR dr;
float dr_abs;
float dr2;
float dr_1;
float dr_2;
float dr_4;
float dr_8;
float dr_14;
float frc_abs = 0.;
VECTOR temp_frc;
float ene_lin;
float ene_lin2;
int x, y;
int atom_pair_LJ_type;
int atom_i = a_14[dihedral_14_i];
int atom_j = b_14[dihedral_14_i];
r1 = uint_crd[atom_i];
r2 = uint_crd[atom_j];
int_x = r2.uint_x - r1.uint_x;
int_y = r2.uint_y - r1.uint_y;
int_z = r2.uint_z - r1.uint_z;
dr.x = boxlength[0].x * int_x;
dr.y = boxlength[0].y * int_y;
dr.z = boxlength[0].z * int_z;
dr2 = dr.x * dr.x + dr.y * dr.y + dr.z * dr.z;
dr_2 = 1.0 / dr2;
dr_4 = dr_2 * dr_2;
dr_8 = dr_4 * dr_4;
dr_14 = dr_8 * dr_4 * dr_2;
dr_abs = norm3df(dr.x, dr.y, dr.z);
dr_1 = 1. / dr_abs;
float charge_i = r1.charge;
float charge_j = r2.charge;
float frc_cf_abs;
frc_cf_abs = cf_scale_factor[dihedral_14_i] * dr_2 * dr_1;
frc_cf_abs = -charge_i * charge_j * frc_cf_abs;
y = (r2.LJ_type - r1.LJ_type);
x = y >> 31;
y = (y ^ x) - x;
x = r2.LJ_type + r1.LJ_type;
r2.LJ_type = (x + y) >> 1;
x = (x - y) >> 1;
atom_pair_LJ_type = (r2.LJ_type * (r2.LJ_type + 1) >> 1) + x;
frc_abs = -LJ_type_A[atom_pair_LJ_type] * dr_14 + LJ_type_B[atom_pair_LJ_type] * dr_8;
frc_abs *= lj_scale_factor[dihedral_14_i];
frc_abs += frc_cf_abs;
temp_frc.x = frc_abs * dr.x;
temp_frc.y = frc_abs * dr.y;
temp_frc.z = frc_abs * dr.z;
atomicAdd(&frc[atom_j].x, -temp_frc.x);
atomicAdd(&frc[atom_j].y, -temp_frc.y);
atomicAdd(&frc[atom_j].z, -temp_frc.z);
atomicAdd(&frc[atom_i].x, temp_frc.x);
atomicAdd(&frc[atom_i].y, temp_frc.y);
atomicAdd(&frc[atom_i].z, temp_frc.z);
ene_lin = r1.charge * r2.charge * dr_1;
ene_lin *= cf_scale_factor[dihedral_14_i];
ene_lin2 = 0.08333333 * LJ_type_A[atom_pair_LJ_type] * dr_4 * dr_8 -
0.1666666 * LJ_type_B[atom_pair_LJ_type] * dr_4 * dr_2; // LJ的A,B系数已经乘以12和6因此要反乘
ene_lin2 *= lj_scale_factor[dihedral_14_i];
atomicAdd(&atom_energy[atom_i], ene_lin + ene_lin2);
}
}
void Dihedral14LJCFForceWithAtomEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *boxlength_f,
const int *a_14, const int *b_14, const float *lj_scale_factor,
const float *cf_scale_factor, const float *LJ_type_A, const float *LJ_type_B,
float *frc_f, float *atom_energy, cudaStream_t stream) {
size_t thread_per_block = 128;
size_t block_per_grid = ceilf(static_cast<float>(atom_numbers) / 128);
UINT_VECTOR_LJ_TYPE *uint_crd_with_LJ = NULL;
Cuda_Malloc_Safely(reinterpret_cast<void **>(&uint_crd_with_LJ), sizeof(UINT_VECTOR_LJ_TYPE) * atom_numbers);
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
Copy_Crd_To_New_Crd_Start<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, uint_crd, uint_crd_with_LJ, LJtype, charge);
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128>>>(3 * atom_numbers, frc_f, 0.);
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128>>>(atom_numbers, atom_energy, 0.);
VECTOR *boxlength = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(boxlength_f));
VECTOR *frc = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(frc_f));
Dihedral14LJCFForceWithAtomEnergyKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
dihedral_14_numbers, uint_crd_with_LJ, boxlength, a_14, b_14, lj_scale_factor, cf_scale_factor, LJ_type_A,
LJ_type_B, frc, atom_energy);
cudaStreamSynchronize(stream);
return;
}
void Dihedral14LJForceWithDirectCF(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *boxlength_f, const int *a_14,
const int *b_14, const float *lj_scale_factor, const float *cf_scale_factor,
const float *LJ_type_A, const float *LJ_type_B, float *frc, float *atom_energy,
cudaStream_t stream);

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_LJ_CF_FORCE_WITH_ATOM_ENERGY_IMPL_H
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_LJ_CF_FORCE_WITH_ATOM_ENERGY_IMPL_H
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void Dihedral14LJCFForceWithAtomEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f,
const int *LJtype, const float *charge, const float *boxlength_f,
const int *a_14, const int *b_14, const float *lj_scale_factor,
const float *cf_scale_factor, const float *LJ_type_A, const float *LJ_type_B,
float *frc, float *atom_energy, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_LJ_CF_FORCE_WITH_ATOM_ENERGY_IMPL_H

@ -0,0 +1,102 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/nb14/dihedral_14_lj_energy_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
__global__ void Dihedral14LJEnergyKernel(const int dihedral_14_numbers, const UINT_VECTOR_LJ_TYPE *uint_crd,
const VECTOR *boxlength, const int *a_14, const int *b_14,
const float *lj_scale_factor, const float *LJ_type_A, const float *LJ_type_B,
float *ene) {
int dihedral_14_i = blockDim.x * blockIdx.x + threadIdx.x;
if (dihedral_14_i < dihedral_14_numbers) {
int atom_i = a_14[dihedral_14_i];
int atom_j = b_14[dihedral_14_i];
UINT_VECTOR_LJ_TYPE r1 = uint_crd[atom_i];
UINT_VECTOR_LJ_TYPE r2 = uint_crd[atom_j];
int int_x;
int int_y;
int int_z;
VECTOR dr;
float dr2;
float dr_2;
float dr_4;
float dr_6;
float dr_12;
float ene_lin = 0.;
int x, y;
int atom_pair_LJ_type;
int_x = r2.uint_x - r1.uint_x;
int_y = r2.uint_y - r1.uint_y;
int_z = r2.uint_z - r1.uint_z;
dr.x = boxlength[0].x * int_x;
dr.y = boxlength[0].y * int_y;
dr.z = boxlength[0].z * int_z;
dr2 = dr.x * dr.x + dr.y * dr.y + dr.z * dr.z;
dr_2 = 1. / dr2;
dr_4 = dr_2 * dr_2;
dr_6 = dr_4 * dr_2;
dr_12 = dr_6 * dr_6;
y = (r2.LJ_type - r1.LJ_type);
x = y >> 31;
y = (y ^ x) - x;
x = r2.LJ_type + r1.LJ_type;
r2.LJ_type = (x + y) >> 1;
x = (x - y) >> 1;
atom_pair_LJ_type = (r2.LJ_type * (r2.LJ_type + 1) >> 1) + x;
ene_lin = 0.08333333 * LJ_type_A[atom_pair_LJ_type] * dr_12 -
0.1666666 * LJ_type_B[atom_pair_LJ_type] * dr_6; // LJ的A,B系数已经乘以12和6因此要反乘
ene_lin *= lj_scale_factor[dihedral_14_i];
ene[dihedral_14_i] = ene_lin;
}
}
void Dihedral14LJEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *boxlength_f, const int *a_14, const int *b_14,
const float *lj_scale_factor, const float *LJ_type_A, const float *LJ_type_B, float *ene,
cudaStream_t stream) {
size_t thread_per_block = 128;
size_t block_per_grid = ceilf(static_cast<float>(atom_numbers) / 128);
UINT_VECTOR_LJ_TYPE *uint_crd_with_LJ = NULL;
Cuda_Malloc_Safely(reinterpret_cast<void **>(&uint_crd_with_LJ), sizeof(UINT_VECTOR_LJ_TYPE) * atom_numbers);
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
Copy_Crd_To_New_Crd_Start<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, uint_crd, uint_crd_with_LJ, LJtype, charge);
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128>>>(dihedral_14_numbers, ene, 0.);
VECTOR *boxlength = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(boxlength_f));
Dihedral14LJEnergyKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
dihedral_14_numbers, uint_crd_with_LJ, boxlength, a_14, b_14, lj_scale_factor, LJ_type_A, LJ_type_B, ene);
cudaStreamSynchronize(stream);
return;
}
void Dihedral14LJEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *boxlength_f, const int *a_14, const int *b_14,
const float *lj_scale_factor, const float *LJ_type_A, const float *LJ_type_B, float *ene,
cudaStream_t stream);

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_LJ_ENERGY_IMPL_H
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_LJ_ENERGY_IMPL_H
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void Dihedral14LJEnergy(const int dihedral_14_numbers, const int atom_numbers, const int *uint_crd_f, const int *LJtype,
const float *charge, const float *boxlength_f, const int *a_14, const int *b_14,
const float *lj_scale_factor, const float *LJ_type_A, const float *LJ_type_B, float *ene,
cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NB14_DIHEDRAL_14_LJ_ENERGY_IMPL_H

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save