|
|
|
@ -39,14 +39,22 @@ class SplitOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
if (num > 0) {
|
|
|
|
|
int64_t in_axis_dim = in_dims[axis];
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
|
|
|
|
|
"tensor split does not result"
|
|
|
|
|
" in an equal division");
|
|
|
|
|
size_t out_axis_dim = in_axis_dim / num;
|
|
|
|
|
for (size_t i = 0; i < outs_number; ++i) {
|
|
|
|
|
auto dim = in_dims;
|
|
|
|
|
dim[axis] = out_axis_dim;
|
|
|
|
|
outs_dims.push_back(dim);
|
|
|
|
|
if (ctx->IsRuntime() || in_axis_dim > 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
|
|
|
|
|
"tensor split does not result"
|
|
|
|
|
" in an equal division");
|
|
|
|
|
size_t out_axis_dim = in_axis_dim / num;
|
|
|
|
|
for (size_t i = 0; i < outs_number; ++i) {
|
|
|
|
|
auto dim = in_dims;
|
|
|
|
|
dim[axis] = out_axis_dim;
|
|
|
|
|
outs_dims.push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t i = 0; i < outs_number; ++i) {
|
|
|
|
|
auto dim = in_dims;
|
|
|
|
|
dim[axis] = -1;
|
|
|
|
|
outs_dims.push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (sections.size() > 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(sections.size(), outs_number,
|
|
|
|
|