!12703 Add a switch to control grad for scalar

From: @ginfung
Reviewed-by: 
Signed-off-by:
pull/12703/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 614cf339ad

@ -34,6 +34,7 @@
#include "pybind_api/api_register.h" #include "pybind_api/api_register.h"
#include "ir/signature.h" #include "ir/signature.h"
#include "debug/trace.h" #include "debug/trace.h"
#include "utils/ms_context.h"
namespace mindspore { namespace mindspore {
// namespace to support composite operators definition // namespace to support composite operators definition
@ -403,7 +404,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
if (tail_type_ == kGradFirst) { if (tail_type_ == kGradFirst) {
if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && if (sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || ((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
((*sequeue)[1]->BuildType() != nullptr && (*sequeue)[1]->BuildType()->isa<Number>()))) { (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
(*sequeue)[1]->BuildType()->isa<Number>()))) {
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))})); ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
} else { } else {
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{}))); ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
@ -416,7 +418,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
if (tail_type_ == kGradAll) { if (tail_type_ == kGradAll) {
MS_EXCEPTION_IF_NULL((*sequeue)[i]); MS_EXCEPTION_IF_NULL((*sequeue)[i]);
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
((*sequeue)[i]->BuildType() != nullptr && (*sequeue)[i]->BuildType()->isa<Number>())) { (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr &&
(*sequeue)[i]->BuildType()->isa<Number>())) {
elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
} }
} else { } else {

@ -490,7 +490,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
} }
AbstractBasePtr par_abs = param_node->abstract(); AbstractBasePtr par_abs = param_node->abstract();
if (par_abs->isa<abstract::AbstractUndetermined>() || if (par_abs->isa<abstract::AbstractUndetermined>() ||
(par_abs->BuildType() != nullptr && par_abs->BuildType()->isa<Number>())) { (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
par_abs->BuildType()->isa<Number>())) {
new_paras.push_back(param_node); new_paras.push_back(param_node);
} }
} }

@ -98,7 +98,8 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name)
AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
bool broaden = value->isa<MetaTensor>() || value->isa<Scalar>(); bool broaden = value->isa<MetaTensor>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
return abstract::FromValue(value, broaden); return abstract::FromValue(value, broaden);
} }

@ -95,7 +95,8 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH); .value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH)
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext") (void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.") .def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.")

@ -210,7 +210,9 @@ class _MindSporeFunction:
return None return None
new_inputs = [] new_inputs = []
for i in args_list: for i in args_list:
if isinstance(i, (Tensor, int, float)): if isinstance(i, Tensor):
new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i) new_inputs.append(i)
return self._executor(tuple(new_inputs), phase) return self._executor(tuple(new_inputs), phase)

@ -533,6 +533,7 @@ def set_context(**kwargs):
save_graphs variable_memory_max_size save_graphs variable_memory_max_size
save_graphs_path save_graphs_path
env_config_path env_config_path
grad_for_scalar
=========================== =========================== ================= =========================== =========================== =================
Args: Args:
@ -602,6 +603,7 @@ def set_context(**kwargs):
enable_sparse (bool): Whether to enable sparsity feature. Default: False. enable_sparse (bool): Whether to enable sparsity feature. Default: False.
max_call_depth (int): Specify the maximum depth of function call. Default: 1000. max_call_depth (int): Specify the maximum depth of function call. Default: 1000.
env_config_path (str): Config path for DFX. env_config_path (str): Config path for DFX.
grad_for_scalar (bool): Whether to get gradient for scalar. Default: False.
Raises: Raises:
ValueError: If input key is not an attribute in context. ValueError: If input key is not an attribute in context.

@ -22,6 +22,7 @@
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "abstract/utils.h" #include "abstract/utils.h"
#include "utils/ms_context.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
@ -88,7 +89,13 @@ std::string AbstractBase::ToString() const {
return buffer.str(); return buffer.str();
} }
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); } AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
return AbstractBase::Broaden(config);
} else {
return Clone();
}
}
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other); MS_EXCEPTION_IF_NULL(other);

