!2807 Apply indexed_slices

Merge pull request !2807 from riemann_penn/apply_indexed_slices
pull/2807/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2f27c4e3e8

@ -45,7 +45,8 @@ FuncGraph::FuncGraph()
hyper_param_count_(0),
is_generated_(false),
return_(nullptr),
manager_(std::weak_ptr<FuncGraphManager>()) {
manager_(std::weak_ptr<FuncGraphManager>()),
stub_(false) {
debug_info_ = std::make_shared<GraphDebugInfo>();
}

@ -344,6 +344,9 @@ class FuncGraph : public FuncGraphBase {
void SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs);
bool HasEffect(const CNodePtr &cnode);
bool stub() const { return stub_; }
void set_stub(bool stub) { stub_ = stub; }
private:
// graph is manipulated by manager and others
friend FuncGraphManager;
@ -402,6 +405,7 @@ class FuncGraph : public FuncGraphBase {
// CNode order which relates to origin code order
std::list<CNodePtr> order_;
bool stub_;
};
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {

@ -218,6 +218,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
(*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count());
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
(*target_func_graph)->set_is_generate(func_graph->is_generated());
(*target_func_graph)->set_stub(func_graph->stub());
TraceManager::EndTrace();
}
@ -629,6 +630,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
new_func_graph->set_is_generate(func_graph->is_generated());
new_func_graph->set_stub(func_graph->stub());
for (auto &item : func_graph->parameter_default_value()) {
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
}

@ -30,6 +30,7 @@
#include "pipeline/static_analysis/param_validator.h"
#include "operator/cc_implementations.h"
#include "optimizer/opt.h"
#include "utils/context/ms_context.h"
#include "utils/symbolic.h"
#include "pybind_api/api_register.h"
#include "./common.h"
@ -115,36 +116,43 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
}
return item.second;
}
// Try best match
py::function py_fn_subclass;
size_t subclass_match_cnt = 0;
for (auto &item : fn_cache_py_) {
TypePtrList sign = item.first;
if (sign.size() != types.size()) {
continue;
return py::none();
}
FuncGraphPtr GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
std::vector<AnfNodePtr> parameters;
ParameterPtr undetermined_param = nullptr;
auto stub = std::make_shared<FuncGraph>();
for (size_t i = 0; i < types.size(); ++i) {
auto param = stub->add_parameter();
parameters.push_back(param);
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
undetermined_param = param;
}
auto match = true;
for (size_t i = 0; i < sign.size(); ++i) {
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) &&
!IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) {
match = false;
break;
}
if (undetermined_param != nullptr) {
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < types.size(); ++i) {
if (types[i]->type_id() == kObjectTypeFunction) {
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
inputs.push_back(stub->NewCNode(call_prim));
} else {
inputs.push_back(parameters[i]);
}
}
if (!match) {
continue;
}
py_fn_subclass = item.second;
subclass_match_cnt++;
}
if (subclass_match_cnt > 1) {
MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass";
}
if (subclass_match_cnt == 1) {
MS_LOG(DEBUG) << "Found one subclass match";
return py_fn_subclass;
auto stub_output = stub->NewCNode(inputs);
stub->set_output(stub_output);
stub->set_stub(true);
return stub;
}
return py::none();
return nullptr;
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
@ -159,6 +167,11 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString();
return func_graph;
}
auto stub = GenerateStubFunc(types);
if (stub != nullptr) {
MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString();
return stub;
}
std::ostringstream oss;
oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
<< "`, corresponding location info:\n";

