|
|
|
@ -50,11 +50,17 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
xpu::findmax(dev_ctx.x_context(), input->data<T>(), input->numel(),
|
|
|
|
|
max_input->data<T>()) == xpu::Error_t::SUCCESS,
|
|
|
|
|
true, platform::errors::InvalidArgument("XPU kernel error!"));
|
|
|
|
|
true, platform::errors::InvalidArgument(
|
|
|
|
|
"XPU conv kernel error,can not finde max_input,please "
|
|
|
|
|
"check whether Baidu Kunlun "
|
|
|
|
|
"Card is properly installed."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
xpu::findmax(dev_ctx.x_context(), filter.data<T>(), filter.numel(),
|
|
|
|
|
max_filter->data<T>()) == xpu::Error_t::SUCCESS,
|
|
|
|
|
true, platform::errors::InvalidArgument("XPU kernel error!"));
|
|
|
|
|
true, platform::errors::InvalidArgument(
|
|
|
|
|
"XPU conv kernel error,can not find max_filter,please "
|
|
|
|
|
"check whether Baidu Kunlun "
|
|
|
|
|
"Card is properly installed."));
|
|
|
|
|
if (groups == 1) {
|
|
|
|
|
int r = xpu::conv2d_forward_int16<float, float, float, float>(
|
|
|
|
|
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
|
|
|
|
@ -63,8 +69,12 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
output->data<float>(), nullptr, nullptr, xpu::Activation_t::LINEAR,
|
|
|
|
|
// nullptr, nullptr);
|
|
|
|
|
max_input->data<float>(), max_filter->data<float>());
|
|
|
|
|
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
|
|
|
|
|
platform::errors::InvalidArgument("XPU kernel error!"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, XPU_SUCCESS,
|
|
|
|
|
platform::errors::External("XPU conv kernel return wrong value[%d], "
|
|
|
|
|
"please check whether Baidu Kunlun Card "
|
|
|
|
|
"is properly installed.",
|
|
|
|
|
r));
|
|
|
|
|
} else {
|
|
|
|
|
int r = xpu::conv2d_int16_with_group<float, float, float>(
|
|
|
|
|
dev_ctx.x_context(), input->data<float>(), filter.data<float>(),
|
|
|
|
@ -72,8 +82,12 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
win_w, groups, strides[0], strides[1], paddings[0], paddings[1],
|
|
|
|
|
// nullptr, nullptr);
|
|
|
|
|
max_input->data<float>(), max_filter->data<float>());
|
|
|
|
|
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
|
|
|
|
|
platform::errors::InvalidArgument("XPU kernel error!"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, XPU_SUCCESS,
|
|
|
|
|
platform::errors::External("XPU conv kernel return wrong value[%d], "
|
|
|
|
|
"please check whether Baidu Kunlun Card "
|
|
|
|
|
"is properly installed.",
|
|
|
|
|
r));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -125,7 +139,11 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
xpu::findmax(dev_ctx.x_context(), output_grad->data<T>(),
|
|
|
|
|
output_grad->numel(),
|
|
|
|
|
max_output_grad->data<T>()) == xpu::Error_t::SUCCESS,
|
|
|
|
|
true, platform::errors::InvalidArgument("XPU kernel error!"));
|
|
|
|
|
true,
|
|
|
|
|
platform::errors::External(
|
|
|
|
|
"XPU conv kernel error, can not find max_output_grad, please check "
|
|
|
|
|
"whether Baidu Kunlun Card is "
|
|
|
|
|
"properly installed."));
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
int r = xpu::conv2d_backward_int16(
|
|
|
|
|
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
|
|
|
|
@ -134,8 +152,12 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
filter.data<float>(), input_grad->data<float>(),
|
|
|
|
|
// nullptr, nullptr,
|
|
|
|
|
max_output_grad->data<float>(), max_filter->data<float>());
|
|
|
|
|
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
|
|
|
|
|
platform::errors::InvalidArgument("XPU kernel error!"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, XPU_SUCCESS,
|
|
|
|
|
platform::errors::External("XPU conv kernel return wrong value[%d], "
|
|
|
|
|
"please check whether Baidu Kunlun Card "
|
|
|
|
|
"is properly installed.",
|
|
|
|
|
r));
|
|
|
|
|
}
|
|
|
|
|
if (filter_grad) {
|
|
|
|
|
int r = xpu::conv2d_backward_weight_int16(
|
|
|
|
@ -145,8 +167,12 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
input->data<float>(), filter_grad->data<float>(),
|
|
|
|
|
// nullptr, nullptr,
|
|
|
|
|
max_output_grad->data<float>(), max_input->data<float>());
|
|
|
|
|
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
|
|
|
|
|
platform::errors::InvalidArgument("XPU kernel error!"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, XPU_SUCCESS,
|
|
|
|
|
platform::errors::External("XPU conv kernel return wrong value[%d], "
|
|
|
|
|
"please check whether Baidu Kunlun Card "
|
|
|
|
|
"is properly installed.",
|
|
|
|
|
r));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|