@ -171,6 +171,12 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
return args_spec_list[0]; return args_spec_list[0];
} }
auto depends = args_spec_list[0]->Broaden(); auto depends = args_spec_list[0]->Broaden();
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
// For scalar, need to set value to kAnyValue, because broaden scalar will not change the value.
if (depends->isa<AbstractScalar>()) {
depends->set_value(kAnyValue);
}
}
return depends; return depends;
} }

@ -74,6 +74,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false); set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
set_param<bool>(MS_CTX_ENABLE_SPARSE, false); set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false); set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
backend_policy_ = policy_map_[policy]; backend_policy_ = policy_map_[policy];
} }

@ -76,6 +76,7 @@ enum MsCtxParam : unsigned {
MS_CTX_SAVE_GRAPHS_FLAG, MS_CTX_SAVE_GRAPHS_FLAG,
MS_CTX_ENABLE_PARALLEL_SPLIT, MS_CTX_ENABLE_PARALLEL_SPLIT,
MS_CTX_ENABLE_INFER_OPT, MS_CTX_ENABLE_INFER_OPT,
MS_CTX_GRAD_FOR_SCALAR,
MS_CTX_TYPE_BOOL_END, MS_CTX_TYPE_BOOL_END,
// parameter of type int // parameter of type int

@ -609,7 +609,9 @@ class Cell(Cell_):
new_inputs = [] new_inputs = []
for i in inputs: for i in inputs:
if isinstance(i, (Tensor, int, float)): if isinstance(i, Tensor):
new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i) new_inputs.append(i)
if self._auto_parallel_mode: if self._auto_parallel_mode:

@ -32,26 +32,18 @@ TEST_F(TestUtils, test_join) {
AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false); AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false); AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false);
AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true); AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true);
abs_s_anything->set_value(kAnyValue);
AbstractBasePtr res_s1 = abs_s1->Join(abs_s2); AbstractBasePtr res_s1 = abs_s1->Join(abs_s2);
ASSERT_EQ(*res_s1, *abs_s_anything); ASSERT_EQ(*res_s1, *abs_s_anything);
// AbstractTuple join;
std::vector<int64_t> list1 = {1, 2, 3, 4, 5};
std::vector<int64_t> list2 = {5, 4, 3, 2, 1};
AbstractBasePtr abs_t1 = FromValue(list1, true);
AbstractBasePtr abs_t2 = FromValue(list2, true);
AbstractBasePtr res_t1 = abs_t1->Join(abs_t2);
ASSERT_EQ(res_t1, abs_t1);
abs_s1 = FromValue(static_cast<int64_t>(1), false); abs_s1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything})); AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything})); AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything})); AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything}));
res_t1 = t1->Join(t2); AbstractBasePtr res_t1 = t1->Join(t2);
ASSERT_EQ(res_t1, t1); ASSERT_EQ(res_t1, t1);
res_t1 = t1->Join(t3); res_t1 = t1->Join(t3);

@ -111,8 +111,11 @@ TEST_F(TestOptLib, test_inline) {
// add infer and renormalize // add infer and renormalize
std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>(); std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>();
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast<int64_t>(1), true); tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast<int64_t>(2), true); tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true);
AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true);
args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2); args_spec_list.push_back(abstract_v2);
AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list); AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list);

