|
|
@ -142,9 +142,12 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
|
|
|
|
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
|
|
|
|
|
|
|
|
|
|
|
|
ConcatPrimitiveFactory<T> prim_creator;
|
|
|
|
ConcatPrimitiveFactory<T> prim_creator;
|
|
|
|
|
|
|
|
// If one of the multiple inputs of concat has an input size of 0, the
|
|
|
|
|
|
|
|
// actual size of the multi_input will change
|
|
|
|
std::string key = platform::CreateKey(
|
|
|
|
std::string key = platform::CreateKey(
|
|
|
|
paddle::framework::vectorize<int>(multi_input[0]->dims()),
|
|
|
|
paddle::framework::vectorize<int>(multi_input[0]->dims()),
|
|
|
|
ctx.OutputName("Out"), dt, platform::ThreadIDasStr());
|
|
|
|
multi_input.size(), ctx.OutputName("Out"), dt,
|
|
|
|
|
|
|
|
platform::ThreadIDasStr());
|
|
|
|
|
|
|
|
|
|
|
|
const std::string key_prim = key + "@concat_p";
|
|
|
|
const std::string key_prim = key + "@concat_p";
|
|
|
|
const std::string key_concat_pd = key + "@concat_pd";
|
|
|
|
const std::string key_concat_pd = key + "@concat_pd";
|
|
|
|