Enhance concat op to support empty input. (#17015)

* enhance_concat, test=develop
enh_seq_pool
jerrywgz 6 years ago committed by GitHub
parent 83c4f7721f
commit a72907bbf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,6 +37,9 @@ class ConcatKernel : public framework::OpKernel<T> {
if (axis == 0 && ins.size() < 10) {
size_t output_offset = 0;
for (auto* in : ins) {
if (!in || in->numel() == 0UL) {
continue;
}
auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
@ -45,9 +48,13 @@ class ConcatKernel : public framework::OpKernel<T> {
output_offset += in_stride[axis];
}
} else {
std::vector<framework::Tensor> inputs(ins.size());
std::vector<framework::Tensor> inputs;
for (size_t j = 0; j < ins.size(); ++j) {
inputs[j] = *ins[j];
if (ins[j] && ins[j]->numel() > 0) {
inputs.push_back(*ins[j]);
} else {
continue;
}
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
@ -82,7 +89,8 @@ class ConcatGradKernel : public framework::OpKernel<T> {
// get output tensor that the name is not kEmptyVarName
std::vector<framework::Tensor*> outputs;
for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName) {
if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs.push_back(outs[j]);
} else {

@ -64,5 +64,16 @@ class TestConcatOp3(TestConcatOp):
pass
class TestConcatOp4(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x1 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x2 = np.random.random((0, 3, 4, 5)).astype('float32')
self.axis = 0
def test_check_grad(self):
pass
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save