fix mkldnn concat bug. test=develop (#24722)

v1.8
Wilber 5 years ago committed by GitHub
parent b9260b365a
commit dbe2497768
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -134,6 +134,15 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out");
int concat_axis = ctx.Attr<int>("axis");
const int rank = multi_input[0]->dims().size();
PADDLE_ENFORCE_EQ(
concat_axis >= -rank && concat_axis < rank, true,
platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank, rank, concat_axis));
if (concat_axis < 0) {
concat_axis = concat_axis + rank;
}
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
auto place = GetCpuPlace(ctx);

Loading…
Cancel
Save