make python Parameter inherit from Tensor

pull/3473/head
Wei Luning 5 years ago
parent 2b56562770
commit a05c38bb63

@ -21,12 +21,12 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_class_member_namespace_symbol, create_slice_obj,
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key,
get_default_input, get_parse_method_of_class, get_scope_name,
get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol)
from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_default_input', 'get_class_instance_type', 'is_class_member',
'get_object_key', 'get_class_instance_type', 'is_class_member',
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',

@ -206,16 +206,6 @@ def get_object_key(obj):
return obj_id, obj_key
def get_default_input(obj):
if hasattr(obj, '__parameter__'):
return obj.default_input
if isinstance(obj, tuple):
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
args = tuple(convert(x) for x in obj)
return args
return obj
def is_class_member(node):
"""Check the attr is class member variable."""
type_ = node.__class__.__name__

@ -76,7 +76,7 @@ GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
if (AnfAlgo::IsParameterWeight(pk_node)) {
const auto &param_value = pk_node->default_param();
MS_EXCEPTION_IF_NULL(param_value);
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value->value());
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
MS_EXCEPTION_IF_NULL(tensor);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),

@ -42,12 +42,12 @@
namespace mindspore {
namespace session {
static std::shared_ptr<std::map<ParamValuePtr, ParameterPtr>> python_paras;
static std::shared_ptr<std::map<ValuePtr, ParameterPtr>> python_paras;
void ClearPythonParasMap() { python_paras = nullptr; }
namespace {
const int kSummaryGetItem = 2;
ParamValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
}
@ -209,8 +209,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
auto param = graph->NewParameter();
MS_EXCEPTION_IF_NULL(param);
if (tensor_mask == kParameterWeightTensorMask) {
auto param_value_new = std::make_shared<ParamValue>();
param->set_default_param(param_value_new);
param->set_default_param(input_tensor);
}
// set the kernel info of parameter
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
@ -390,7 +389,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
ParameterPtr new_parameter = nullptr;
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
if (python_paras == nullptr) {
python_paras = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
}
auto iter = python_paras->find(param_value);
if (iter != python_paras->end()) {
@ -667,7 +666,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
auto param_value = GetParamDefaultValue(anf);
ParameterPtr new_parameter = nullptr;
if (python_paras == nullptr) {
python_paras = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
}
auto iter = python_paras->find(param_value);
if (iter != python_paras->end()) {

@ -1670,7 +1670,7 @@ class IrParser {
// load parameter default value from serialized file
py::object default_obj = LoadObject(lexer_.GetTokenText());
auto param_value_new = py::cast<ParamValuePtr>(default_obj);
auto param_value_new = py::cast<tensor::TensorPtr>(default_obj);
param->set_default_param(param_value_new);
tok = lexer_.GetNextToken();

@ -318,8 +318,9 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
buffer_ << parameter->ToString();
auto param = parameter->cast<ParameterPtr>();
if (param->has_default()) {
auto tensor = param->default_param()->value();
if (tensor) {
auto tensor_v = param->default_param();
if (tensor_v && tensor_v->isa<tensor::Tensor>()) {
auto tensor = tensor_v->cast<tensor::TensorPtr>();
auto &shape = tensor->shape();
std::ostringstream shape_str;
std::copy(shape.begin(), shape.end(), std::ostream_iterator<int>(shape_str, ","));

@ -38,7 +38,12 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
if (!para_ptr->has_default()) {
return false;
}
return para_ptr->default_param()->requires_grad();
auto obj = py::cast(para_ptr->default_param());
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
if (param_value == nullptr) {
return false;
}
return param_value->requires_grad();
}
} // namespace parallel
} // namespace mindspore

@ -41,6 +41,7 @@
#include "frontend/parallel/context.h"
#include "frontend/parallel/ops_info/tmp_identity_info.h"
#include "frontend/parallel/ops_info/reshape_info.h"
#include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "pipeline/jit/parse/python_adapter.h"
@ -122,12 +123,7 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
bool requires_grad = input_parameter->default_param()->requires_grad();
is_parameter.push_back(requires_grad);
} else {
is_parameter.push_back(false);
}
is_parameter.push_back(ParameterRequireGrad(input_parameter));
} else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
is_parameter.push_back(false);
}
@ -798,12 +794,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
std::vector<bool> is_parameter;
auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(casted_target_parameter);
if (casted_target_parameter->has_default()) {
bool requires_grad = casted_target_parameter->default_param()->requires_grad();
is_parameter.push_back(requires_grad);
} else {
is_parameter.push_back(false);
}
is_parameter.push_back(ParameterRequireGrad(casted_target_parameter));
if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
}

@ -1295,11 +1295,8 @@ void CoverSliceShape(const FuncGraphPtr &root) {
g_RefMap.clear();
}
bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr &parameter_node) {
MS_EXCEPTION_IF_NULL(root);
bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
MS_EXCEPTION_IF_NULL(parameter_node);
FuncGraphManagerPtr manager = root->manager();
MS_EXCEPTION_IF_NULL(manager);
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(cloned_parameter);
@ -1307,8 +1304,12 @@ bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr &parameter_nod
if (!cloned_parameter->has_default()) {
return false;
}
bool cloned = cloned_parameter->default_param()->cloned();
auto obj = py::cast(cloned_parameter->default_param());
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
if (param_value == nullptr) {
return false;
}
bool cloned = param_value->cloned();
if (!cloned) {
return false;
}
@ -1324,12 +1325,16 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(cloned_parameter);
if (!ParameterIsCloned(root, cloned_parameter_node)) {
if (!ParameterIsCloned(cloned_parameter_node)) {
continue;
}
auto obj = py::cast(cloned_parameter->default_param());
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
if (param_value == nullptr) {
continue;
}
// get the cloned index
int32_t cloned_index = cloned_parameter->default_param()->cloned_index();
int32_t cloned_index = param_value->cloned_index();
// find the be cloned parameter
bool found_be_cloned_parameter = false;
@ -1344,12 +1349,18 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
}
const auto &param_value_cloned = be_cloned_parameter->default_param();
if (!param_value_cloned->be_cloned()) {
auto obj_in = py::cast(param_value_cloned);
auto param_value_in = py::cast<ParamValuePtr>(obj_in.attr("_value"));
if (param_value_in == nullptr) {
continue;
}
if (!param_value_in->be_cloned()) {
continue;
}
// get the be cloned index
auto &be_cloned_index = param_value_cloned->be_cloned_index();
auto &be_cloned_index = param_value_in->be_cloned_index();
if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) {
found_be_cloned_parameter = true;
cloned_from_parameter = be_cloned_parameter;
@ -2103,10 +2114,7 @@ std::string NodeParameterName(const CNodePtr &node) {
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
const auto &param_value = input_parameter->default_param();
if (param_value->requires_grad()) {
return param_value->name();
}
input_parameter->name();
}
}
}