@ -23,8 +23,8 @@
#include "pipeline/static_analysis/param_validator.h"
#include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/utils.h"
#include "utils/symbolic.h"
#include "utils/context/ms_context.h"
#include "utils/symbolic.h"
namespace mindspore {
namespace abstract {
@ -56,79 +56,6 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
return AbstractFunction::MakeAbstractFunction(jv);
}
class UndeterminedShapeType {
public:
explicit UndeterminedShapeType(const std::string &env_str) {
// param_name indices_shape indices_type values_shape values_type dense_shape
// export UNDETERMINED_SPARSE_SHAPE_TYPES="sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1
// 2:Float32:3 1 2"
std::vector<string> fields;
string tmp;
std::stringstream input(env_str);
while (std::getline(input, tmp, ':')) {
fields.push_back(tmp);
}
if (fields.size() != fields_num) {
MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size();
}
param_name_ = fields[0];
indices_shape_ = GetShape(fields[1]);
indices_type_ = StringToType(fields[2]);
values_shape_ = GetShape(fields[3]);
values_type_ = StringToType(fields[4]);
auto dense_shape_vec = GetShape(fields[5]);
AbstractBasePtrList dense_shape_list;
(void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list),
[](const auto &elem) { return FromValue(elem, false); });
dense_shape_ = dense_shape_list;
}
~UndeterminedShapeType() = default;
const std::string &param_name() { return param_name_; }
const std::vector<int> &indices_shape() { return indices_shape_; }
const TypePtr &indices_type() { return indices_type_; }
const std::vector<int> &values_shape() { return values_shape_; }
const TypePtr &values_type() { return values_type_; }
const AbstractBasePtrList &dense_shape() { return dense_shape_; }
private:
std::string param_name_;
std::vector<int> indices_shape_;
TypePtr indices_type_;
std::vector<int> values_shape_;
TypePtr values_type_;
AbstractBasePtrList dense_shape_;
static const size_t fields_num;
std::vector<int> GetShape(const std::string &shape_str);
};
std::vector<int> UndeterminedShapeType::GetShape(const std::string &shape_str) {
std::vector<int> ret;
std::istringstream iss(shape_str);
int elem;
while (iss.good()) {
iss >> elem;
ret.emplace_back(elem);
}
return ret;
}
const size_t UndeterminedShapeType::fields_num = 6;
std::unordered_map<std::string, UndeterminedShapeType> g_undetermined_configs;
void InitUndeterminedFromEnv(const std::string &sparse_shape_types) {
std::string tmp;
std::stringstream input(sparse_shape_types);
g_undetermined_configs.clear();
while (std::getline(input, tmp, ';')) {
auto config = UndeterminedShapeType(tmp);
g_undetermined_configs.insert(std::make_pair(config.param_name(), config));
MS_LOG(DEBUG) << "Undetermined config from env: " << tmp;
}
}
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
@ -142,45 +69,14 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
}
if (!key->sparse_grad().empty()) {
// Will be fixed once undetermined type ready
if (g_undetermined_configs.empty()) {
auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES");
MS_LOG(INFO) << "Undetermind sparse shape:" << sparse_shape_types;
if (sparse_shape_types.empty()) {
sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2";
}
InitUndeterminedFromEnv(sparse_shape_types);
}
auto shape_types = g_undetermined_configs.find(key->sparse_grad());
if (shape_types == g_undetermined_configs.end()) {
MS_LOG(EXCEPTION) << "Param " << key->ToString()
<< " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES";
}
MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString();
AbstractBasePtrList sparse_list;
// indices
auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.indices_type());
auto indices =
std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types->second.indices_shape()));
sparse_list.emplace_back(indices);
// values
auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.values_type());
auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types->second.values_shape()));
sparse_list.emplace_back(dout);
// dense_shape
sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types->second.dense_shape()));
return std::make_shared<AbstractTuple>(sparse_list);
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa<AbstractTensor>()) {
bool enable_sparse = context->enable_sparse();
if (enable_sparse && dflt->isa<AbstractTensor>()) {
auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());
}
if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
return dflt;
}
@ -242,10 +138,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
if (type->type_id() != kObjectTypeRefKey) {
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
}
auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
ret->set_sparse_grad(args_spec_list[2]->sparse_grad());
ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad());
return ret;
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
}
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,

@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor {
}
auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
return nullptr;
}
@ -110,7 +110,7 @@ class InlinerBase : public AnfVisitor {
// G
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
return nullptr;
}
// Do not inline GraphKernel to Cell.

@ -1367,7 +1367,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
std::string env = common::GetEnv("SLICE_ENV");
if (!env.empty()) {
MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env;
abstract::InitUndeterminedFromEnv(env);
}
}

@ -232,8 +232,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
ValuePtr value = param_value->value();
constexpr bool broaden = true;
AbstractBasePtr ptr = abstract::FromValue(value, broaden);
ptr->set_sparse_grad(param_value->sparse_grad());
ptr->set_has_indexed_slices_grad(param_value->has_indexed_slices_grad());
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
args_spec.push_back(ptr);

@ -155,8 +155,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
"Set the GraphKernel switch to on or off.")
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.")
.def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.")
.def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse.");
.def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.")
.def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")

