First commit - scatter_nd_int64_update

lint fix - no lint required for previous code

fix int64 templates

removed extra files
pull/9858/head
danishnxt 4 years ago
parent d68708960e
commit d1fcfaf0f2

@ -22,18 +22,35 @@ MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ScatterNdGpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ScatterNdGpuFwdKernel, float, int64_t)
MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ScatterNdGpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ScatterNdGpuFwdKernel, half, int64_t)
MS_REG_GPU_KERNEL_TWO(
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ScatterNdGpuFwdKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ScatterNdGpuFwdKernel, int, int64_t)
MS_REG_GPU_KERNEL_TWO(
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
ScatterNdGpuFwdKernel, short, int) // NOLINT
MS_REG_GPU_KERNEL_TWO(
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
ScatterNdGpuFwdKernel, short, int64_t) // NOLINT
MS_REG_GPU_KERNEL_TWO(
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ScatterNdGpuFwdKernel, uchar, int)
MS_REG_GPU_KERNEL_TWO(
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ScatterNdGpuFwdKernel, uchar, int64_t)
} // namespace kernel
} // namespace mindspore

@ -23,9 +23,9 @@ __global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t b
const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1,
S *indices_stride, S *work_shape) {
int i, j;
for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
int write_index = 0;
size_t write_index = 0;
bool out_bound = false;
i = read_index / block_size;
@ -51,8 +51,8 @@ void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
S *work_shape, cudaStream_t stream) {
ScatterNdKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(indices, update, output, block_size, input_size,
output_size, indices_dim_0, indices_dim_1,
indices_stride, work_shape);
output_size, indices_dim_0, indices_dim_1,
indices_stride, work_shape);
return;
}
@ -60,21 +60,43 @@ template void ScatterNd<float, int>(int *indices, float *update, float *output,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void ScatterNd<float, int64_t>(int64_t *indices, float *update, float *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1,
int64_t *indices_stride, int64_t *work_shape, cudaStream_t stream);
template void ScatterNd<half, int>(int *indices, half *update, half *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void ScatterNd<half, int64_t>(int64_t *indices, half *update, half *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
cudaStream_t stream);
template void ScatterNd<int, int>(int *indices, int *update, int *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void ScatterNd<int, int64_t>(int64_t *indices, int *update, int *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
cudaStream_t stream);
// NOLINTNEXTLINE
template void ScatterNd<short, int>(int *indices, short *update, short *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
// NOLINTNEXTLINE
template void ScatterNd<short, int64_t>(int64_t *indices, short *update, short *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1,
int64_t *indices_stride, int64_t *work_shape, cudaStream_t stream);
template void ScatterNd<unsigned char, int>(int *indices, unsigned char *update, unsigned char *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void ScatterNd<unsigned char, int64_t>(int64_t *indices, unsigned char *update, unsigned char *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride,
int64_t *work_shape, cudaStream_t stream);

@ -2984,7 +2984,7 @@ class ScatterNd(PrimitiveWithInfer):
def __infer__(self, indices, update, shape):
shp = shape['value']
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], [mstype.int32], self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_value_type("shape", shp, [tuple], self.name)
for i, x in enumerate(shp):
validator.check_positive_int(x, f'shape[{i}]', self.name)

@ -48,6 +48,13 @@ def scatternd_positive(nptype):
[0., 1.1]]).astype(nptype)
scatternd_net(arr_indices, arr_update, shape, expect)
arr_indices = np.array([[0, 1], [1, 1], [0, 1], [0, 1], [0, 1]]).astype(np.int64)
arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(nptype)
shape = (2, 2)
expect = np.array([[0., 5.3],
[0., 1.1]]).astype(nptype)
scatternd_net(arr_indices, arr_update, shape, expect)
def scatternd_negative(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -58,6 +65,13 @@ def scatternd_negative(nptype):
[-21.4, -3.1]]).astype(nptype)
scatternd_net(arr_indices, arr_update, shape, expect)
arr_indices = np.array([[1, 0], [1, 1], [1, 0], [1, 0], [1, 0]]).astype(np.int64)
arr_update = np.array([-13.4, -3.1, 5.1, -12.1, -1.0]).astype(nptype)
shape = (2, 2)
expect = np.array([[0., 0.],
[-21.4, -3.1]]).astype(nptype)
scatternd_net(arr_indices, arr_update, shape, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_traning
@pytest.mark.env_onecard

Loading…
Cancel
Save