@ -233,8 +233,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
for (const auto &param : func_graph->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) {
const auto &param_value = param_node->default_param();
ValuePtr value = param_value->value();
ValuePtr value = param_node->default_param();
constexpr bool broaden = true;
AbstractBasePtr ptr = abstract::FromValue(value, broaden);

@ -68,6 +68,8 @@ PYBIND11_MODULE(_c_expression, m) {
py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.")
.def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
.def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.")
.def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"),
"Get Parameter Tensor Layout Dictionary.")
.def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"),

@ -205,41 +205,6 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_
return true;
}
bool ConvertDataType(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting type object";
auto typeptr = obj.cast<TypePtr>();
if (typeptr == nullptr) {
MS_LOG(ERROR) << "Resolve TypePtr error, get ptr is null";
return false;
}
*data = typeptr;
return true;
}
bool ConvertMetaTensor(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting MetaTensor object.";
auto m_tensor = obj.cast<MetaTensorPtr>();
if (m_tensor == nullptr) {
MS_LOG(ERROR) << "Resolve MetaTensor error, get ptr is null.";
return false;
}
*data = m_tensor;
return true;
}
bool ConvertTensor(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting tensor object";
auto m_tensor = obj.cast<TensorPtr>();
if (m_tensor == nullptr) {
MS_LOG(ERROR) << "Resolve Tensor error, get ptr is null";
return false;
}
*data = m_tensor;
return true;
}
bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting slice object";
@ -364,11 +329,11 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::isinstance<MetaFuncGraph>(obj)) {
ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
} else if (py::isinstance<Type>(obj)) {
ret = ConvertDataType(obj, &converted);
converted = obj.cast<TypePtr>();
} else if (py::isinstance<Tensor>(obj)) {
ret = ConvertTensor(obj, &converted);
converted = obj.cast<TensorPtr>();
} else if (py::isinstance<MetaTensor>(obj)) {
ret = ConvertMetaTensor(obj, &converted);
converted = obj.cast<MetaTensorPtr>();
} else if (py::isinstance<EnvInstance>(obj)) {
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
converted = env;

@ -85,7 +85,6 @@ const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super";
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
const char PYTHON_MOD_GET_DEFAULT_INPUT[] = "get_default_input";
// define the common name
const char NAMED_PRIMITIVE_LEN[] = "len";

@ -103,10 +103,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
}
if (para_node == nullptr) {
auto node = top_graph->AddWeightParameter(param_name);
auto param_value = py::cast<ParamValuePtr>(python_adapter::GetPyObjAttr(obj, "_value"));
node->set_default_param(param_value);
auto value = py::cast<tensor::MetaTensorPtr>(obj);
node->set_default_param(value);
// set_abstract for parameter
ValuePtr value = param_value->value();
constexpr bool broaden = true;
node->set_abstract(abstract::FromValue(value, broaden));
para_node = node;

@ -719,7 +719,11 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
if (!param_ptr->has_default()) {
MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
}
arg_list->push_back(param_ptr->default_param()->value());
if (!param_ptr->default_param()->isa<Tensor>()) {
MS_LOG(EXCEPTION) << "Parameter[" << param_ptr->ToString()
<< "] is not initialized, need to call `.init_data()`";
}
arg_list->push_back(param_ptr->default_param());
}
}
}
@ -782,6 +786,24 @@ FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::stri
#endif
}
void ExecutorPy::UpdataParamNodeDefaultInput(const std::string &phase,
const std::unordered_map<std::string, tensor::TensorPtr> &params_value) {
FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(DEBUG) << "UpdataParamNodeDefaultInput for func graph(" << func_graph->ToString() << ") phase(" << phase
<< ")!";
auto &params = func_graph->parameters();
for (const auto &param : params) {
MS_EXCEPTION_IF_NULL(param);
auto param_cast = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_cast);
auto iter = params_value.find(param_cast->name());
if (iter != params_value.end()) {
param_cast->set_default_param(iter->second);
}
}
}
void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) {
#if ENABLE_GE
RunGEInitGraph(init_params, phase);

@ -88,6 +88,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase,
const py::object &broadcast_params = {});
void UpdataParamNodeDefaultInput(const std::string &phase,
const std::unordered_map<std::string, tensor::TensorPtr> &params);
void RunInitGraph(const py::dict &init_params, const std::string &phase);
py::dict GetParameterLayout(const std::string &phase);
py::dict GetCNodeStrategy(const std::string &phase);

