Modify code to support dynamic graph.

pull/1421/head
rick_sanchez 6 years ago committed by kpy
parent 72fd41786c
commit e2a322b6b7

@ -19,14 +19,15 @@ Interfaces for parser module in c++.
from .parser import (Parser, create_obj_instance, generate_scope, from .parser import (Parser, create_obj_instance, generate_scope,
get_bprop_method_of_class, get_class_instance_type, get_bprop_method_of_class, get_class_instance_type,
get_class_member_namespace_symbol, create_slice_obj, get_class_member_namespace_symbol, create_slice_obj,
get_dataclass_attributes, get_dataclass_methods, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key, get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name, get_default_input, get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj) is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
from .serialize import * from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type', 'get_object_key', 'get_default_input', 'get_class_instance_type', 'is_class_member',
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj', 'create_ellipsis_obj'] 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
'create_slice_obj', 'create_ellipsis_obj']

@ -209,6 +209,14 @@ def get_object_key(obj):
obj_id = instance_id + obj_id obj_id = instance_id + obj_id
return obj_id, obj_key return obj_id, obj_key
def get_default_input(obj):
if hasattr(obj, '__parameter__'):
return obj.default_input
if isinstance(obj, tuple):
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
args = tuple(convert(x) for x in obj)
return args
return obj
def is_class_member(node): def is_class_member(node):
"""Check the attr is class member variable.""" """Check the attr is class member variable."""
@ -221,6 +229,9 @@ def is_class_member(node):
return True return True
return False return False
def get_obj_id(obj):
"""Get the obj id."""
return str(id(obj))
def get_obj_type(obj): def get_obj_type(obj):
"""Get the obj type.""" """Get the obj type."""

@ -328,9 +328,6 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
DropEdge(node, index, inp); DropEdge(node, index, inp);
} else { } else {
MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString();
if (inp->func_graph() != nullptr) {
AddFuncGraph(inp->func_graph());
}
if (IsValueNode<FuncGraph>(inp)) { if (IsValueNode<FuncGraph>(inp)) {
MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString();
AddFuncGraph(GetValueNode<FuncGraphPtr>(inp)); AddFuncGraph(GetValueNode<FuncGraphPtr>(inp));
@ -372,9 +369,8 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
for (auto &node : acq) { for (auto &node : acq) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr fg = node->func_graph(); auto fg = node->func_graph();
if (fg != nullptr) { if (fg != nullptr) {
AddFuncGraph(fg);
fg->AddNode(node); fg->AddNode(node);
} }
ProcessInputs(node, kIncEdge); ProcessInputs(node, kIncEdge);

@ -28,7 +28,7 @@ namespace py = pybind11;
class ParamValuePy : public ParamValue { class ParamValuePy : public ParamValue {
public: public:
ParamValuePy() : value_(py::none()) {} ParamValuePy() : value_(py::none()) {}
explicit ParamValuePy(py::object value) : value_(value) {} explicit ParamValuePy(const py::object &value) : value_(value) {}
~ParamValuePy() override = default; ~ParamValuePy() override = default;
py::object value() { return value_; } py::object value() { return value_; }

@ -75,7 +75,7 @@ py::function PrimitivePy::GetComputeFunction() {
py::function vm_fn = get_fn(python_obj_); py::function vm_fn = get_fn(python_obj_);
if (py::isinstance<py::none>(vm_fn)) { if (py::isinstance<py::none>(vm_fn)) {
MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>(); MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
vm_fn = mindspore::GetComputeFunction(Primitive::name()); vm_fn = mindspore::GetComputeFunction(Primitive::name());
} }
return vm_fn; return vm_fn;

@ -81,6 +81,7 @@ Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
: MetaTensor(tensor), device_address_(tensor.device_address_) { : MetaTensor(tensor), device_address_(tensor.device_address_) {
init(tensor.data_, data_type); init(tensor.data_, data_type);
dirty_ = tensor.is_dirty(); dirty_ = tensor.is_dirty();
id_ = tensor.id();
} }
Tensor &Tensor::operator=(const Tensor &tensor) { Tensor &Tensor::operator=(const Tensor &tensor) {
@ -89,6 +90,7 @@ Tensor &Tensor::operator=(const Tensor &tensor) {
dirty_ = tensor.is_dirty(); dirty_ = tensor.is_dirty();
device_address_ = tensor.device_address(); device_address_ = tensor.device_address();
data_ = tensor.data_; data_ = tensor.data_;
id_ = tensor.id();
} }
return *this; return *this;
} }
@ -208,6 +210,7 @@ void Tensor::init(const py::array &input, const TypeId &data_type) {
data_ = input; data_ = input;
} }
dirty_ = true; dirty_ = true;
id_ = std::to_string((uintptr_t)(this));
} }
void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) { void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {
@ -254,6 +257,7 @@ void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *co
MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
break; break;
} }
id_ = std::to_string((uintptr_t)(this));
} }
TypePtr Tensor::SetDtype(const TypePtr type_ptr) { TypePtr Tensor::SetDtype(const TypePtr type_ptr) {

@ -263,9 +263,11 @@ class Tensor : public MetaTensor {
DeviceAddressPtr device_address() const { return device_address_; } DeviceAddressPtr device_address() const { return device_address_; }
void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
py::array data_sync(); py::array data_sync();
std::string id() const { return id_; }
private: private:
bool dirty_{true}; bool dirty_{true};
std::string id_{""};
DeviceAddressPtr device_address_{nullptr}; DeviceAddressPtr device_address_{nullptr};
}; };

@ -501,10 +501,16 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
} }
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &params_list, bool applyJ) { const std::vector<AnfNodePtr> &params_list, const std::vector<AnfNodePtr> &args,
bool applyJ) {
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
auto weights_node = weights;
if (weights == nullptr && !args.empty()) {
weights_node = ret->NewCNode(args);
}
ValueNodePtr opsJ = NewValueNode(prim::kPrimJ); ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem); ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
@ -537,7 +543,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
inputs.push_back(NewValueNode(1)); inputs.push_back(NewValueNode(1));
AnfNodePtr ptrBprop = ret->NewCNode(inputs); AnfNodePtr ptrBprop = ret->NewCNode(inputs);
doGetGrad(ret, out, ptrBprop, weights, opsTupleItem); doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem);
return ret; return ret;
} }

