|
|
|
@ -59,6 +59,16 @@ static const mkldnn::engine& GetMKLDNNEngine(
|
|
|
|
|
return dev_ctx.GetEngine();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// From a multi-input, gather only nonempty inputs
|
|
|
|
|
static const std::vector<const Tensor*> ReduceMultiInput(
|
|
|
|
|
const std::vector<const Tensor*>& inputs) {
|
|
|
|
|
std::vector<const Tensor*> reduced(inputs.size());
|
|
|
|
|
auto end_it = std::copy_if(inputs.begin(), inputs.end(), reduced.begin(),
|
|
|
|
|
[](const Tensor* t) { return t->numel() > 0; });
|
|
|
|
|
reduced.resize(std::distance(reduced.begin(), end_it));
|
|
|
|
|
return reduced;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ConcatPrimitiveFactory {
|
|
|
|
|
public:
|
|
|
|
@ -120,7 +130,7 @@ template <typename T>
|
|
|
|
|
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto multi_input = ctx.MultiInput<Tensor>("X");
|
|
|
|
|
auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X"));
|
|
|
|
|
EnforceLayouts(multi_input);
|
|
|
|
|
Tensor* output = ctx.Output<Tensor>("Out");
|
|
|
|
|
int concat_axis = ctx.Attr<int>("axis");
|
|
|
|
|