fix gpu ops bug.

pull/10275/head
linqingke 4 years ago
parent 9663cc7235
commit f40f991ea9

@ -27,5 +27,8 @@ MS_REG_GPU_KERNEL_ONE(Split,
MS_REG_GPU_KERNEL_ONE(
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SplitGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
SplitGpuFwdKernel, uint32_t)
} // namespace kernel
} // namespace mindspore

@ -48,3 +48,6 @@ template void SplitKernel(const size_t size, const int axis_step, const int all_
template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
const int all_size_axis, const half* input, half** outputs,
cudaStream_t cuda_stream);
template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
const int all_size_axis, const uint32_t* input, uint32_t** outputs,
cudaStream_t cuda_stream);

@ -42,12 +42,18 @@ class SGDGpuKernel : public GpuKernel {
T *accum = GetDeviceAddress<T>(inputs, 3);
T *momentum = GetDeviceAddress<T>(inputs, 4);
T *stat = GetDeviceAddress<T>(inputs, 5);
T *output_param = GetDeviceAddress<T>(outputs, 0);
SGD(size_, dampening_, weight_decay_, nesterov_, lr, momentum, grad, param, accum, stat,
reinterpret_cast<cudaStream_t>(stream));
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(output_param, param, sizeof(T) * size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream)),
"SGD cudaMemcpyAsync params to outputs failed");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
dampening_ = GetAttr<float>(kernel_node, "dampening");
weight_decay_ = GetAttr<float>(kernel_node, "weight_decay");
nesterov_ = GetAttr<bool>(kernel_node, "nesterov");

Loading…
Cancel
Save