@ -129,7 +129,7 @@ class GradOperation : public MetaFuncGraph {
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams, FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
bool applyJ = false); const std::vector<AnfNodePtr> &args = {}, bool applyJ = false);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
bool sens_param() const { return sens_param_; } bool sens_param() const { return sens_param_; }
bool get_all_; bool get_all_;

@ -285,6 +285,10 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
// and add cast op on other inputs to keep the same type with assigned parameter. // and add cast op on other inputs to keep the same type with assigned parameter.
for (size_t i = 0; i < args_spec_list.size(); ++i) { for (size_t i = 0; i < args_spec_list.size(); ++i) {
AnfNodePtr param = params_list[i]; AnfNodePtr param = params_list[i];
if (args_spec_list[i] == nullptr) {
op_inputs.push_back(param);
continue;
}
SignatureEnumRW sig = SignatureEnumRW::kRWDefault; SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
// If sig_size is 0 use defalut. // If sig_size is 0 use defalut.
if (sig_size > 0 && i < sig_size) { if (sig_size > 0 && i < sig_size) {
@ -292,6 +296,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
} else if (has_var && i >= sig_size) { } else if (has_var && i >= sig_size) {
sig = signature[sig_size - 1].rw; sig = signature[sig_size - 1].rw;
} }
TypePtr type = args_spec_list[i]->GetTypeTrack(); TypePtr type = args_spec_list[i]->GetTypeTrack();
if (type && type->type_id() == kObjectTypeRef) { if (type && type->type_id() == kObjectTypeRef) {
if (sig == SignatureEnumRW::kRWRead) { if (sig == SignatureEnumRW::kRWRead) {

@ -551,6 +551,10 @@ AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
} }
void DFunctor::CallDoutHoleOnTape() { void DFunctor::CallDoutHoleOnTape() {
if (!is_top_) {
return;
}
// Call dout hole of all adjoint. // Call dout hole of all adjoint.
for (auto &f : func_graph_to_functor_) { for (auto &f : func_graph_to_functor_) {
for (auto &adjoint : f.second->anfnode_to_adjoin_) { for (auto &adjoint : f.second->anfnode_to_adjoin_) {

@ -55,6 +55,8 @@ class DFunctor {
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
// Register functor objects to form a global view. // Register functor objects to form a global view.
void Init(const DFunctorPtr &functor, bool is_top = false); void Init(const DFunctorPtr &functor, bool is_top = false);
bool IsInScope(const AnfNodePtr &node);
// Clear resources. // Clear resources.
static void Clear(); static void Clear();
@ -62,7 +64,6 @@ class DFunctor {
// Map one morphism. // Map one morphism.
AdjointPtr MapMorphism(const AnfNodePtr &morph); AdjointPtr MapMorphism(const AnfNodePtr &morph);
bool IsFreeMorphism(const AnfNodePtr &node); bool IsFreeMorphism(const AnfNodePtr &node);
bool IsInScope(const AnfNodePtr &node);
// Map morphism that's not attached to output. // Map morphism that's not attached to output.
void MapFreeMorphism(); void MapFreeMorphism();
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);

@ -23,7 +23,7 @@
namespace mindspore { namespace mindspore {
namespace ad { namespace ad {
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources) { FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto gradkv = func_graph->transforms().find("grad"); auto gradkv = func_graph->transforms().find("grad");
if (gradkv != func_graph->transforms().end()) { if (gradkv != func_graph->transforms().end()) {
@ -46,14 +46,18 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
auto user_defined = f->KUserDefined(func_graph); auto user_defined = f->KUserDefined(func_graph);
if (user_defined != nullptr) { if (user_defined != nullptr) {
multi_graph_sink(user_defined); multi_graph_sink(user_defined);
DFunctor::Clear(); if (is_top) {
DFunctor::Clear();
}
return user_defined; return user_defined;
} }
f->Init(f, true); f->Init(f, is_top);
f->MapObject(); f->MapObject();
f->MapMorphism(); f->MapMorphism();
auto ret = f->k_graph(); auto ret = f->k_graph();
DFunctor::Clear(); if (is_top) {
DFunctor::Clear();
}
multi_graph_sink(ret); multi_graph_sink(ret);
return ret; return ret;
@ -71,5 +75,7 @@ MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr
MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim); MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim);
return fg; return fg;
} }
void CleanRes() { DFunctor::Clear(); }
} // namespace ad } // namespace ad
} // namespace mindspore } // namespace mindspore

@ -28,9 +28,10 @@ namespace mindspore {
namespace ad { namespace ad {
using ResourcePtr = std::shared_ptr<pipeline::Resource>; using ResourcePtr = std::shared_ptr<pipeline::Resource>;
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources); FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top = true);
FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &); MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &);
void CleanRes();
} // namespace ad } // namespace ad
} // namespace mindspore } // namespace mindspore

@ -167,7 +167,8 @@ class InlinerBase : public AnfVisitor {
auto params = fg->parameters(); auto params = fg->parameters();
auto old_size = params.size(); auto old_size = params.size();
if (old_size != new_params.size()) { if (old_size != new_params.size()) {
MS_LOG(EXCEPTION) << "Parameter size not match."; MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size()
<< fg->output()->DebugString(10);
} }
for (size_t i = 0; i < old_size; i++) { for (size_t i = 0; i < old_size; i++) {
(void)mng->Replace(params[i], new_params[i]); (void)mng->Replace(params[i], new_params[i]);

@ -276,6 +276,8 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); }
static bool IsCtrlSink() { static bool IsCtrlSink() {
auto ms_ctx = MsContext::GetInstance(); auto ms_ctx = MsContext::GetInstance();
std::string device_target = ms_ctx->device_target(); std::string device_target = ms_ctx->device_target();

@ -35,6 +35,7 @@ bool SymbolResolveAction(const ResourcePtr &res);
bool AbstractSpecializeAction(const ResourcePtr &res); bool AbstractSpecializeAction(const ResourcePtr &res);
bool GeOptimizeAction(const ResourcePtr &res); bool GeOptimizeAction(const ResourcePtr &res);
bool VmOptimizeAction(const ResourcePtr &res); bool VmOptimizeAction(const ResourcePtr &res);
bool PynativeOptimizeAction(const ResourcePtr &res);
bool TaskEmitAction(const ResourcePtr &res); bool TaskEmitAction(const ResourcePtr &res);
bool ExecuteAction(const ResourcePtr &res); bool ExecuteAction(const ResourcePtr &res);

@ -32,6 +32,7 @@
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "debug/trace.h" #include "debug/trace.h"
#include "optimizer/ad/grad.h"
namespace mindspore { namespace mindspore {
namespace parse { namespace parse {
@ -338,6 +339,9 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) {
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>(); std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
converted = env; converted = env;
} else if (py::hasattr(obj, "__parameter__")) {
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
ret = ConvertData(to_convert, &converted);
} else { } else {
ret = ConvertOtherObj(obj, &converted); ret = ConvertOtherObj(obj, &converted);
} }

@ -60,6 +60,7 @@ const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol";
const char PYTHON_MOD_RESOLVE_GET_OBJ_KEY[] = "get_object_key"; const char PYTHON_MOD_RESOLVE_GET_OBJ_KEY[] = "get_object_key";
const char PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER[] = "is_class_member"; const char PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER[] = "is_class_member";
const char PYTHON_MOD_RESOLVE_GET_OBJ_TYPE[] = "get_obj_type"; const char PYTHON_MOD_RESOLVE_GET_OBJ_TYPE[] = "get_obj_type";
const char PYTHON_MOD_GET_OBJ_ID[] = "get_obj_id";
const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type"; const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type";
const char PYTHON_MOD_CREATE_OBJ_INSTANCE[] = "create_obj_instance"; const char PYTHON_MOD_CREATE_OBJ_INSTANCE[] = "create_obj_instance";
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes"; const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
@ -83,6 +84,7 @@ const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
const char PYTHON_MOD_GET_DEFAULT_INPUT[] = "get_default_input";
// define the common name // define the common name
const char NAMED_PRIMITIVE_ITER[] = "iter"; const char NAMED_PRIMITIVE_ITER[] = "iter";

@ -278,5 +278,7 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
{"opt_control", ControlGroup}, {"opt_control", ControlGroup},
{"opt_prepare", PrepareGroup}, {"opt_prepare", PrepareGroup},
{"cconv", CconvPass}}; {"cconv", CconvPass}};
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

@ -29,6 +29,7 @@ using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;
extern std::vector<PassItem> kGePasses; extern std::vector<PassItem> kGePasses;
extern std::vector<PassItem> kVmPasses; extern std::vector<PassItem> kVmPasses;
extern std::vector<PassItem> kPynativePasses;
bool CconvPass(const ResourcePtr &res); bool CconvPass(const ResourcePtr &res);
bool ValidatePass(const ResourcePtr &res); bool ValidatePass(const ResourcePtr &res);

@ -608,7 +608,7 @@ void Pipeline::Run() {
MS_LOG(INFO) << "End"; MS_LOG(INFO) << "End";
} }
void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list) { void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list) {
std::size_t size = args.size(); std::size_t size = args.size();
for (std::size_t i = 0; i < size; i++) { for (std::size_t i = 0; i < size; i++) {
@ -625,7 +625,6 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V
arg_list->push_back(converted); arg_list->push_back(converted);
} }
ResourcePtr res = GetResource(phase);
MS_EXCEPTION_IF_NULL(res); MS_EXCEPTION_IF_NULL(res);
auto graph = res->func_graph(); auto graph = res->func_graph();
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
@ -647,6 +646,10 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V
} }
} }
void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list) {
ProcessVmArgInner(args, GetResource(phase), arg_list);
}
py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) {
std::size_t size = args.size(); std::size_t size = args.size();
if (!py::isinstance<py::str>(phase)) { if (!py::isinstance<py::str>(phase)) {
@ -874,6 +877,8 @@ void ClearResAtexit() {
compile::ClearConvertCache(); compile::ClearConvertCache();
pipeline::GetMethodMap().clear(); pipeline::GetMethodMap().clear();
pipeline::ExecutorPy::ClearRes(); pipeline::ExecutorPy::ClearRes();
pipeline::ReclaimOptimizer();
pynative::PynativeExecutor::GetInstance()->Clean();
#ifdef ENABLE_GE #ifdef ENABLE_GE
transform::DfGraphManager::GetInstance().ClearGraph(); transform::DfGraphManager::GetInstance().ClearGraph();
transform::DfGraphConvertor::get_adpt_map().clear(); transform::DfGraphConvertor::get_adpt_map().clear();

@ -139,6 +139,8 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, bool need_run); const std::vector<int64_t> &input_indexes, bool need_run);
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list);
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -22,23 +22,93 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <mutex>
#include <stack>
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pynative/base.h" #include "pynative/base.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "ir/anf.h"
#include "pipeline/resource.h"
#include "operator/composite/composite.h"
namespace mindspore { namespace mindspore {
namespace pynative { namespace pynative {
namespace py = pybind11; namespace py = pybind11;
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::tuple RunOp(const py::args &args); py::tuple RunOp(const py::args &args);
py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args);
void ClearPyNativeSession(); void ClearPyNativeSession();
struct GraphInfo {
std::unordered_map<std::string, AnfNodePtr> param_map;
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> obj_node_map;
AnfNodePtr output;
std::vector<std::string> objects;
};
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
public:
static std::shared_ptr<PynativeExecutor> GetInstance() {
std::lock_guard<std::mutex> i_lock(instance_lock_);
if (executor_ == nullptr) {
executor_ = std::shared_ptr<PynativeExecutor>(new (std::nothrow) PynativeExecutor());
resource_ = std::make_shared<pipeline::Resource>();
}
return executor_;
}
void NewGraph(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
void Clear();
void Clean();
bool grad_flag() { return grad_flag_; }
void set_grad_flag(bool flag) { grad_flag_ = flag; }
AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask);
AnfNodePtr GetObjNode(const py::object &obj);
FuncGraphPtr curr_g() { return curr_g_; }
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, -1);
}
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
}
AnfNodePtr MakeCNode(const py::args &args, const py::tuple &out);
py::object Run(const py::tuple &args, const py::object &phase);
void Pushp();
void Popp();
FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
size_t arg_size);
~PynativeExecutor();
private:
PynativeExecutor();
static std::shared_ptr<PynativeExecutor> executor_;
static std::mutex instance_lock_;
static ResourcePtr resource_;
bool grad_flag_;
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::stack<FuncGraphPtr> graph_p_;
FuncGraphPtr top_g_;
FuncGraphPtr df_builder_;
FuncGraphPtr curr_g_;
};
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
} // namespace pynative } // namespace pynative
} // namespace mindspore } // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save