|
|
|
@ -37,6 +37,9 @@ class ConcatKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (axis == 0 && ins.size() < 10) {
|
|
|
|
|
size_t output_offset = 0;
|
|
|
|
|
for (auto* in : ins) {
|
|
|
|
|
if (!in || in->numel() == 0UL) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto in_stride = framework::stride_numel(in->dims());
|
|
|
|
|
auto out_stride = framework::stride_numel(out->dims());
|
|
|
|
|
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
|
|
|
|
@ -45,9 +48,13 @@ class ConcatKernel : public framework::OpKernel<T> {
|
|
|
|
|
output_offset += in_stride[axis];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<framework::Tensor> inputs(ins.size());
|
|
|
|
|
std::vector<framework::Tensor> inputs;
|
|
|
|
|
for (size_t j = 0; j < ins.size(); ++j) {
|
|
|
|
|
inputs[j] = *ins[j];
|
|
|
|
|
if (ins[j] && ins[j]->numel() > 0) {
|
|
|
|
|
inputs.push_back(*ins[j]);
|
|
|
|
|
} else {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
|
|
|
|
@ -82,7 +89,8 @@ class ConcatGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// get output tensor that the name is not kEmptyVarName
|
|
|
|
|
std::vector<framework::Tensor*> outputs;
|
|
|
|
|
for (size_t j = 0; j < outs.size(); ++j) {
|
|
|
|
|
if (out_var_names[j] != framework::kEmptyVarName) {
|
|
|
|
|
if (out_var_names[j] != framework::kEmptyVarName &&
|
|
|
|
|
outs[j]->numel() != 0UL) {
|
|
|
|
|
outs[j]->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
outputs.push_back(outs[j]);
|
|
|
|
|
} else {
|
|
|
|
|