|
|
|
@ -78,12 +78,13 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
|
|
|
|
auto input = GetRealInput(node_inputs[i]);
|
|
|
|
|
|
|
|
|
|
if (HasAbstractMonad(input)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (input->isa<Parameter>()) {
|
|
|
|
|
auto input_parameter = input->cast<ParameterPtr>();
|
|
|
|
|
is_parameter.push_back(ParameterRequireGrad(input_parameter));
|
|
|
|
|
} else if ((input->isa<CNode>() && !HasAbstractMonad(input)) || IsValueNode<tensor::Tensor>(input) ||
|
|
|
|
|
IsValueNode<RefKey>(input)) {
|
|
|
|
|
} else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
|
|
|
|
|
is_parameter.push_back(false);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -174,6 +175,9 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
|
|
|
|
|
|
|
|
|
|
// extract input element length
|
|
|
|
|
for (auto &input : node_inputs) {
|
|
|
|
|
if (HasAbstractMonad(input)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (IsValueNode<RefKey>(input)) {
|
|
|
|
|
auto func_graph = node->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
@ -182,8 +186,7 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
|
|
|
|
|
}
|
|
|
|
|
inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
|
|
|
|
|
} else if ((input->isa<CNode>() && !HasAbstractMonad(input)) || input->isa<Parameter>() ||
|
|
|
|
|
IsValueNode<tensor::Tensor>(input)) {
|
|
|
|
|
} else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
|
|
|
|
|
// extract input shape from parameter and apply node
|
|
|
|
|
inputs_type_len.push_back(GetInputsTypeLen(input));
|
|
|
|
|
}
|
|
|
|
|