|
|
|
@ -28,50 +28,34 @@ class StackXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* y = ctx.Output<Tensor>("Y");
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
if (axis < 0) {
|
|
|
|
|
axis += (x[0]->dims().size() + 1);
|
|
|
|
|
axis += x[0]->dims().size() + 1;
|
|
|
|
|
}
|
|
|
|
|
int n = static_cast<int>(x.size());
|
|
|
|
|
PADDLE_ENFORCE_LE(n, 24,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"XPU only surpport at most 24 tensors for now"));
|
|
|
|
|
auto* y_data = y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
int pre = 1, post = 1;
|
|
|
|
|
|
|
|
|
|
auto& dim = x[0]->dims();
|
|
|
|
|
for (auto i = 0; i < axis; ++i) {
|
|
|
|
|
pre *= dim[i];
|
|
|
|
|
std::vector<int> xdims;
|
|
|
|
|
for (auto i = 0; i < dim.size(); ++i) {
|
|
|
|
|
xdims.push_back(dim[i]);
|
|
|
|
|
}
|
|
|
|
|
for (auto i = axis; i < dim.size(); ++i) {
|
|
|
|
|
post *= dim[i];
|
|
|
|
|
xdims.push_back(1);
|
|
|
|
|
std::vector<std::vector<int>> xdims_list;
|
|
|
|
|
int n = static_cast<int>(x.size());
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
|
xdims_list.push_back(xdims);
|
|
|
|
|
}
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
void* x_datas_host = std::malloc(n * sizeof(void*));
|
|
|
|
|
void* x_datas_device = nullptr;
|
|
|
|
|
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&x_datas_device),
|
|
|
|
|
n * sizeof(void*)),
|
|
|
|
|
XPU_SUCCESS,
|
|
|
|
|
platform::errors::ResourceExhausted(
|
|
|
|
|
"\n\nOut of memory error on XPU, Cannot"
|
|
|
|
|
"allocate %s memory on XPU. \n\nPlease "
|
|
|
|
|
"check whether there is any other process "
|
|
|
|
|
"using XPU.\n",
|
|
|
|
|
string::HumanReadableSize(n * sizeof(void*))));
|
|
|
|
|
for (auto i = 0; i < n; ++i) {
|
|
|
|
|
((const void**)x_datas_host)[i] = x[i]->data<T>();
|
|
|
|
|
|
|
|
|
|
std::vector<const T*> x_list;
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
|
x_list.push_back(x[i]->data<T>());
|
|
|
|
|
}
|
|
|
|
|
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()),
|
|
|
|
|
x_datas_device, platform::CPUPlace(), x_datas_host,
|
|
|
|
|
n * sizeof(void*));
|
|
|
|
|
int r = xpu::stack_forward<float>(dev_ctx.x_context(), pre, post, n,
|
|
|
|
|
x_datas_device, y_data);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, xpu::Error_t::SUCCESS,
|
|
|
|
|
platform::errors::External(
|
|
|
|
|
"The stack XPU API return wrong value[%d], please check "
|
|
|
|
|
"where Baidu Kunlun Card is properly installed.",
|
|
|
|
|
r));
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
std::free(x_datas_host);
|
|
|
|
|
xpu_free(x_datas_device);
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
int r =
|
|
|
|
|
xpu::concat<T>(dev_ctx.x_context(), x_list, y_data, xdims_list, axis);
|
|
|
|
|
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
|
|
|
|
|
platform::errors::External(
|
|
|
|
|
"The stack XPU API return wrong value[%d %s]", r,
|
|
|
|
|
XPUAPIErrorMsg[r]));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|