|
|
|
@ -71,6 +71,15 @@ static const std::vector<const Tensor*> ReduceMultiInput(
|
|
|
|
|
return reduced;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static const std::vector<int> GetDimsForKey(
|
|
|
|
|
const std::vector<const Tensor*>& inputs) {
|
|
|
|
|
auto dims_key = paddle::framework::vectorize<int>(inputs[0]->dims());
|
|
|
|
|
for (auto it = std::next(inputs.begin()); it != inputs.end(); ++it) {
|
|
|
|
|
dims_key.push_back((*it)->dims()[0]);
|
|
|
|
|
}
|
|
|
|
|
return dims_key;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ConcatPrimitiveFactory {
|
|
|
|
|
public:
|
|
|
|
@ -134,6 +143,8 @@ template <typename T>
|
|
|
|
|
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
|
|
|
|
|
// If any of the multiple inputs of concat has an input size of 0, the
|
|
|
|
|
// actual size of the multi_input will change
|
|
|
|
|
auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X"));
|
|
|
|
|
EnforceLayouts(multi_input);
|
|
|
|
|
Tensor* output = ctx.Output<Tensor>("Out");
|
|
|
|
@ -156,12 +167,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
|
|
|
|
|
|
|
|
|
|
ConcatPrimitiveFactory<T> prim_creator;
|
|
|
|
|
// If one of the multiple inputs of concat has an input size of 0, the
|
|
|
|
|
// actual size of the multi_input will change
|
|
|
|
|
std::string key = platform::CreateKey(
|
|
|
|
|
dev_ctx, paddle::framework::vectorize<int>(multi_input[0]->dims()),
|
|
|
|
|
multi_input.size(), ctx.OutputName("Out"), dt,
|
|
|
|
|
platform::ThreadIDasStr());
|
|
|
|
|
std::string key =
|
|
|
|
|
platform::CreateKey(dev_ctx, GetDimsForKey(multi_input),
|
|
|
|
|
multi_input.size(), ctx.OutputName("Out"), dt);
|
|
|
|
|
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
|
|
|
|
|
|
|
|
|
|
const std::string key_prim = key + "@concat_p";
|
|
|
|
|