|
|
|
@ -657,14 +657,17 @@ void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_i
|
|
|
|
|
if (cnode != nullptr) {
|
|
|
|
|
cnode->set_abstract(op_exec_info->abstract);
|
|
|
|
|
}
|
|
|
|
|
std::string obj_id = GetId(out_real);
|
|
|
|
|
node_abs_map_[obj_id] = op_exec_info->abstract;
|
|
|
|
|
|
|
|
|
|
// Save info for building grad graph
|
|
|
|
|
if (grad()->grad_flag() && grad()->in_grad_process()) {
|
|
|
|
|
std::string obj_id = GetId(out_real);
|
|
|
|
|
node_abs_map_[obj_id] = op_exec_info->abstract;
|
|
|
|
|
grad()->SaveOutputNodeMap(obj_id, out_real, cnode);
|
|
|
|
|
grad()->SaveAllResult(op_exec_info, cnode, out_real);
|
|
|
|
|
// Update the abstract and device address of value node with tensor in grad graph
|
|
|
|
|
UpdateAbstractAndDeviceAddress(op_exec_info, out_real);
|
|
|
|
|
} else {
|
|
|
|
|
node_abs_map_.clear();
|
|
|
|
|
}
|
|
|
|
|
*ret = out_real;
|
|
|
|
|
}
|
|
|
|
@ -810,8 +813,8 @@ abstract::AbstractBasePtr ForwardExecutor::CheckConstValue(const PrimitivePyPtr
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_abs);
|
|
|
|
|
new_abs = new_abs->Broaden(config);
|
|
|
|
|
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
|
|
|
|
|
node_abs_map_[id] = new_abs;
|
|
|
|
|
}
|
|
|
|
|
node_abs_map_[id] = new_abs;
|
|
|
|
|
}
|
|
|
|
|
return new_abs;
|
|
|
|
|
}
|
|
|
|
|