|
|
|
@ -27,10 +27,10 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
// that avoids modifying the variable in the Scope.
|
|
|
|
|
Tensor filter = *context.Input<Tensor>("Filter");
|
|
|
|
|
Tensor* output = context.Output<Tensor>("Output");
|
|
|
|
|
Tensor* max_input = context.Output<Tensor>("MaxInput");
|
|
|
|
|
Tensor* max_filter = context.Output<Tensor>("MaxFilter");
|
|
|
|
|
max_input->mutable_data<T>(context.GetPlace());
|
|
|
|
|
max_filter->mutable_data<T>(context.GetPlace());
|
|
|
|
|
// Tensor* max_input = context.Output<Tensor>("MaxInput");
|
|
|
|
|
// Tensor* max_filter = context.Output<Tensor>("MaxFilter");
|
|
|
|
|
// max_input->mutable_data<T>(context.GetPlace());
|
|
|
|
|
// max_filter->mutable_data<T>(context.GetPlace());
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
int groups = context.Attr<int>("groups");
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
@ -47,28 +47,28 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
dilations[0] == 1 && dilations[1] == 1, true,
|
|
|
|
|
platform::errors::InvalidArgument("XPU only support dilation == 1."));
|
|
|
|
|
auto& dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
xpu::findmax(dev_ctx.x_context(), input->data<T>(), input->numel(),
|
|
|
|
|
max_input->data<T>()) == xpu::Error_t::SUCCESS,
|
|
|
|
|
true, platform::errors::InvalidArgument(
|
|
|
|
|
"XPU conv kernel error,can not finde max_input,please "
|
|
|
|
|
"check whether Baidu Kunlun "
|
|
|
|
|
"Card is properly installed."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
xpu::findmax(dev_ctx.x_context(), filter.data<T>(), filter.numel(),
|
|
|
|
|
max_filter->data<T>()) == xpu::Error_t::SUCCESS,
|
|
|
|
|
true, platform::errors::InvalidArgument(
|
|
|
|
|
"XPU conv kernel error,can not find max_filter,please "
|
|
|
|
|
"check whether Baidu Kunlun "
|
|
|
|
|
"Card is properly installed."));
|
|
|
|
|
// PADDLE_ENFORCE_EQ(
|
|
|
|
|
// xpu::findmax(dev_ctx.x_context(), input->data<T>(), input->numel(),
|
|
|
|
|
// max_input->data<T>()) == xpu::Error_t::SUCCESS,
|
|
|
|
|
// true, platform::errors::InvalidArgument(
|
|
|
|
|
// "XPU conv kernel error,can not finde max_input,please "
|
|
|
|
|
// "check whether Baidu Kunlun "
|
|
|
|
|
// "Card is properly installed."));
|
|
|
|
|
// PADDLE_ENFORCE_EQ(
|
|
|
|
|
// xpu::findmax(dev_ctx.x_context(), filter.data<T>(), filter.numel(),
|
|
|
|
|
// max_filter->data<T>()) == xpu::Error_t::SUCCESS,
|
|
|
|
|
// true, platform::errors::InvalidArgument(
|
|
|
|
|
// "XPU conv kernel error,can not find max_filter,please "
|
|
|
|
|
// "check whether Baidu Kunlun "
|
|
|
|
|
// "Card is properly installed."));
|
|
|
|
|
if (groups == 1) {
|
|
|
|
|
int r = xpu::conv2d_forward_int16<float, float, float, float>(
|
|
|
|
|
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
|
|
|
|
|
strides[0], strides[1], paddings[0], paddings[1], dilations[0],
|
|
|
|
|
dilations[1], groups, input->data<float>(), filter.data<float>(),
|
|
|
|
|
output->data<float>(), nullptr, nullptr, xpu::Activation_t::LINEAR,
|
|
|
|
|
// nullptr, nullptr);
|
|
|
|
|
max_input->data<float>(), max_filter->data<float>());
|
|
|
|
|
nullptr, nullptr);
|
|
|
|
|
// max_input->data<float>(), max_filter->data<float>());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, XPU_SUCCESS,
|
|
|
|
|
platform::errors::External("XPU conv kernel return wrong value[%d], "
|
|
|
|
@ -80,8 +80,8 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
dev_ctx.x_context(), input->data<float>(), filter.data<float>(),
|
|
|
|
|
output->data<float>(), batch_size, img_c, img_h, img_w, f, win_h,
|
|
|
|
|
win_w, groups, strides[0], strides[1], paddings[0], paddings[1],
|
|
|
|
|
// nullptr, nullptr);
|
|
|
|
|
max_input->data<float>(), max_filter->data<float>());
|
|
|
|
|
nullptr, nullptr);
|
|
|
|
|
// max_input->data<float>(), max_filter->data<float>());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, XPU_SUCCESS,
|
|
|
|
|
platform::errors::External("XPU conv kernel return wrong value[%d], "
|
|
|
|
@ -96,9 +96,9 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
const Tensor* input = context.Input<Tensor>("Input");
|
|
|
|
|
const Tensor* max_input = context.Input<Tensor>("MaxInput");
|
|
|
|
|
const Tensor* max_filter = context.Input<Tensor>("MaxFilter");
|
|
|
|
|
Tensor* max_output_grad = context.Output<Tensor>("MaxOutputGrad");
|
|
|
|
|
// const Tensor* max_input = context.Input<Tensor>("MaxInput");
|
|
|
|
|
// const Tensor* max_filter = context.Input<Tensor>("MaxFilter");
|
|
|
|
|
// Tensor* max_output_grad = context.Output<Tensor>("MaxOutputGrad");
|
|
|
|
|
const Tensor* output_grad =
|
|
|
|
|
context.Input<Tensor>(framework::GradVarName("Output"));
|
|
|
|
|
Tensor* input_grad =
|
|
|
|
@ -133,25 +133,25 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
filter_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
auto& dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
max_output_grad->Resize({4});
|
|
|
|
|
max_output_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
xpu::findmax(dev_ctx.x_context(), output_grad->data<T>(),
|
|
|
|
|
output_grad->numel(),
|
|
|
|
|
max_output_grad->data<T>()) == xpu::Error_t::SUCCESS,
|
|
|
|
|
true,
|
|
|
|
|
platform::errors::External(
|
|
|
|
|
"XPU conv kernel error, can not find max_output_grad, please check "
|
|
|
|
|
"whether Baidu Kunlun Card is "
|
|
|
|
|
"properly installed."));
|
|
|
|
|
// max_output_grad->Resize({4});
|
|
|
|
|
// max_output_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
// PADDLE_ENFORCE_EQ(
|
|
|
|
|
// xpu::findmax(dev_ctx.x_context(), output_grad->data<T>(),
|
|
|
|
|
// output_grad->numel(),
|
|
|
|
|
// max_output_grad->data<T>()) == xpu::Error_t::SUCCESS,
|
|
|
|
|
// true,
|
|
|
|
|
// platform::errors::External(
|
|
|
|
|
// "XPU conv kernel error, can not find max_output_grad, please
|
|
|
|
|
// check "
|
|
|
|
|
// "whether Baidu Kunlun Card is "
|
|
|
|
|
// "properly installed."));
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
int r = xpu::conv2d_backward_int16(
|
|
|
|
|
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
|
|
|
|
|
strides[0], strides[1], paddings[0], paddings[1], dilations[0],
|
|
|
|
|
dilations[1], groups, output_grad->data<float>(),
|
|
|
|
|
filter.data<float>(), input_grad->data<float>(),
|
|
|
|
|
// nullptr, nullptr,
|
|
|
|
|
max_output_grad->data<float>(), max_filter->data<float>());
|
|
|
|
|
filter.data<float>(), input_grad->data<float>(), nullptr, nullptr);
|
|
|
|
|
// max_output_grad->data<float>(), max_filter->data<float>());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, XPU_SUCCESS,
|
|
|
|
|
platform::errors::External("XPU conv kernel return wrong value[%d], "
|
|
|
|
@ -164,9 +164,8 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
|
|
|
|
|
strides[0], strides[1], paddings[0], paddings[1], dilations[0],
|
|
|
|
|
dilations[1], groups, output_grad->data<float>(),
|
|
|
|
|
input->data<float>(), filter_grad->data<float>(),
|
|
|
|
|
// nullptr, nullptr,
|
|
|
|
|
max_output_grad->data<float>(), max_input->data<float>());
|
|
|
|
|
input->data<float>(), filter_grad->data<float>(), nullptr, nullptr);
|
|
|
|
|
// max_output_grad->data<float>(), max_input->data<float>());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, XPU_SUCCESS,
|
|
|
|
|
platform::errors::External("XPU conv kernel return wrong value[%d], "
|
|
|
|
|