@ -321,21 +321,19 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
return true;
}
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},
std::vector<PassItem> kVmPasses = {{"opt_a", OptPassAGroup},
{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_b", OptPassBGroup},
{"cconv", CconvPass},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},
{"opt_b", OptPassBGroup},
{"add_control_depend", AddControlDependPass},
{"opt_control", ControlGroup},
{"opt_prepare", PrepareGroup},
{"cconv", CconvPass}};
std::vector<PassItem> kGePasses = {
{"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass},
{"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup},
{"cconv", CconvPass}};
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
} // namespace pipeline

@ -146,37 +146,35 @@ MethodMap &GetMethodMap() {
}},
{kObjectTypeTensorType,
{
{"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub
{"__mul__", std::string("mul")}, // C.mul
{"__truediv__", std::string("truediv")}, // C.truediv
{"__floordiv__", std::string("floordiv")}, // C.floordiv
{"__mod__", std::string("mod")}, // C.mod
{"__pow__", std::string("pow_")}, // C.pow
{"__floor__", std::string("array_floor")}, // C.array_floor
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
{"__pos__", std::string("array_uadd")}, // C.array_uadd
{"__neg__", std::string("array_usub")}, // C.array_usub
{"__eq__", std::string("eq")}, // C.eq
{"__ne__", std::string("ne")}, // C.ne
{"__lt__", std::string("lt")}, // C.lt
{"__gt__", std::string("gt")}, // C.gt
{"__le__", std::string("le")}, // C.le
{"__ge__", std::string("ge")}, // C.ge
{"__matmul__", prim::kPrimDot}, // P.dot,
{"__len__", prim::kPrimArrayLen}, // P.array_len,
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
{"transpose", std::string("transpose")}, // P.transpose
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
{"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub
{"__mul__", std::string("mul")}, // C.mul
{"__truediv__", std::string("truediv")}, // C.truediv
{"__floordiv__", std::string("floordiv")}, // C.floordiv
{"__mod__", std::string("mod")}, // C.mod
{"__pow__", std::string("pow_")}, // C.pow
{"__floor__", std::string("array_floor")}, // C.array_floor
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
{"__pos__", std::string("array_uadd")}, // C.array_uadd
{"__neg__", std::string("array_usub")}, // C.array_usub
{"__eq__", std::string("eq")}, // C.eq
{"__ne__", std::string("ne")}, // C.ne
{"__lt__", std::string("lt")}, // C.lt
{"__gt__", std::string("gt")}, // C.gt
{"__le__", std::string("le")}, // C.le
{"__ge__", std::string("ge")}, // C.ge
{"__matmul__", prim::kPrimDot}, // P.dot,
{"__len__", prim::kPrimArrayLen}, // P.array_len,
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
{"transpose", std::string("transpose")}, // P.transpose
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
}},
{kObjectTypeIndexedSlicesType,
{
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
{"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape

@ -55,7 +55,6 @@ ValuePtr AbstractBase::BuildValue() const {
AbstractBasePtr AbstractBase::Broaden() const {
AbstractBasePtr clone = Clone();
clone->set_value(kAnyValue);
clone->set_sparse_grad(sparse_grad_);
return clone;
}
@ -68,8 +67,7 @@ std::string AbstractBase::ToString() const {
MS_EXCEPTION_IF_NULL(type_);
MS_EXCEPTION_IF_NULL(shape_);
buffer << type_name() << "("
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString()
<< " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")";
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")";
return buffer.str();
}
@ -78,25 +76,16 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden()
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
if (*this == *other) {
auto ret = shared_from_base<AbstractBase>();
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
return shared_from_base<AbstractBase>();
}
auto value_self = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_self);
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
if (res_value == value_self) {
auto ret = shared_from_base<AbstractBase>();
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
return shared_from_base<AbstractBase>();
}
auto ret = std::make_shared<AbstractScalar>(res_value, res_type);
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
return std::make_shared<AbstractScalar>(res_value, res_type);
}
AbstractBasePtr AbstractType::Clone() const {
@ -452,16 +441,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {
if (sparse_grad() == other->sparse_grad()) {
return shared_from_base<AbstractBase>();
}
return shared_from_base<AbstractBase>();
}
auto element = element_->Join(other_tensor->element_);
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
auto ret = std::make_shared<AbstractTensor>(element, shape);
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
return std::make_shared<AbstractTensor>(element, shape);
}
bool AbstractTensor::operator==(const AbstractTensor &other) const {
@ -501,8 +485,6 @@ AbstractBasePtr AbstractTensor::Clone() const {
ShapePtr shp = shape();
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_sparse_grad(sparse_grad());
clone->set_has_indexed_slices_grad(has_indexed_slices_grad());
return clone;
}
@ -512,8 +494,6 @@ AbstractBasePtr AbstractTensor::Broaden() const {
auto shp = shape();
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_sparse_grad(sparse_grad());
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
return broaden;
}
@ -524,8 +504,6 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
shp->Broaden();
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
broaden->set_sparse_grad(sparse_grad());
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
return broaden;
}
@ -538,8 +516,7 @@ std::string AbstractTensor::ToString() const {
MS_EXCEPTION_IF_NULL(value_track);
buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad()
<< " has_indexed_slices_grad " << has_indexed_slices_grad() << ")";
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")";
return buffer.str();
}

@ -44,7 +44,7 @@ class AbstractBase : public Base {
public:
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
const BaseShapePtr &shape = kNoShape)
: value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {}
: value_(value), type_(type), shape_(shape) {}
~AbstractBase() override = default;
MS_DECLARE_PARENT(AbstractBase, Base)
@ -53,17 +53,11 @@ class AbstractBase : public Base {
virtual bool operator==(const AbstractBase &other) const;
void set_value(const ValuePtr &value) { value_ = value; }
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) {
has_indexed_slices_grad_ = has_indexed_slices_grad;
}
void set_type(const TypePtr &type) { type_ = type; }
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
const std::string &value_desc() const { return value_desc_; }
ValuePtr GetValueTrack() const { return value_; }
const std::string &sparse_grad() const { return sparse_grad_; }
const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
TypePtr GetTypeTrack() const { return type_; }
BaseShapePtr GetShapeTrack() const { return shape_; }
@ -91,8 +85,6 @@ class AbstractBase : public Base {
TypePtr type_;
BaseShapePtr shape_;
std::string value_desc_; // store initial value description for error report
std::string sparse_grad_;
bool has_indexed_slices_grad_;
};
class AbstractScalar : public AbstractBase {

@ -126,7 +126,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
}
MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString();
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString()
<< ", is stub: " << fg->stub();
if (fg->stub()) {
return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr);
}
return std::make_shared<EvalResult>(ret_base, nullptr);
}

@ -25,6 +25,7 @@
#include <vector>
#include "pipeline/static_analysis/static_analysis.h"
#include "utils/context/ms_context.h"
namespace mindspore {
namespace abstract {
@ -59,6 +60,13 @@ class Evaluator : public Base {
}
virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) {
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
return true;

@ -146,10 +146,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
using mindspore::parse::PyObjectWrapper;
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) {
if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) {
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
@ -167,6 +164,14 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
auto ret_abstract = AbstractEval(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
return ret_abstract;
}
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
@ -181,9 +186,6 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
}
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) {
scope = out_conf->node()->scope();
@ -509,15 +511,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
} // end anonymous namespace
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag) {
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
return ret_abstract;
}
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
return ret_abstract;
}
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
@ -546,15 +543,10 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
}
EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag) {
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
return ret_abstract;
}
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
return ret_abstract;
}
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
if (nargs_ != args.size()) {
@ -914,8 +906,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
auto ret = std::make_shared<AbstractScalar>(type);
auto ref_value = ref_abs->ref();
MS_EXCEPTION_IF_NULL(ref_value);
ret->set_sparse_grad(ref_value->sparse_grad());
ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad());
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
}
@ -930,8 +920,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
x = SensitivityTransform(x);
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
abs_scalar->set_sparse_grad(x->sparse_grad());
abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad());
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
}
};
@ -943,15 +931,10 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag) {
auto ret_abstract = AbstractEval(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
return ret_abstract;
}
auto ret_abstract = AbstractEval(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
return ret_abstract;
}
// Inputs: data, item
if (args_spec_list.size() != 2) {

@ -349,7 +349,6 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
void InitUndeterminedFromEnv(const std::string &sparse_shape_types);
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);

@ -321,7 +321,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
AbstractFunctionPtr func = real_a->GetUnique();
SpecializeStatusCode errcode;
ScopeGuard scope_guard(node->scope());
AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode);
AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode);
if (repl == nullptr) {
if (errcode == kSpecializeFindUniqueArgvalDead) {
const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
@ -340,7 +340,8 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
return repl;
}
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractFunctionPtr &func,
const AbstractBasePtrList &args,
SpecializeStatusCode *errcode) {
MS_EXCEPTION_IF_NULL(abs);
@ -384,7 +385,14 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
<< ", graph: " << context->func_graph()->get_return()->DebugString();
if (context->func_graph()->stub()) {
MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
<< ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString()
<< ", " << node->ToString();
return node;
}
FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
v->set_flag(kFuncGraphFlagUndetermined, false);
return BuildValueNode(v, abs);
}
@ -613,7 +621,8 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
*result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
return kSpecializeSuccess;
} else if (choices->empty()) {
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
<< func->type_name();
return kSpecializeFindUniqueArgvalDead;
} else {
if (IsPolyFunc(func, argvals)) {

@ -118,8 +118,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
// Build a specialized node from given argvals;
AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractBasePtrList &argvals);
AnfNodePtr BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
const AbstractBasePtrList &args, SpecializeStatusCode *errcode);
AnfNodePtr BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractFunctionPtr &func, const AbstractBasePtrList &args,
SpecializeStatusCode *errcode);
// Find the unique argument values which can be used to specialize a primitive or graph function.
SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval,

