|
|
|
@ -29,6 +29,37 @@ void Split::SetSizeSplits(const std::vector<int> &size_splits) {
|
|
|
|
|
}
|
|
|
|
|
void Split::SetSplitDim(int split_dim) { this->primitive_->value.AsSplit()->splitDim = split_dim; }
|
|
|
|
|
|
|
|
|
|
int Split::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
|
|
|
|
if (this->primitive_ == nullptr) {
|
|
|
|
|
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
|
|
|
|
if (this->primitive_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new primitiveT failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
this->primitive_->value.type = schema::PrimitiveType_Split;
|
|
|
|
|
}
|
|
|
|
|
if (this->primitive_->value.type != schema::PrimitiveType_Split) {
|
|
|
|
|
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (this->primitive_->value.value == nullptr) {
|
|
|
|
|
auto attr = new (std::nothrow) schema::SplitT();
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new primitiveT value failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
attr->splitDim = GetValue<int32_t>(prim.GetAttr("axis"));
|
|
|
|
|
attr->numberSplit = GetValue<int32_t>(prim.GetAttr("output_num"));
|
|
|
|
|
this->primitive_->value.value = attr;
|
|
|
|
|
if (this->primitive_->value.value == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "primitive value is nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
|
|
|
|
|
int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); }
|
|
|
|
@ -99,12 +130,14 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
|
|
|
|
|
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
|
|
|
|
|
int split_dim_i = input_shape[split_dim];
|
|
|
|
|
// support split size is -1 in the end.
|
|
|
|
|
if (i == number_split - 1 && size_split[i] == -1) {
|
|
|
|
|
if (size_split.empty()) {
|
|
|
|
|
split_dim_i = input_shape[split_dim] / number_split;
|
|
|
|
|
} else if (i == number_split - 1 && size_split[i] == -1) {
|
|
|
|
|
for (size_t j = 0; j < size_split.size() - 1; ++j) {
|
|
|
|
|
split_dim_i -= size_split[j];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i];
|
|
|
|
|
split_dim_i = size_split[i];
|
|
|
|
|
}
|
|
|
|
|
output_shape[split_dim] = split_dim_i;
|
|
|
|
|
outputs_[i]->set_shape(output_shape);
|
|
|
|
|