fix bug of op l2loss on gpu

pull/9713/head
zhouyuanshen 4 years ago
parent d77e8c39d6
commit 37ac7ebcf6

@ -22,15 +22,22 @@ template <typename T>
__global__ void L2LossKernel(const size_t input_size, const T *input , T *output) { __global__ void L2LossKernel(const size_t input_size, const T *input , T *output) {
T ret = 0; T ret = 0;
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < input_size; id += blockDim.x * gridDim.x) { for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < input_size; id += blockDim.x * gridDim.x) {
ret = (input[id] * input[id]); ret = input[id] * input[id];
ret /= static_cast<T>(2); ret /= static_cast<T>(2);
MsAtomicAdd(output, ret); MsAtomicAdd(output, ret);
} }
return; return;
} }
template <typename T>
__global__ void ClearOutputMem(T *output) {
output[0] = static_cast<T>(0);
return;
}
template <typename T> template <typename T>
void L2Loss(const size_t input_size, const T *input , T *output, cudaStream_t stream) { void L2Loss(const size_t input_size, const T *input , T *output, cudaStream_t stream) {
ClearOutputMem<<<GET_BLOCKS(1), GET_THREADS, 0, stream>>>(output);
L2LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, input, output); L2LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, input, output);
} }

@ -445,7 +445,7 @@ class ReduceAll(_Reduce):
the shape of output is :math:`(x_1, x_4, ..., x_R)`. the shape of output is :math:`(x_1, x_4, ..., x_R)`.
Supported Platforms: Supported Platforms:
``Ascend`` ``Ascend`` ``GPU``
Examples: Examples:
>>> input_x = Tensor(np.array([[True, False], [True, True]])) >>> input_x = Tensor(np.array([[True, False], [True, True]]))
@ -487,7 +487,7 @@ class ReduceAny(_Reduce):
the shape of output is :math:`(x_1, x_4, ..., x_R)`. the shape of output is :math:`(x_1, x_4, ..., x_R)`.
Supported Platforms: Supported Platforms:
``Ascend`` ``Ascend`` ``GPU``
Examples: Examples:
>>> input_x = Tensor(np.array([[True, False], [True, True]])) >>> input_x = Tensor(np.array([[True, False], [True, True]]))

Loading…
Cancel
Save