|
|
|
@ -725,23 +725,15 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) {
|
|
|
|
|
return BaseRefToPyData(value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::object StructureOutput(const AbstractBasePtr& output, const py::tuple& data, size_t* count) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
|
py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::tuple& data, size_t* count) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_data);
|
|
|
|
|
if (*count >= data.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
|
|
|
|
<< " less than the number of elements required. ";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!output->isa<AbstractTuple>()) {
|
|
|
|
|
ValuePtr value = output->BuildValue();
|
|
|
|
|
if (value != kAnyValue) {
|
|
|
|
|
return ValuePtrToPyData(value);
|
|
|
|
|
}
|
|
|
|
|
if (!output->isa<AbstractTensor>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Output can only be tensor except for constants, but got "
|
|
|
|
|
<< output->BuildValue()->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
|
if (*count >= data.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
|
|
|
|
<< " less than the number of elements required. ";
|
|
|
|
|
}
|
|
|
|
|
auto shape = output->BuildShape();
|
|
|
|
|
if (cnode_data->isa<AbstractTensor>()) {
|
|
|
|
|
BaseShapePtr shape = cnode_data->BuildShape();
|
|
|
|
|
auto shape_act = shape->cast<abstract::ShapePtr>()->shape();
|
|
|
|
|
Tensor tensor_exp = py::cast<Tensor>(data[*count]);
|
|
|
|
|
if (shape_act != tensor_exp.shape()) {
|
|
|
|
@ -751,16 +743,58 @@ py::object StructureOutput(const AbstractBasePtr& output, const py::tuple& data,
|
|
|
|
|
return data[(*count)++];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto tuple_output = output->cast<AbstractTuplePtr>();
|
|
|
|
|
AbstractBasePtrList elements = tuple_output->elements();
|
|
|
|
|
size_t size = elements.size();
|
|
|
|
|
if (!cnode_data->isa<AbstractTuple>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The output of operator in the final anf graph could "
|
|
|
|
|
<< "only be a tensor or a tuple of tensor, but got " << cnode_data->BuildValue()->ToString()
|
|
|
|
|
<< ".";
|
|
|
|
|
}
|
|
|
|
|
auto data_tp = cnode_data->cast<AbstractTuplePtr>();
|
|
|
|
|
auto elements = data_tp->elements();
|
|
|
|
|
size_t size = data_tp->size();
|
|
|
|
|
py::tuple tp = py::tuple(size);
|
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
|
tp[i] = StructureOutput(elements[i], data, count);
|
|
|
|
|
tp[i] = ExtractGeneralCnodeRet(elements[i], data, count);
|
|
|
|
|
}
|
|
|
|
|
return std::move(tp);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, size_t* count) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_node);
|
|
|
|
|
|
|
|
|
|
if (output_node->isa<ValueNode>()) {
|
|
|
|
|
return ValuePtrToPyData(GetValueNode(output_node));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (*count >= data.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
|
|
|
|
|
<< " less than the number of elements required. ";
|
|
|
|
|
}
|
|
|
|
|
if (output_node->isa<Parameter>()) {
|
|
|
|
|
return data[(*count)++];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto output_c = output_node->cast<CNodePtr>();
|
|
|
|
|
if (output_c == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The final anf graph could only have constant, parameter, and operator, but got "
|
|
|
|
|
<< output_node->ToString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (output_c->IsApply(prim::kPrimMakeTuple)) {
|
|
|
|
|
auto input_list = output_c->inputs();
|
|
|
|
|
size_t size = input_list.size();
|
|
|
|
|
py::tuple tp = py::tuple(size - 1);
|
|
|
|
|
for (size_t i = 1; i < size; i++) {
|
|
|
|
|
tp[i - 1] = StructureOutput(input_list[i], data, count);
|
|
|
|
|
}
|
|
|
|
|
return std::move(tp);
|
|
|
|
|
}
|
|
|
|
|
if (output_c->IsApply(prim::kPrimDepend)) {
|
|
|
|
|
return StructureOutput(output_c->input(1), data, count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ExtractGeneralCnodeRet(output_c->abstract(), data, count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::vector<MeTensorPtr>& inputs,
|
|
|
|
|
const std::string& phase) {
|
|
|
|
|
std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW);
|
|
|
|
@ -806,11 +840,10 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::ve
|
|
|
|
|
std::shared_ptr<py::object> ret = nullptr;
|
|
|
|
|
|
|
|
|
|
#ifdef ENABLE_GE
|
|
|
|
|
AnfNodePtr root = graph->get_return();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(root);
|
|
|
|
|
AbstractBasePtr output = root->abstract();
|
|
|
|
|
AnfNodePtr output_node = graph->get_return()->input(1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_node);
|
|
|
|
|
size_t count = 0;
|
|
|
|
|
py::object oj = StructureOutput(output, outputs, &count);
|
|
|
|
|
py::object oj = StructureOutput(output_node, outputs, &count);
|
|
|
|
|
ret = std::make_shared<py::object>(oj);
|
|
|
|
|
#else
|
|
|
|
|
if (outputs.size() == 1) {
|
|
|
|
|