support roi_align & affine_channel for kunlun (#29561)

* support roi_align & affine_channel for kunlun

* minor
revert-31562-mean
QingshuChen 4 years ago committed by GitHub
parent 0cad1152f4
commit 79a41a9ed6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,7 @@ endif()
INCLUDE(ExternalProject)
SET(XPU_PROJECT "extern_xpu")
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_07_cdfbf0c.tar.gz" CACHE STRING "" FORCE)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_11.tar.gz" CACHE STRING "" FORCE)
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}")
SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu")

@ -0,0 +1,186 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class AffineChannelXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* y = ctx.Output<framework::Tensor>("Out");
y->mutable_data<T>(ctx.GetPlace());
const framework::DataLayout layout =
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto dims = x->dims();
int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1];
int HxW = x->numel() / N / C;
auto* scale_d = scale->data<T>();
auto* bias_d = bias->data<T>();
auto* x_d = x->data<T>();
auto* y_d = y->data<T>();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<int> x_shape;
std::vector<int> b_shape;
if (layout == framework::DataLayout::kNCHW) {
x_shape.push_back(N);
x_shape.push_back(C);
x_shape.push_back(HxW);
b_shape.push_back(1);
b_shape.push_back(C);
b_shape.push_back(1);
} else {
x_shape.push_back(N * HxW);
x_shape.push_back(C);
b_shape.push_back(1);
b_shape.push_back(C);
}
int r = 0;
r = xpu::broadcast_mul(dev_ctx.x_context(), x_d, scale_d, y_d, x_shape,
b_shape);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The broadcast_mul XPU OP return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::broadcast_add(dev_ctx.x_context(), y_d, bias_d, y_d, x_shape,
b_shape);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The broadcast_add XPU OP return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
template <typename DeviceContext, typename T>
class AffineChannelGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dscale =
ctx.Output<framework::Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
const framework::DataLayout layout =
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto dims = x->dims();
int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1];
int HxW = x->numel() / N / C;
auto* dy_d = dy->data<T>();
auto* scale_d = scale->data<T>();
T* dx_d = dx ? dx->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* dscale_d = dscale ? dscale->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* dbias_d = dbias ? dbias->mutable_data<T>(ctx.GetPlace()) : nullptr;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<int> x_shape;
std::vector<int> b_shape;
std::vector<int> rdims;
if (layout == framework::DataLayout::kNCHW) {
x_shape.push_back(N);
x_shape.push_back(C);
x_shape.push_back(HxW);
b_shape.push_back(1);
b_shape.push_back(C);
b_shape.push_back(1);
rdims.push_back(0);
rdims.push_back(2);
} else {
x_shape.push_back(N * HxW);
x_shape.push_back(C);
b_shape.push_back(1);
b_shape.push_back(C);
rdims.push_back(0);
}
int r = 0;
if (dscale_d && dbias_d) {
r = xpu::reduce_sum<T>(dev_ctx.x_context(), dy_d, dbias_d, x_shape,
rdims);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The reduce_sum XPU OP return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
T* tmp = nullptr;
r = xpu_malloc(reinterpret_cast<void**>(&tmp), dy->numel() * sizeof(T));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External("no enough memory in xpu"));
r = xpu::mul<T>(dev_ctx.x_context(), dy_d, x->data<T>(), tmp,
dy->numel());
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("The mul XPU OP return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::reduce_sum<T>(dev_ctx.x_context(), tmp, dscale_d, x_shape,
rdims);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The reduce_sum XPU OP return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
xpu_free(tmp);
}
if (dx_d) {
r = xpu::broadcast_mul(dev_ctx.x_context(), dy_d, scale_d, dx_d, x_shape,
b_shape);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The broadcast_mul XPU OP return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using XPU = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(affine_channel, ops::AffineChannelXPUKernel<XPU, float>);
REGISTER_OP_XPU_KERNEL(affine_channel_grad,
ops::AffineChannelGradXPUKernel<XPU, float>);
#endif

File diff suppressed because it is too large Load Diff

@ -0,0 +1,148 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit testing for affine_channel_op
"""
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test_xpu import XPUOpTest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
def affine_channel(x, scale, bias, layout):
C = x.shape[1] if layout == 'NCHW' else x.shape[-1]
if len(x.shape) == 4:
new_shape = (1, C, 1, 1) if layout == 'NCHW' else (1, 1, 1, C)
else:
new_shape = (1, C)
scale = scale.reshape(new_shape)
bias = bias.reshape(new_shape)
return x * scale + bias
class TestAffineChannelOp(XPUOpTest):
def setUp(self):
self.op_type = "affine_channel"
self.init_test_case()
x = np.random.random(self.shape).astype("float32")
scale = np.random.random(self.C).astype("float32")
bias = np.random.random(self.C).astype("float32")
y = affine_channel(x, scale, bias, self.layout)
self.inputs = {'X': x, 'Scale': scale, 'Bias': bias}
self.attrs = {'data_layout': self.layout}
self.outputs = {'Out': y}
def test_check_output(self):
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X', 'Scale', 'Bias'], 'Out')
def test_check_grad_stopgrad_dx(self):
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Scale', 'Bias'], 'Out', no_grad_set=set('X'))
def test_check_grad_stopgrad_dscale_dbias(self):
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', no_grad_set=set(['Scale', 'Bias']))
def init_test_case(self):
self.shape = [2, 100, 3, 3]
self.C = 100
self.layout = 'NCHW'
class TestAffineChannelOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program()):
def test_x_type():
input_data = np.random.random(2, 1, 2, 2).astype("float32")
fluid.layers.affine_channel(input_data)
self.assertRaises(TypeError, test_x_type)
def test_x_dtype():
x2 = fluid.layers.data(
name='x2', shape=[None, 1, 2, 2], dtype='int32')
fluid.layers.affine_channel(x2)
self.assertRaises(TypeError, test_x_dtype)
def test_scale_type():
x3 = fluid.layers.data(
name='x3', shape=[None, 1, 2, 2], dtype='float32')
fluid.layers.affine_channel(x3, scale=1)
self.assertRaises(TypeError, test_scale_type)
def test_bias_type():
x4 = fluid.layers.data(
name='x4', shape=[None, 1, 2, 2], dtype='float32')
fluid.layers.affine_channel(x4, bias=1)
self.assertRaises(TypeError, test_bias_type)
class TestAffineChannelNHWC(TestAffineChannelOp):
def init_test_case(self):
self.shape = [2, 3, 3, 100]
self.C = 100
self.layout = 'NHWC'
def test_check_grad_stopgrad_dx(self):
return
def test_check_grad_stopgrad_dscale_dbias(self):
return
class TestAffineChannel2D(TestAffineChannelOp):
def init_test_case(self):
self.shape = [2, 100]
self.C = 100
self.layout = 'NCHW'
def test_check_grad_stopgrad_dx(self):
return
def test_check_grad_stopgrad_dscale_dbias(self):
return
if __name__ == '__main__':
unittest.main()

@ -20,13 +20,13 @@ import math
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
from op_test_xpu import XPUOpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
@skip_check_grad_ci(reason="There is no grad kernel for roi_align_xpu kernel.")
class TestROIAlignOp(OpTest):
class TestROIAlignOp(XPUOpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
@ -59,16 +59,16 @@ class TestROIAlignOp(OpTest):
self.pooled_width = 2
self.sampling_ratio = -1
self.x = np.random.random(self.x_dim).astype('float64')
self.x = np.random.random(self.x_dim).astype('float32')
def pre_calc(self, x_i, roi_xmin, roi_ymin, roi_bin_grid_h, roi_bin_grid_w,
bin_size_h, bin_size_w):
count = roi_bin_grid_h * roi_bin_grid_w
bilinear_pos = np.zeros(
[self.channels, self.pooled_height, self.pooled_width, count, 4],
np.float64)
np.float32)
bilinear_w = np.zeros(
[self.pooled_height, self.pooled_width, count, 4], np.float64)
[self.pooled_height, self.pooled_width, count, 4], np.float32)
for ph in range(self.pooled_width):
for pw in range(self.pooled_height):
c = 0
@ -118,7 +118,7 @@ class TestROIAlignOp(OpTest):
def calc_roi_align(self):
self.out_data = np.zeros(
(self.rois_num, self.channels, self.pooled_height,
self.pooled_width)).astype('float64')
self.pooled_width)).astype('float32')
for i in range(self.rois_num):
roi = self.rois[i]
@ -166,7 +166,7 @@ class TestROIAlignOp(OpTest):
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype("float64")
self.rois = np.array(rois).astype("float32")
def setUp(self):
self.op_type = "roi_align"
@ -178,6 +178,12 @@ class TestROIAlignOp(OpTest):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, {'X'}, 'Out')
class TestROIAlignInLodOp(TestROIAlignOp):
def set_data(self):

Loading…
Cancel
Save