From 9e633b6c12db05357c709b1e0941425c0c789356 Mon Sep 17 00:00:00 2001 From: panyifeng Date: Thu, 23 Apr 2020 15:42:11 +0800 Subject: [PATCH] validate bprop rules --- mindspore/ccsrc/ir/dtype.cc | 1 + mindspore/ccsrc/operator/ops.cc | 1 + mindspore/ccsrc/operator/ops.h | 1 + mindspore/ccsrc/optimizer/ad/dfunctor.cc | 8 -- mindspore/ccsrc/optimizer/ad/dfunctor.h | 7 +- mindspore/ccsrc/optimizer/ad/kprim.cc | 46 +++++---- mindspore/ccsrc/optimizer/irpass.cc | 1 + mindspore/ccsrc/optimizer/irpass.h | 1 + .../optimizer/irpass/special_op_eliminate.h | 19 ++++ mindspore/ccsrc/pipeline/pass.cc | 1 + .../ccsrc/pipeline/static_analysis/prim.cc | 7 ++ mindspore/common/dtype.py | 3 + mindspore/ops/_grad/grad_array_ops.py | 6 +- mindspore/ops/_grad/grad_math_ops.py | 4 +- mindspore/ops/_grad/grad_nn_ops.py | 2 +- .../multitype_ops/zeros_like_impl.py | 4 + mindspore/ops/functional.py | 1 + mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/other_ops.py | 63 +++++++++++++ tests/ut/python/model/test_bert_cell.py | 8 +- tests/ut/python/model/test_mix_precision.py | 2 +- tests/ut/python/ops/test_ops.py | 22 ++--- .../python/pynative_mode/test_cell_bprop.py | 39 ++++++-- .../python/pynative_mode/test_framstruct.py | 93 +++++++++++++++++++ .../pynative_mode/test_insert_grad_of.py | 2 +- 25 files changed, 275 insertions(+), 70 deletions(-) diff --git a/mindspore/ccsrc/ir/dtype.cc b/mindspore/ccsrc/ir/dtype.cc index 97291a3dc0..0ba25f2f66 100644 --- a/mindspore/ccsrc/ir/dtype.cc +++ b/mindspore/ccsrc/ir/dtype.cc @@ -695,6 +695,7 @@ REGISTER_PYBIND_DEFINE( (void)py::class_>(m_sub, "String").def(py::init()); (void)py::class_>(m_sub, "RefKeyType").def(py::init()); (void)py::class_>(m_sub, "RefType").def(py::init()); + (void)py::class_>(m_sub, "TypeAnything").def(py::init()); })); const TypePtr kTypeExternal = std::make_shared(); diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index 91a54e1fdb..407efe5689 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -213,6 +213,7 @@ const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_orig const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); +const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); const PrimitivePtr kPrimPrint = std::make_shared("Print"); const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index d84b2e4738..e938e5c64e 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -220,6 +220,7 @@ extern const PrimitivePtr kPrimInsertGradientOf; extern const PrimitivePtr kPrimPrintShapeType; extern const PrimitivePtr kPrimPrint; extern const PrimitivePtr kPrimSameTypeShape; +extern const PrimitivePtr kPrimCheckBprop; extern const PrimitivePtr kPrimDepend; extern const PrimitivePtr kPrimStateSetItem; extern const PrimitivePtr kPrimScalarSummary; diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/optimizer/ad/dfunctor.cc index 33f919e2ac..de368dbdd2 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.cc @@ -309,14 +309,6 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { auto bprop = primal->transforms().find("bprop"); if (bprop != primal->transforms().end()) { FuncGraphPtr bprop_graph = bprop->second.func_graph(); - const size_t param_diff = 1; - if (bprop_graph->output()->isa() && - bprop_graph->output()->cast()->size() + param_diff != bprop_graph->parameters().size()) { - // It does not matter with the final tangents, just a tip for debugging - MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope " - << primal->output()->scope()->name() - << " output must be a tuple and output number should be the same with inputs."; - } resources_->manager()->AddFuncGraph(bprop_graph); if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) { diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.h b/mindspore/ccsrc/optimizer/ad/dfunctor.h index 3059736171..1358cc8f28 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.h @@ -127,7 +127,7 @@ class KPrim { AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, std::vector *const transf_args); - void AddCheckTypeShapeOp(const FuncGraphPtr &bprop_fg); + void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check); Registry bprop_registry_; std::unordered_map bprop_registry_meta_; @@ -137,10 +137,7 @@ template FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { MS_EXCEPTION_IF_NULL(primal); MS_EXCEPTION_IF_NULL(bprop_fg); - - if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { - AddCheckTypeShapeOp(bprop_fg); - } + CheckBprop(bprop_fg, primal->ToString()); auto debug_info = std::make_shared(); debug_info->set_name(primal->ToString()); diff --git a/mindspore/ccsrc/optimizer/ad/kprim.cc b/mindspore/ccsrc/optimizer/ad/kprim.cc index 2c8ddbfa82..c74670e55d 100644 --- a/mindspore/ccsrc/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/optimizer/ad/kprim.cc @@ -50,9 +50,13 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { grad_op_child_scope_prefix + prim->name()); ScopeGuard scope_guard(scope); py::function fn = prim->GetBpropFunction(); + if (fn == nullptr || py::isinstance(fn)) { + MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; + return nullptr; + } FuncGraphPtr func_graph = parse::ParsePythonCode(fn); if (func_graph == nullptr) { - MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << "."; + MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; return nullptr; } return func_graph; @@ -153,31 +157,23 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp } } -void KPrim::AddCheckTypeShapeOp(const FuncGraphPtr &bprop_fg) { +void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { // bprop_fg has been checked in caller - auto same_type_shape = prim::GetPythonOps("same_type_shape", "mindspore.ops.functional")->cast(); - MS_EXCEPTION_IF_NULL(same_type_shape); - - std::vector bout_input; - bout_input.push_back(NewValueNode(prim::kPrimMakeTuple)); - - auto fg_out = bprop_fg->output(); - MS_EXCEPTION_IF_NULL(fg_out); - auto cnode = fg_out->cast(); - MS_EXCEPTION_IF_NULL(cnode); - - auto &inputs = cnode->inputs(); - auto params = bprop_fg->parameters(); - std::vector sub_input; - for (size_t i = 1; i < inputs.size(); ++i) { - sub_input.clear(); - sub_input.push_back(NewValueNode(same_type_shape)); - sub_input.push_back(inputs[i]); - sub_input.push_back(params[i - 1]); - bout_input.push_back(bprop_fg->NewCNode(sub_input)); - } - AnfNodePtr cbout = bprop_fg->NewCNode(bout_input); - bprop_fg->set_output(cbout); + auto check_bprop = prim::GetPythonOps("check_bprop", "mindspore.ops.functional")->cast(); + MS_EXCEPTION_IF_NULL(check_bprop); + check_bprop->set_attr("prim_to_check", std::make_shared(prim_to_check)); + + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2); + AnfNodePtr params = bprop_fg->NewCNode(inputs); + + inputs.clear(); + inputs.push_back(NewValueNode(check_bprop)); + inputs.push_back(bprop_fg->output()); + inputs.push_back(params); + AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs); + bprop_fg->set_output(bprop_out); } FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 3b44700e1c..2bd013cb08 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -67,6 +67,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); + check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode); // Env Item Eliminate diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index 0af22c5cd0..02bfee65d6 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -45,6 +45,7 @@ class OptimizeIRPassLib { SubstitutionPtr reduce_eliminate_; SubstitutionPtr partial_eliminate_; SubstitutionPtr same_eliminate_; + SubstitutionPtr check_bprop_eliminate_; SubstitutionPtr reset_defer_inline_; // Env Item Eliminate diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index 2dd27a89c3..e06ccd862b 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -109,6 +109,25 @@ class SameEliminater : public AnfVisitor { AnfNodePtr x_{nullptr}; }; +// {prim::kPrimCheckBprop, X, Y} -> X +class CheckBpropEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + x_ = nullptr; + AnfVisitor::Match(prim::kPrimCheckBprop, {IsNode, IsNode})(node); + return x_; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } + } + + private: + AnfNodePtr x_{nullptr}; +}; + // Reset defer_inline flag class ResetDeferInline : public AnfVisitor { public: diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index 6ce6c4603d..d9f805fdc9 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { }); opt::OptPassConfig a_3 = opt::OptPassConfig({ irpass.same_eliminate_, + irpass.check_bprop_eliminate_, irpass.replace_applicator_, }); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index d71ad8f710..293f31707e 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -295,6 +295,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic["shape"] = shape; dic["dtype"] = arg_slice->BuildType(); dic["value"] = BuildValue(arg_slice->BuildValue()); + } else if (abs_base->isa()) { + auto value = abs_base->cast()->ref(); + dic = ConvertAbstractToPython(value); } else if (abs_base->isa()) { auto arg_tuple = dyn_cast(abs_base); size_t len = arg_tuple->size(); @@ -327,6 +330,10 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic["shape"] = py::none(); dic["dtype"] = py::none(); dic["value"] = py::none(); + } else if (abs_base->isa()) { + dic["shape"] = py::none(); + dic["dtype"] = abs_base->BuildType(); + dic["value"] = py::none(); } else { auto value = abs_base->BuildValue(); if ((*value == *kAnyValue)) { diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index 702e01effb..e6b9779f39 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -85,13 +85,16 @@ list_ = typing.List() tuple_ = typing.Tuple() tensor = typing.TensorType() function = typing.Function() +function_type = typing.Function symbolic_key = typing.SymbolicKeyType() env_type = typing.EnvType() +env_type_type = typing.EnvType type_type = typing.TypeType() type_none = typing.TypeNone() string = typing.String() type_refkey = typing.RefKeyType() tensor_type = typing.TensorType +anything_type = typing.TypeAnything number_type = (int8, int16, diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 35d37b3ada..b9281a7456 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -211,11 +211,11 @@ def get_bprop_slice(self): def bprop(x, begin, size, out, dout): dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout) - return (dx,) + return (dx, zeros_like(begin), zeros_like(size)) def bprop_gpu(x, begin, size, out, dout): dx = dx = G.SliceGrad()(dout, x, begin, size) - return (dx,) + return (dx, zeros_like(begin), zeros_like(size)) if context.get_context('device_target') == "GPU": return bprop_gpu @@ -262,7 +262,7 @@ def get_bprop_gather_v2(self): # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) perm_2 = _generate_inverse_index(x_shp, axis) params_grad = transpose(params_grad, perm_2) - return params_grad, zeros_like(indices) + return params_grad, zeros_like(indices), zeros_like(axis) return bprop diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index c334050218..2f39fe8745 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -505,7 +505,7 @@ def get_bprop_reducemax(self): def bprop(x, axis, out, dout): dx = _min_or_max_grad(x, axis, out, dout) - return (dx,) + return (dx, zeros_like(axis)) return bprop @@ -528,7 +528,7 @@ def get_bprop_reducemin(self): def bprop(x, axis, out, dout): dx = _min_or_max_grad(x, axis, out, dout) - return (dx,) + return (dx, zeros_like(axis)) return bprop diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index e43d3d5d3a..baccdbbbb2 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -436,7 +436,7 @@ def get_bprop_onehot(self): """Grad definition for `OneHot` operation.""" def bprop(indices, depth, on_value, off_value, out, dout): - return zeros_like(indices), zeros_like(depth) + return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value) return bprop diff --git a/mindspore/ops/composite/multitype_ops/zeros_like_impl.py b/mindspore/ops/composite/multitype_ops/zeros_like_impl.py index 1c1a4f1d12..1308bfd62a 100644 --- a/mindspore/ops/composite/multitype_ops/zeros_like_impl.py +++ b/mindspore/ops/composite/multitype_ops/zeros_like_impl.py @@ -31,6 +31,10 @@ def _zeros_like_scala(x): """Returns 0 which has the same dtype as x where x is a scalar.""" return 0 +@zeros_like_leaf.register("Bool") +def _zeros_like_bool(x): + """Returns False if x is a bool.""" + return False newenv = base.EnvInstance_() diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 4135133e85..4cae11aed1 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -56,6 +56,7 @@ tensor_pow = P.Pow() tensor_mod = P.FloorMod() strided_slice = P.StridedSlice() same_type_shape = P.SameTypeShape() +check_bprop = P.CheckBprop() equal = P.Equal() not_equal = P.NotEqual() assign_sub = P.AssignSub() diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index c75c2031d7..868d3b359e 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -67,7 +67,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, SparseSoftmaxCrossEntropyWithLogits, Tanh, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, ApplyRMSProp, ApplyCenteredRMSProp) -from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey +from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop from . import _quant_ops from ._quant_ops import * @@ -179,6 +179,7 @@ __all__ = [ 'GeSwitch', 'Merge', 'SameTypeShape', + 'CheckBprop', 'CheckValid', 'BoundingBoxEncode', 'BoundingBoxDecode', diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 12a8a2cfde..5e66050d9a 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -269,3 +269,66 @@ class MakeRefKey(Primitive): def __call__(self): pass + + +class CheckBprop(PrimitiveWithInfer): + """ + Checks whether data type and shape of corresponding element from tuple x and y are the same. + + Raises: + TypeError: If not the same. + + Inputs: + - **input_x** (tuple[Tensor]) - The input_x contains the outputs of bprop to be checked. + - **input_y** (tuple[Tensor]) - The input_y contains the inputs of bprop to check against. + + Outputs: + (tuple[Tensor]), the input_x, + if data type and shape of corresponding elements from `input_x` and `input_y` are the same. + + Examples: + >>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),) + >>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),) + >>> out = P.CheckBprop()(input_x, input_y) + """ + + @prim_attr_register + def __init__(self): + """init CheckBprop""" + + def infer_shape(self, xshapes, yshapes): + tips = f'Bprop of {self.prim_to_check}' + if len(xshapes) < len(yshapes): + raise TypeError(f"{tips}, the size of output should be {len(yshapes)}," + f" but got {len(xshapes)}.") + checking_range = len(yshapes) + for i in range(checking_range): + xshape = xshapes[i] + yshape = yshapes[i] + if not xshape or not yshape: + continue + if xshape != yshape: + raise TypeError(f"{tips}, the shape of {i}th output should be {yshape}," + f" but got {xshape}.") + return xshapes + + def infer_dtype(self, xdtypes, ydtypes): + tips = f'Bprop of {self.prim_to_check}' + if len(xdtypes) < len(ydtypes): + raise TypeError(f"{tips}, the size of output should be {len(ydtypes)}," + f" but got {len(xdtypes)}.") + checking_range = len(ydtypes) + for i in range(checking_range): + xdtype = xdtypes[i] + ydtype = ydtypes[i] + if isinstance(xdtype, mstype.anything_type) or isinstance(ydtype, mstype.anything_type): + continue + if isinstance(ydtype, mstype.function_type): + if not isinstance(xdtype, mstype.env_type_type): + raise TypeError(f"{tips}, the dtype of {i}th output should be {mstype.env_type_type}," + f" but got {xdtype}.") + continue + if xdtype != ydtype: + raise TypeError(f"{tips}, the dtype of {i}th output should be {ydtype}," + f" but got {xdtype}.") + return xdtypes diff --git a/tests/ut/python/model/test_bert_cell.py b/tests/ut/python/model/test_bert_cell.py index fdaaac397b..2cb642c75f 100644 --- a/tests/ut/python/model/test_bert_cell.py +++ b/tests/ut/python/model/test_bert_cell.py @@ -317,7 +317,7 @@ test_case_cell_ops = [ initializer_range=0.02, dropout_prob=0.1), 'desc_inputs': [[1, 768], [1, 768]], - 'desc_bprop': [[1, 128, 768]]}), # maybe not right + 'desc_bprop': [[1, 768]]}), ('BertTransformer_2', { 'block': bert_trans(), 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}), @@ -331,7 +331,7 @@ test_case_cell_ops = [ 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), Tensor(np.random.rand(128).astype(np.int32)), [128]], 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], - 'num_output': 3}), # maybe not right + 'num_output': 3}), ('BertModel_1', { 'block': BertModel(config=BertConfig(batch_size=1, @@ -342,7 +342,7 @@ test_case_cell_ops = [ 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), Tensor(np.random.rand(128).astype(np.int32)), [128]], 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], - 'num_output': 3}), # maybe not right + 'num_output': 3}), ('BertModel_2', { 'block': BertModel(config=BertConfig(batch_size=1, @@ -354,7 +354,7 @@ test_case_cell_ops = [ 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), Tensor(np.random.rand(128).astype(np.int32)), [128]], 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], - 'num_output': 3}), # maybe not right + 'num_output': 3}), ('BertPretrainingLoss', { 'block': BertPretrainingLoss(config=BertConfig(batch_size=1)), diff --git a/tests/ut/python/model/test_mix_precision.py b/tests/ut/python/model/test_mix_precision.py index 0a8b185e8c..0c762f42b9 100644 --- a/tests/ut/python/model/test_mix_precision.py +++ b/tests/ut/python/model/test_mix_precision.py @@ -175,7 +175,7 @@ class GetParamGrad(nn.Cell): def test_grad_conv_prelu(): shapes = [[64, 64, 112, 112]] - outshape = [[64, 64, 56, 56]] + outshape = [[64, 64, 112, 112]] net = IRBlockZ(inplanes=64, planes=64).add_flags_recursive(fp16=True) inputs = [convert(shp, dtype=np.float16) for shp in shapes] sens_shape = outshape[0] diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index d6622e76f4..6121933d70 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -585,7 +585,7 @@ test_case_nn_ops = [ ('ReLUV2', { 'block': P.ReLUV2(), 'desc_inputs': [[1, 3, 4, 4]], - 'desc_bprop': [[1, 3, 4, 4], [1, 3, 4, 4]]}), + 'desc_bprop': [[1, 3, 4, 4], ([1, 1, 4, 4, 2], {'dtype': np.uint8})]}), ('ReLUGrad', { 'block': G.ReluGrad(), 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], @@ -626,7 +626,7 @@ test_case_nn_ops = [ ('MaxPoolWithArgmax', { 'block': P.MaxPoolWithArgmax(ksize=2, strides=2), 'desc_inputs': [[128, 32, 32, 64]], - 'desc_bprop': [[128, 32, 8, 16], [128, 32, 8, 16]]}), + 'desc_bprop': [[128, 32, 16, 32], ([128, 32, 4, 33], {'dtype': np.uint16})]}), ('SoftmaxCrossEntropyWithLogits', { 'block': P.SoftmaxCrossEntropyWithLogits(), 'desc_inputs': [[1, 10], [1, 10]], @@ -639,7 +639,7 @@ test_case_nn_ops = [ ('LogSoftmax', { 'block': P.LogSoftmax(), 'desc_inputs': [[64, 2]], - 'desc_bprop': [[160, 30522]]}), + 'desc_bprop': [[64, 2]]}), ('LogSoftmaxGrad', { 'block': G.LogSoftmaxGrad(), 'desc_inputs': [[16, 1234], [16, 1234]], @@ -648,7 +648,7 @@ test_case_nn_ops = [ ('LayerNorm', { 'block': P.LayerNorm(), 'desc_inputs': [[2, 16], [16], [16]], - 'desc_bprop': [[2, 16], [2, 16], [2, 16]]}), + 'desc_bprop': [[2, 16], [2, 1], [2, 1]]}), ('LayerNormGrad', { 'block': G.LayerNormGrad(), 'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]], @@ -845,7 +845,7 @@ test_case_nn_ops = [ 'block': P.OneHot(), 'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)], 'desc_inputs': [Tensor(np.array([64]).astype(np.int32))], - 'desc_bprop': [[64, 2]]}), + 'desc_bprop': [[1, 3]]}), ('ReduceProd_0', { 'block': P.ReduceProd(), 'desc_const': [0], @@ -950,7 +950,7 @@ test_case_array_ops = [ 'block': P.Cast(), 'desc_const': [mstype.int32], 'desc_inputs': [[2, 3, 4, 5]], - 'desc_bprop': [Tensor(np.ones((2, 3, 3, 5)).astype(np.int32))]}), + 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5)).astype(np.int32))]}), ('ExpandDims', { 'block': P.ExpandDims(), 'desc_const': [0], @@ -1002,12 +1002,12 @@ test_case_array_ops = [ 'desc_inputs': [ (Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)), Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)))], - 'desc_bprop': [[4, 2]]}), + 'desc_bprop': [([4, 2], {'dtype': np.int32})]}), ('ConcatV2_1', { 'block': P.Concat(axis=2), 'desc_inputs': [(Tensor(np.array([[[0, 1, 2]], [[2, 1, 2]]]).astype(np.int32)), Tensor(np.array([[[0, 1]], [[2, 1]]]).astype(np.int32)))], - 'desc_bprop': [[2, 1, 5]]}), + 'desc_bprop': [([2, 1, 5], {'dtype': np.int32})]}), ('ConcatV2_2', { 'block': NetForConcat(), 'desc_inputs': [[2, 2]], @@ -1042,7 +1042,7 @@ test_case_array_ops = [ ('Pack_2', { 'block': NetForPackInput(P.Pack()), 'desc_inputs':[[2, 2]], - 'desc_bprop':[[2, 2, 2]], + 'desc_bprop':[[1, 2, 2]], }), ('Pack_3', { 'block': NetForPackInput(P.Pack()), @@ -1077,7 +1077,7 @@ test_case_array_ops = [ ('SpaceToBatch_2', { 'block': P.SpaceToBatch(2, [[1, 1], [0, 4]]), 'desc_inputs': [[1, 3, 2, 2]], - 'desc_bprop': [[4, 3, 2, 4]], + 'desc_bprop': [[4, 3, 2, 3]], }), ('BatchToSpace_1', { 'block': P.BatchToSpace(2, [[0, 0], [0, 0]]), @@ -1124,7 +1124,7 @@ test_case_other_ops = [ 'desc_const': [(3, 3)], 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), Tensor(np.ones((2,), np.int32))), - 'desc_bprop': [[3, 3]]}), + 'desc_bprop': [([3, 3], {'dtype': np.int32})]}), ('SmoothL1Loss', { 'block': P.SmoothL1Loss(), 'desc_inputs': [[256, 4], [256, 4]], diff --git a/tests/ut/python/pynative_mode/test_cell_bprop.py b/tests/ut/python/pynative_mode/test_cell_bprop.py index c69b80412e..bd9f46d21d 100644 --- a/tests/ut/python/pynative_mode/test_cell_bprop.py +++ b/tests/ut/python/pynative_mode/test_cell_bprop.py @@ -229,12 +229,6 @@ class TwoInputBprop(nn.Cell): def bprop(self, x, y, out, dout): return 5 * x, 8 * y -class TwoInput(nn.Cell): - def __init__(self): - super().__init__() - self.op = P.Mul() - def construct(self, x, y): - return self.op(x, y) class TwoInputWithParameter(nn.Cell): def __init__(self): @@ -301,8 +295,37 @@ class MulAddWithWrongOutputNum(nn.Cell): def construct(self, x, y): return 2 * x + y def bprop(self, x, y, out, dout): - return 2 * dout, 2 * y, out + return 2 * dout, def test_grad_mul_add_with_wrong_output_num(): mul_add = MulAddWithWrongOutputNum() - C.grad_all(mul_add)(1, 2) + with pytest.raises(TypeError): + C.grad_all(mul_add)(1, 2) + +class MulAddWithWrongOutputType(nn.Cell): + def __init__(self): + super(MulAddWithWrongOutputType, self).__init__() + def construct(self, x, y): + return 2 * x + y + def bprop(self, x, y, out, dout): + return 2 * dout, 2 + +def test_grad_mul_add_with_wrong_output_type(): + mul_add = MulAddWithWrongOutputType() + with pytest.raises(TypeError): + C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) + + +class MulAddWithWrongOutputShape(nn.Cell): + def __init__(self): + super(MulAddWithWrongOutputShape, self).__init__() + self.ones = Tensor(np.ones([2,])) + def construct(self, x, y): + return 2 * x + y + def bprop(self, x, y, out, dout): + return 2, self.ones + +def test_grad_mul_add_with_wrong_output_shape(): + mul_add = MulAddWithWrongOutputShape() + with pytest.raises(TypeError): + C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index eb3b76765a..7e504c405f 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -32,6 +32,8 @@ from ....mindspore_test_framework.utils.check_gradient import ( OperationGradChecker, check_gradient, ScalarGradChecker) from ....mindspore_test_framework.utils.bprop_util import bprop import mindspore.context as context +from mindspore.ops._grad.grad_base import bprop_getters +from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer def setup_module(module): @@ -721,3 +723,94 @@ def test_grad_if_defer_inline(): inp = Tensor(np.ones([128, 96]).astype(np.float32)) grads = C.grad_all(network)(inp) assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) + +def test_bprop_with_wrong_output_num(): + class BpropWithWrongOutputNum(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') + + def __call__(self, x, y): + return x + + def infer_shape(self, x_shape, yshape): + return x_shape + + def infer_dtype(self, x_type, y_type): + return x_type + + @bprop_getters.register(BpropWithWrongOutputNum) + def get_bprop_with_wrong_output_num(self): + """Generate bprop for BpropWithWrongOutputNum""" + def bprop(x, y, out, dout): + return (dout,) + return bprop + + class BpropWithWrongOutputNumCell(nn.Cell): + def __init__(self): + super(BpropWithWrongOutputNumCell, self).__init__() + def construct(self, x, y): + return BpropWithWrongOutputNum()(x, y) + with pytest.raises(TypeError): + C.grad_all(BpropWithWrongOutputNumCell())(1, 2) + +def test_bprop_with_wrong_output_type(): + class BpropWithWrongOutputType(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') + + def __call__(self, x): + return x + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + return x_type + + @bprop_getters.register(BpropWithWrongOutputType) + def get_bprop_with_wrong_output_type(self): + """Generate bprop for BpropWithWrongOutputType""" + def bprop(x, out, dout): + return (1,) + return bprop + + class BpropWithWrongOutputTypeCell(nn.Cell): + def __init__(self): + super(BpropWithWrongOutputTypeCell, self).__init__() + def construct(self, x): + return BpropWithWrongOutputType()(x) + with pytest.raises(TypeError): + C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) + +def test_bprop_with_wrong_output_shape(): + class BpropWithWrongOutputShape(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') + + def __call__(self, x): + return x + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + return x_type + + @bprop_getters.register(BpropWithWrongOutputShape) + def get_bprop_with_wrong_output_shape(self): + """Generate bprop for BpropWithWrongOutputShape""" + ones = Tensor(np.ones([2,]).astype(np.int32)) + def bprop(x, out, dout): + return (ones,) + return bprop + + class BpropWithWrongOutputShapeCell(nn.Cell): + def __init__(self): + super(BpropWithWrongOutputShapeCell, self).__init__() + def construct(self, x): + return BpropWithWrongOutputShape()(x) + with pytest.raises(TypeError): + C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) diff --git a/tests/ut/python/pynative_mode/test_insert_grad_of.py b/tests/ut/python/pynative_mode/test_insert_grad_of.py index a11c5fa2b1..0527365a98 100644 --- a/tests/ut/python/pynative_mode/test_insert_grad_of.py +++ b/tests/ut/python/pynative_mode/test_insert_grad_of.py @@ -79,7 +79,7 @@ def test_InsertGradientOf_2(): summary = P.ScalarSummary() def debug_gradient(dx): """ debug_gradient """ - dx = summary("dx: ", dx) + summary("dx: ", dx) return dx debug = P.InsertGradientOf(debug_gradient)