|
|
|
@ -31,14 +31,16 @@ class SplitByrefOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto in_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto outs_names = ctx->Outputs("Out");
|
|
|
|
|
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
|
|
|
|
|
std::vector<int> sections = static_cast<std::vector<int>>(
|
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("sections"));
|
|
|
|
|
auto sections = ctx->Attrs().Get<std::vector<int>>("sections");
|
|
|
|
|
const size_t outs_number = outs_names.size();
|
|
|
|
|
std::vector<framework::DDim> outs_dims;
|
|
|
|
|
outs_dims.reserve(outs_number);
|
|
|
|
|
|
|
|
|
|
if (num > 0) {
|
|
|
|
|
int64_t in_axis_dim = in_dims[0];
|
|
|
|
|
int64_t in_axis_dim = 0;
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
in_axis_dim = in_dims[0];
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
|
|
|
|
|
"tensor split does not result"
|
|
|
|
|
" in an equal division");
|
|
|
|
|