@ -146,12 +146,6 @@ static std::string GetOpId(const OpExecInfoPtr &op_exec_info) {
return id;
}
py::object GetTupleObj(const py::object &obj) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
return obj_tuple;
}
std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
for (size_t i = 0; i < dtypes.size(); ++i) {
@ -242,7 +236,7 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
py::tuple input_mask(args.size());
for (size_t i = 0; i < args.size(); ++i) {
input_mask[i] = py::hasattr(args[i], "__parameter__");
py_args[i] = GetTupleObj(args[i]);
py_args[i] = args[i];
}
auto signature = prim->signatures();
std::vector<SignatureEnumDType> dtypes;
@ -366,9 +360,6 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py::tuple result(op_inputs.size());
for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i];
if (py::hasattr(input, "__parameter__")) {
input = py::getattr(input, "data");
}
auto tensor = py::cast<tensor::TensorPtr>(input);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
new_tensor->set_device_address(tensor->device_address());
@ -878,8 +869,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
auto free_param = df_builder_->add_parameter();
free_param->set_name(param_name);
auto free_param_new = py::cast<ParamValuePtr>(obj.attr("_value"));
free_param->set_default_param(free_param_new);
free_param->set_default_param(py::cast<tensor::TensorPtr>(obj));
free_param->debug_info()->set_name(param_name);
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
graph_info_map_[df_builder_].param_map[obj_id] = free_param;
@ -1074,8 +1064,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
for (const auto &param : df_builder_->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) {
const auto &param_value = param_node->default_param();
ValuePtr value = param_value->value();
ValuePtr value = param_node->default_param();
AbstractBasePtr ptr = abstract::FromValue(value, true);
if (ptr == nullptr) {
MS_LOG(EXCEPTION) << "Args convert error";

@ -187,7 +187,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
initializer_proto->set_name(param_name);
SetParamToTensorProto(param, initializer_proto);
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param()->value());
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
if (tensor) {
initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
}

@ -449,7 +449,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP
initializer_proto->set_name(param_ptr->ToString());
SetTensorProtoInfo(param_ptr, initializer_proto);
// set value for initializer
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_ptr->default_param()->value());
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_ptr->default_param());
if (tensor) {
initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
}