@ -74,20 +74,17 @@ TEST_F(TestData, test_build_value) {
AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false); AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false);
AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2})); AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2}));
ValuePtr func_tuple_built = abs_func_tuple->BuildValue(); ValuePtr func_tuple_built = abs_func_tuple->BuildValue();
ASSERT_EQ(*func_tuple_built, ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
// BuildValue(List(AbstractFunction)) should return kAnyValue; // BuildValue(List(AbstractFunction)) should return kAnyValue;
AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2})); AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2}));
ValuePtr func_list_built = abs_func_list->BuildValue(); ValuePtr func_list_built = abs_func_list->BuildValue();
ASSERT_EQ(*func_list_built, ASSERT_EQ(*func_list_built, ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
// BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue // BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue
abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2})); abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2}));
func_tuple_built = abs_func_tuple->BuildValue(); func_tuple_built = abs_func_tuple->BuildValue();
ASSERT_EQ(*func_tuple_built, ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
} }
TEST_F(TestData, test_build_type) { TEST_F(TestData, test_build_type) {
@ -129,7 +126,7 @@ TEST_F(TestData, test_build_shape) {
AbstractBasePtr abstract_tup = FromValue(vec, true); AbstractBasePtr abstract_tup = FromValue(vec, true);
std::shared_ptr<TupleShape> shape_tuple = dyn_cast<TupleShape>(abstract_tup->BuildShape()); std::shared_ptr<TupleShape> shape_tuple = dyn_cast<TupleShape>(abstract_tup->BuildShape());
ASSERT_TRUE(shape_tuple); ASSERT_TRUE(shape_tuple);
const std::vector<BaseShapePtr>& ptr_vec = shape_tuple->shape(); const std::vector<BaseShapePtr> &ptr_vec = shape_tuple->shape();
ASSERT_EQ(ptr_vec.size(), 2); ASSERT_EQ(ptr_vec.size(), 2);
ShapePtr shape1 = dyn_cast<Shape>(ptr_vec[0]); ShapePtr shape1 = dyn_cast<Shape>(ptr_vec[0]);
@ -148,14 +145,14 @@ TEST_F(TestData, test_clone) {
ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack()); ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack());
ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack()); ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack());
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AbstractFunctionPtr f1 =
AnalysisContext::DummyContext()); std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext());
AbstractBasePtr f2 = f1->Clone(); AbstractBasePtr f2 = f1->Clone();
ASSERT_TRUE(*f2 == *f1); ASSERT_TRUE(*f2 == *f1);
AbstractList l1 = AbstractList({s1, s2}); AbstractList l1 = AbstractList({s1, s2});
AbstractBasePtr l2 = l1.Clone(); AbstractBasePtr l2 = l1.Clone();
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get()); AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get());
ASSERT_TRUE(l2_cast != nullptr); ASSERT_TRUE(l2_cast != nullptr);
ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack()); ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack());
@ -184,19 +181,19 @@ TEST_F(TestData, test_broaden) {
AbstractBasePtr s2 = s1->Broaden(); AbstractBasePtr s2 = s1->Broaden();
ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack()); ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1)); ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
ASSERT_TRUE(s2->GetValueTrack()->isa<AnyValue>()); ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>());
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AbstractFunctionPtr f1 =
AnalysisContext::DummyContext()); std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext());
AbstractBasePtr f2 = f1->Broaden(); AbstractBasePtr f2 = f1->Broaden();
ASSERT_TRUE(f2 == f1); ASSERT_TRUE(f2 == f1);
AbstractList l1 = AbstractList({s1, s2}); AbstractList l1 = AbstractList({s1, s2});
AbstractBasePtr l2 = l1.Broaden(); AbstractBasePtr l2 = l1.Broaden();
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get()); AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get());
ASSERT_TRUE(l2_cast != nullptr); ASSERT_TRUE(l2_cast != nullptr);
AbstractBasePtr csr = AbstractJoin(l2_cast->elements()); AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
ASSERT_TRUE(csr->GetValueTrack()->isa<AnyValue>()); ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>());
} }
} // namespace abstract } // namespace abstract

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
""" test_framstruct """ """ test_framstruct """
import numpy as np import numpy as np
import pytest
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
@ -76,9 +75,7 @@ def dynamic_make_tuple(x, lower, upper):
def test_dynamic_make_tuple(): def test_dynamic_make_tuple():
# Dynamically recursively creating static type is invalid in mindspore, as mindspore is a static language. assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2)
with pytest.raises(RuntimeError):
dynamic_make_tuple(2, 1, 5)
def test_make_tuple(): def test_make_tuple():

Loading…
Cancel
Save