|
|
|
@ -254,12 +254,9 @@ size_t GetInputsTypeLen(const AnfNodePtr &input) {
|
|
|
|
|
return input_type_len;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Given the node, return the element length of input and output
|
|
|
|
|
std::vector<std::vector<size_t>> ExtractInputAndOutputTypeLengthByNode(const CNodePtr &node) {
|
|
|
|
|
std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
std::vector<size_t> inputs_type_len;
|
|
|
|
|
std::vector<size_t> outputs_type_len;
|
|
|
|
|
std::vector<std::vector<size_t>> all_types;
|
|
|
|
|
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
|
|
|
|
|
|
|
|
|
// extract input element length
|
|
|
|
@ -277,9 +274,13 @@ std::vector<std::vector<size_t>> ExtractInputAndOutputTypeLengthByNode(const CNo
|
|
|
|
|
inputs_type_len.push_back(GetInputsTypeLen(input));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
all_types.push_back(inputs_type_len);
|
|
|
|
|
return inputs_type_len;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// extract output element length
|
|
|
|
|
std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
std::vector<TypePtr> outputs_type;
|
|
|
|
|
// extract output element type
|
|
|
|
|
auto primary_output_type = node->Type();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primary_output_type);
|
|
|
|
|
if (primary_output_type->isa<mindspore::Tuple>()) {
|
|
|
|
@ -289,7 +290,7 @@ std::vector<std::vector<size_t>> ExtractInputAndOutputTypeLengthByNode(const CNo
|
|
|
|
|
for (auto &ele : elements) {
|
|
|
|
|
if (ele->isa<mindspore::TensorType>()) {
|
|
|
|
|
auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
|
|
|
|
|
outputs_type_len.push_back(GetLengthOfDataType(ele_element_type));
|
|
|
|
|
outputs_type.push_back(ele_element_type);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
|
|
|
|
|
}
|
|
|
|
@ -298,14 +299,12 @@ std::vector<std::vector<size_t>> ExtractInputAndOutputTypeLengthByNode(const CNo
|
|
|
|
|
// in this case, the output is a single tensor
|
|
|
|
|
if (primary_output_type->isa<mindspore::TensorType>()) {
|
|
|
|
|
auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
|
|
|
|
|
outputs_type_len.push_back(GetLengthOfDataType(element_type));
|
|
|
|
|
outputs_type.push_back(element_type);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
all_types.push_back(outputs_type_len);
|
|
|
|
|
|
|
|
|
|
return all_types;
|
|
|
|
|
return outputs_type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Be careful the argument is cnode_full_name, not the op_name
|
|
|
|
@ -366,11 +365,20 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// Set the data type for inputs and outputs of this OperatorInfo
|
|
|
|
|
std::vector<std::vector<size_t>> type_lengths = ExtractInputAndOutputTypeLengthByNode(cnode);
|
|
|
|
|
if (operator_info->SetInputAndOutputTypeLength(type_lengths[0], type_lengths[1]) != SUCCESS) {
|
|
|
|
|
auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
|
|
|
|
|
auto outputs_type = ExtractOutputTypeByNode(cnode);
|
|
|
|
|
std::vector<size_t> outputs_type_length;
|
|
|
|
|
outputs_type_length.reserve(outputs_type.size());
|
|
|
|
|
std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
|
|
|
|
|
GetLengthOfDataType);
|
|
|
|
|
if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// When the 'inputs' contains numerical values for some operators, these values should be extracted from
|
|
|
|
|
// ANF graph
|
|
|
|
|
auto &inputs = cnode->inputs();
|
|
|
|
|