@ -90,7 +90,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
max_device_memory_ = kDefaultMaxDeviceMemory;
print_file_path_ = "";
enable_graph_kernel_ = false;
enable_sparse_flag_ = false;
enable_sparse_ = false;
}
std::shared_ptr<MsContext> MsContext::GetInstance() {

@ -161,8 +161,8 @@ class MsContext {
void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; }
bool enable_graph_kernel() const { return enable_graph_kernel_; }
bool enable_sparse_flag() const { return enable_sparse_flag_; }
void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; }
bool enable_sparse() const { return enable_sparse_; }
void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; }
private:
MsContext(const std::string &backend_policy, const std::string &target);
@ -207,7 +207,7 @@ class MsContext {
float max_device_memory_;
std::string print_file_path_;
bool enable_graph_kernel_;
bool enable_sparse_flag_;
bool enable_sparse_;
};
} // namespace mindspore

@ -51,18 +51,13 @@ class Parameter:
requires_grad (bool): True if the parameter requires gradient. Default: True.
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
broadcast and gradients communication would not be applied on parameters. Default: False.
sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty.
has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false.
"""
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False,
sparse_grad="", has_indexed_slices_grad=False):
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
self._value = ParamValue()
self.set_parameter_data(default_input)
self.name = name
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
self.sparse_grad = sparse_grad
self.has_indexed_slices_grad = has_indexed_slices_grad
self._is_init = False
self._sliced = False
self.is_param_ps = False
@ -181,28 +176,6 @@ class Parameter:
raise TypeError("`requires_grad` parameter must be bool type")
self._value.requires_grad = value
@property
def sparse_grad(self):
"""Return whether the parameter's gradient is sparse."""
return self._value.sparse_grad
@sparse_grad.setter
def sparse_grad(self, value=""):
if not isinstance(value, str):
raise TypeError("`sparse_grad` parameter must be str type")
self._value.sparse_grad = value
@property
def has_indexed_slices_grad(self):
"""Return whether the parameter's gradient is indexed_slices."""
return self._value.has_indexed_slices_grad
@has_indexed_slices_grad.setter
def has_indexed_slices_grad(self, value=False):
if not isinstance(value, bool):
raise TypeError("`has_indexed_slices_grad` parameter must be bool type")
self._value.has_indexed_slices_grad = value
@property
def data(self):
return self.default_input

@ -367,14 +367,6 @@ class _Context:
def check_bprop(self, check_bprop_flag):
self._context_handle.set_check_bprop_flag(check_bprop_flag)
@property
def enable_sparse(self):
return self._context_handle.get_enable_sparse_flag()
@enable_sparse.setter
def enable_sparse(self, enable_sparse_flag):
self._context_handle.set_enable_sparse_flag(enable_sparse_flag)
@property
def max_device_memory(self):
return self._context_handle.get_max_device_memory()
@ -408,6 +400,13 @@ class _Context:
full_file_name = print_file_path
self._context_handle.set_print_file_path(full_file_name)
@property
def enable_sparse(self):
return self._context_handle.get_enable_sparse()
@enable_sparse.setter
def enable_sparse(self, enable_sparse):
self._context_handle.set_enable_sparse(enable_sparse)
def check_input_format(x):
import re
@ -601,7 +600,7 @@ def set_context(**kwargs):
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
suffix to the file.
enable_sparse (bool): Whether to enable sparse feature. Default: False.
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
Raises:
ValueError: If input key is not an attribute in context.

@ -188,8 +188,8 @@ class Adam(Optimizer):
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
Args:

@ -93,8 +93,8 @@ class FTRL(Optimizer):
<https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
Args:

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save