|
|
|
@ -36,11 +36,16 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
|
|
|
|
"XPU donot surpport AxisTensor for now"));
|
|
|
|
"XPU donot surpport AxisTensor for now"));
|
|
|
|
axis = ComputeAxis(static_cast<int64_t>(axis),
|
|
|
|
axis = ComputeAxis(static_cast<int64_t>(axis),
|
|
|
|
static_cast<int64_t>(ins[0]->dims().size()));
|
|
|
|
static_cast<int64_t>(ins[0]->dims().size()));
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
|
|
|
|
axis, 0, platform::errors::InvalidArgument("concat: axis shoud >= 0!"));
|
|
|
|
"concat: axis should be larger than or "
|
|
|
|
|
|
|
|
"equal to 0, but received axis is %d.",
|
|
|
|
|
|
|
|
axis));
|
|
|
|
PADDLE_ENFORCE_LT(axis, ins[0]->dims().size(),
|
|
|
|
PADDLE_ENFORCE_LT(axis, ins[0]->dims().size(),
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"concat: axis shoud < ins[0]->dims()!"));
|
|
|
|
"concat: axis should be less than ins[0]->dims()!"
|
|
|
|
|
|
|
|
"But received axis is %d, while ins[0]->dims()"
|
|
|
|
|
|
|
|
"size is %d.",
|
|
|
|
|
|
|
|
axis, ins[0]->dims().size()));
|
|
|
|
|
|
|
|
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
out->mutable_data<T>(place);
|
|
|
|
out->mutable_data<T>(place);
|
|
|
|
@ -151,10 +156,16 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
|
|
|
|
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
|
|
|
|
"concat_grad: axis shoud >= 0!"));
|
|
|
|
"concat_grad: axis should be larger than or "
|
|
|
|
PADDLE_ENFORCE_LT(axis, out_grad->dims().size(),
|
|
|
|
"equal to 0, but received axis is %d.",
|
|
|
|
|
|
|
|
axis));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
|
|
|
axis, out_grad->dims().size(),
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"concat_grad: axis shoud < ins[0]->dims()!"));
|
|
|
|
"concat_grad: axis should be less than ins[0]->dims()!"
|
|
|
|
|
|
|
|
"But received axis is %d, while ins[0]->dims()"
|
|
|
|
|
|
|
|
"size is %d.",
|
|
|
|
|
|
|
|
axis, out_grad->dims().size()));
|
|
|
|
|
|
|
|
|
|
|
|
auto input_dims = ins[0]->dims();
|
|
|
|
auto input_dims = ins[0]->dims();
|
|
|
|
std::vector<int> split_list(n);
|
|
|
|
std::vector<int> split_list(n);
|
|
|
|
|