|
|
|
@ -30,6 +30,7 @@
|
|
|
|
|
#include "pipeline/parse/parse_base.h"
|
|
|
|
|
#include "ir/value.h"
|
|
|
|
|
#include "ir/tensor.h"
|
|
|
|
|
#include "ir/param_value_py.h"
|
|
|
|
|
#include "utils/base_ref_extends.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
@ -426,7 +427,17 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
|
|
|
|
|
MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
|
|
|
|
|
<< " add Parameter count " << func_graph->hyper_param_count() << ".";
|
|
|
|
|
}
|
|
|
|
|
*ret_val = args[index];
|
|
|
|
|
if (index < args.size()) {
|
|
|
|
|
*ret_val = args[index];
|
|
|
|
|
} else {
|
|
|
|
|
auto param = dyn_cast<Parameter>(params[index]);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param);
|
|
|
|
|
if (!param->has_default()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
|
|
|
|
|
}
|
|
|
|
|
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
|
|
|
|
|
*ret_val = param_value->value().attr("data");
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|