Add converter method for operator 'Split'.

pull/6343/head
wsc 4 years ago
parent 9da592a99f
commit 5953855106

@ -418,6 +418,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<StridedSlice>(prim, inputs, quantType);
} else if (op_type == "Cast") {
return NewPrimitiveC<Cast>(prim, inputs, quantType);
} else if (op_type == "Split") {
return NewPrimitiveC<Split>(prim, inputs, quantType);
#ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {

@ -29,6 +29,37 @@ void Split::SetSizeSplits(const std::vector<int> &size_splits) {
}
void Split::SetSplitDim(int split_dim) { this->primitive_->value.AsSplit()->splitDim = split_dim; }
int Split::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Split;
}
if (this->primitive_->value.type != schema::PrimitiveType_Split) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SplitT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->splitDim = GetValue<int32_t>(prim.GetAttr("axis"));
attr->numberSplit = GetValue<int32_t>(prim.GetAttr("output_num"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); }
@ -99,12 +130,14 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
int split_dim_i = input_shape[split_dim];
// support split size is -1 in the end.
if (i == number_split - 1 && size_split[i] == -1) {
if (size_split.empty()) {
split_dim_i = input_shape[split_dim] / number_split;
} else if (i == number_split - 1 && size_split[i] == -1) {
for (size_t j = 0; j < size_split.size() - 1; ++j) {
split_dim_i -= size_split[j];
}
} else {
split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i];
split_dim_i = size_split[i];
}
output_shape[split_dim] = split_dim_i;
outputs_[i]->set_shape(output_shape);

@ -35,6 +35,7 @@ class Split : public PrimitiveC {
void SetNumberSplit(int number_split);
void SetSizeSplits(const std::vector<int> &size_splits);
void SetSplitDim(int split_dim);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
#else
Split() = default;

@ -419,10 +419,13 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
}
#endif
} else if (value->isa<Number>()) {
MS_LOG(INFO) << "Value is a number.";
return RET_OK;
} else {
MS_LOG(ERROR) << "Not support value type , need add support.";
return RET_ERROR;
}
MS_LOG(ERROR) << "Not support value type , need add support.";
return RET_ERROR;
}
return RET_OK;
}

Loading…
Cancel
Save