|
|
|
@ -134,6 +134,15 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
EnforceLayouts(multi_input);
|
|
|
|
|
Tensor* output = ctx.Output<Tensor>("Out");
|
|
|
|
|
int concat_axis = ctx.Attr<int>("axis");
|
|
|
|
|
const int rank = multi_input[0]->dims().size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
concat_axis >= -rank && concat_axis < rank, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The axis is expected to be in range of [%d, %d), but got %d",
|
|
|
|
|
-rank, rank, concat_axis));
|
|
|
|
|
if (concat_axis < 0) {
|
|
|
|
|
concat_axis = concat_axis + rank;
|
|
|
|
|
}
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
|
|
|
|
|
auto place = GetCpuPlace(ctx);
|
|
|
|
|