@ -52,7 +52,7 @@ bool GetParameterShape(const FuncGraphPtr &graph, const std::string &param_name,
if (param_node->name() == param_name) {
TensorPtr tensor;
if (param_node->has_default()) {
tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param()->value());
tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
}
if (tensor == nullptr) {
shape->push_back(ONE_SHAPE);

@ -448,7 +448,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
if (!param->has_default()) {
MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
}
auto tensor = param->default_param()->value();
auto tensor = param->default_param();
*ret_val = py::cast(tensor);
}
return true;

@ -124,10 +124,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
auto param_value = std::make_shared<ParamValue>();
MS_EXCEPTION_IF_NULL(param_value);
param_value->set_value(tensor_info);
node->set_default_param(param_value);
node->set_default_param(tensor_info);
}
anfnode_build_map_[value_proto.name()] = node;
return true;

@ -24,22 +24,19 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
(void)py::class_<ParamValue, ParamValuePtr>(*m, "ParamValue")
.def(py::init())
.def("clone", &ParamValue::Clone)
.def_property("data", &ParamValue::value, &ParamValue::set_value)
.def_property("name", &ParamValue::name, &ParamValue::set_name)
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
&ParamValue::set_layerwise_parallel)
.def(py::pickle(
[](const ParamValue &p) { // __getstate__
return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(),
p.layerwise_parallel());
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 6) {
std::runtime_error("Invalid state for ParamValue!");
}
ParamValuePtr p = std::make_shared<ParamValue>();
p->set_value(t[0].cast<tensor::TensorPtr>());
p->set_name(t[1].cast<std::string>());
p->set_requires_grad(t[2].cast<bool>());
p->set_layerwise_parallel(t[3].cast<bool>());

@ -372,7 +372,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
.def(py::pickle(
[](const Tensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(TensorPy::AsNumpy(t));
return py::make_tuple(TensorPy::SyncAsNumpy(t));
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {

@ -255,7 +255,6 @@ def ms_function(fn=None, obj=None, input_signature=None):
process_obj = obj
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
process_obj = args[0]
args = (x.default_input if hasattr(x, 'default_input') else x for x in args)
return _MindSporeFunction(func, input_signature, process_obj)(*args)
return staging_specialize
@ -354,28 +353,8 @@ class _Executor:
raise RuntimeError("Failure to init and dataset subgraph!")
return True
def _build_data_graph(self, obj, params, phase):
if params is None:
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
elif isinstance(params, OrderedDict):
self._executor.build_data_graph(params, phase)
else:
raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params)))
def _params_init_data(self, obj, params, auto_parallel_mode=False):
"""Init parameters' data."""
if params is not None:
for key, param in params.items():
if not auto_parallel_mode:
param.init_data()
elif key not in obj.parameter_layout_dict:
logger.debug("Layout dict does not contain the key %s.", key)
param.init_data(set_sliced=True)
else:
layout = obj.parameter_layout_dict[key]
param.init_data(layout, set_sliced=True)
obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
def _build_data_graph(self, obj, phase):
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
def _set_dataset_mode(self, args_list):
"""set dataset mode."""
@ -386,7 +365,7 @@ class _Executor:
else:
_set_dataset_mode_config('normal')
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False):
"""
Compiles graph.
@ -394,7 +373,6 @@ class _Executor:
obj (Function/Cell): The function or cell instance need compile.
args (tuple): Function or cell input arguments.
phase (str): The name of compile phase. Default: 'predict'.
params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
@ -435,10 +413,8 @@ class _Executor:
if auto_parallel_mode:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
self._params_init_data(obj, params, auto_parallel_mode)
if not enable_debug_runtime or enable_ge:
if auto_parallel_mode:
obj.load_parameter_slice(params)
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
self._updata_param_node_default_input(phase, replace)
# set parallel inputs in sink mode
if auto_parallel_mode and (args and isinstance(args[0], Tensor) and args[0].virtual_flag):
@ -446,16 +422,20 @@ class _Executor:
# the following GE init process is not needed when use vm or ms backend
if enable_ge:
self._build_data_graph(obj, params, phase)
self._build_data_graph(obj, phase)
if "export" not in phase:
init_phase = "init_subgraph" + "." + str(obj.create_time)
_exec_init_graph(obj, init_phase)
elif not enable_ge and "export" in phase:
self._build_data_graph(obj, params, phase)
self._build_data_graph(obj, phase)
return phase, True
def _updata_param_node_default_input(self, phase, replace):
new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
return self._executor.updata_param_node_default_input(phase, new_param)
def _get_strategy(self, obj):
real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
return self._executor.get_strategy(real_phase)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save