From 947e19b8396e254032eb0b8dbf5d827d18668703 Mon Sep 17 00:00:00 2001 From: fary86 Date: Sat, 22 Aug 2020 10:28:33 +0800 Subject: [PATCH] Fix bug of switch layer join --- .../pipeline/jit/parse/data_converter.cc | 96 ++++++++++++++++++- .../ccsrc/pipeline/jit/parse/parse_base.h | 2 +- .../pipeline/jit/static_analysis/prim.cc | 5 +- mindspore/ccsrc/utils/convert_utils.cc | 24 ++++- mindspore/core/abstract/abstract_value.cc | 6 +- mindspore/core/abstract/utils.cc | 12 ++- mindspore/core/ir/scalar.h | 12 +-- .../utils/check_gradient.py | 2 +- tests/ut/python/ops/test_control_ops.py | 70 ++++++++++++++ tests/ut/python/ops/test_ops.py | 7 +- tests/ut/python/ops/test_ops_reid.py | 12 +-- .../python/parameter_feature/test_var_grad.py | 16 ++-- 12 files changed, 229 insertions(+), 35 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index f9be48c85d..72e17fed95 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -283,9 +283,99 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); return false; } + +bool ConvertIntegerWithType(const int &obj, ValuePtr *const data, TypePtr dtype = nullptr) { + if (dtype == nullptr) { + *data = std::make_shared(obj); + return true; + } + + auto int_dypte = dyn_cast(dtype); + if (int_dypte != nullptr) { + switch (int_dypte->nbits()) { + case 8: + *data = std::make_shared(static_cast(obj)); + break; + case 16: + *data = std::make_shared(obj); + break; + case 32: + *data = std::make_shared(obj); + break; + case 64: + *data = std::make_shared(obj); + break; + default: + *data = std::make_shared(obj); + } + return true; + } + + auto uint_dypte = dyn_cast(dtype); + if (int_dypte != nullptr) { + switch (uint_dypte->nbits()) { + case 8: + *data = std::make_shared(obj); + break; + case 16: + *data = std::make_shared(obj); + break; + case 32: + *data = std::make_shared(obj); + break; + case 64: + *data = std::make_shared(obj); + break; + default: + *data = std::make_shared(obj); + } + return true; + } + + auto float_dypte = dyn_cast(dtype); + if (float_dypte != nullptr) { + switch (float_dypte->nbits()) { + case 32: + *data = std::make_shared(obj); + break; + case 64: + *data = std::make_shared(obj); + break; + default: + *data = std::make_shared(obj); + } + return true; + } + + return false; +} + +bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) { + if (dtype == nullptr) { + *data = std::make_shared(obj); + return true; + } + + auto float_dypte = dyn_cast(dtype); + if (float_dypte == nullptr) { + return false; + } + + switch (float_dypte->nbits()) { + case 32: + *data = std::make_shared(obj); + break; + case 64: + *data = std::make_shared(obj); + break; + default: + *data = std::make_shared(obj); + } + return true; +} } // namespace -bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) { +bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) { // check parameter valid if (data == nullptr) { MS_LOG(ERROR) << "Data is null pointer"; @@ -299,9 +389,9 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature } else if (py::isinstance(obj)) { converted = std::make_shared(py::cast(obj)); } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); + ret = ConvertIntegerWithType(py::cast(obj), &converted, dtype); } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); + ret = ConvertFloatWithType(py::cast(obj), &converted, dtype); } else if (py::isinstance(obj)) { converted = std::make_shared(py::cast(obj)); } else if (py::isinstance(obj)) { diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index ddc774e3d5..2581821413 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -139,7 +139,7 @@ enum ClassInstanceTypeDef { }; // Convert python object to ValuePtr -bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false); +bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, TypePtr dtype = nullptr); // Convert python obj to graph FuncGraphPtr ConvertToFuncGraph(const py::object &obj, diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index b25fcdd38b..23a95553c3 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -407,9 +407,9 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { // Convert to AbstractValue based on type and shape + auto out_dtype = output["dtype"]; if (output["value"].is_none()) { auto out_shape = output["shape"]; - auto out_dtype = output["dtype"]; py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none(); py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none(); @@ -417,7 +417,8 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic } // Convert pyobject to Value, then to AbstractValue ValuePtr converted_ret = nullptr; - bool converted = parse::ConvertData(output["value"], &converted_ret); + TypePtr dtype = py::isinstance(out_dtype) ? out_dtype.cast() : nullptr; + bool converted = parse::ConvertData(output["value"], &converted_ret, false, dtype); if (!converted) { MS_LOG(EXCEPTION) << "Convert data failed"; } diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 0b41330d58..6fb4ff13ef 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -45,14 +45,34 @@ py::object ValuePtrToPyData(const ValuePtr &value) { MS_LOG(EXCEPTION) << "value is null"; } py::object ret; - if (value->isa()) { - MS_LOG(DEBUG) << "int"; + if (value->isa()) { + MS_LOG(DEBUG) << "int8"; + py::int_ v = value->cast()->value(); + ret = v; + } else if (value->isa()) { + MS_LOG(DEBUG) << "int16"; + py::int_ v = value->cast()->value(); + ret = v; + } else if (value->isa()) { + MS_LOG(DEBUG) << "int32"; py::int_ v = value->cast()->value(); ret = v; } else if (value->isa()) { MS_LOG(DEBUG) << "int64"; py::int_ v = value->cast()->value(); ret = v; + } else if (value->isa()) { + MS_LOG(DEBUG) << "uint8"; + py::int_ v = value->cast()->value(); + ret = v; + } else if (value->isa()) { + MS_LOG(DEBUG) << "uint16"; + py::int_ v = value->cast()->value(); + ret = v; + } else if (value->isa()) { + MS_LOG(DEBUG) << "uint32"; + py::int_ v = value->cast()->value(); + ret = v; } else if (value->isa()) { MS_LOG(DEBUG) << "uint64"; py::int_ v = value->cast()->value(); diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 154122c5aa..389f269d5a 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -97,8 +97,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { } auto value_self = GetValueTrack(); MS_EXCEPTION_IF_NULL(value_self); - ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); + if (res_type == kAnyType) { + MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << GetTypeTrack()->ToString() + << ", type2 = " << other->GetTypeTrack()->ToString(); + } + ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); if (res_value == value_self) { return shared_from_base(); } diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 20eeab0de5..3110323a55 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -50,9 +50,17 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { if (*shape1 == *shape2) { return shape1; } + // lengths of two shapes are not same, join failed if (shape1->shape().size() != shape2->shape().size()) { - MS_LOG(WARNING) << "Unsupported shape join. shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString(); - return shape1; + // special case: shape(1), shape() -> shape(1) + if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) { + return shape1; + } + if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) { + return shape2; + } + MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString() + << ", shape2 = " << shape2->ToString(); } std::vector dims; bool has_dynamic_shape = false; diff --git a/mindspore/core/ir/scalar.h b/mindspore/core/ir/scalar.h index 62c5f35ba5..7d76bcc1c5 100644 --- a/mindspore/core/ir/scalar.h +++ b/mindspore/core/ir/scalar.h @@ -105,7 +105,7 @@ class Int8Imm : public IntergerImm { std::string DumpText() const override { std::ostringstream oss; - oss << "I8(" << v_ << ")"; + oss << "I8(" << int(v_) << ")"; return oss.str(); } @@ -131,7 +131,7 @@ class Int16Imm : public IntergerImm { std::string DumpText() const override { std::ostringstream oss; - oss << "I16(" << v_ << ")"; + oss << "I16(" << int(v_) << ")"; return oss.str(); } @@ -157,7 +157,7 @@ class Int32Imm : public IntergerImm { std::string DumpText() const override { std::ostringstream oss; - oss << "I32(" << v_ << ")"; + oss << "I32(" << int(v_) << ")"; return oss.str(); } @@ -211,7 +211,7 @@ class UInt8Imm : public IntergerImm { std::string DumpText() const override { std::ostringstream oss; - oss << "U8(" << v_ << ")"; + oss << "U8(" << unsigned(v_) << ")"; return oss.str(); } @@ -239,7 +239,7 @@ class UInt16Imm : public IntergerImm { std::string DumpText() const override { std::ostringstream oss; - oss << "U16(" << v_ << ")"; + oss << "U16(" << unsigned(v_) << ")"; return oss.str(); } @@ -267,7 +267,7 @@ class UInt32Imm : public IntergerImm { std::string DumpText() const override { std::ostringstream oss; - oss << "U32(" << v_ << ")"; + oss << "U32(" << unsigned(v_) << ")"; return oss.str(); } diff --git a/tests/mindspore_test_framework/utils/check_gradient.py b/tests/mindspore_test_framework/utils/check_gradient.py index c2252b8a78..81490e7ee1 100644 --- a/tests/mindspore_test_framework/utils/check_gradient.py +++ b/tests/mindspore_test_framework/utils/check_gradient.py @@ -324,7 +324,7 @@ class ScalarGradChecker(_GradChecker): self.input_selector = [i for i in range(self.nin)] def get_sens(self, i): - return 1 + return 1.0 def check_against_numeric(self, out_index): args = list(self.args) diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 483485ea64..1b5b44755c 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -911,3 +911,73 @@ def test_recursive_call(): with pytest.raises(RuntimeError): net(input_data) context.set_context(max_call_depth=old_max_call_depth) + + +def test_switch_layer_shape_join_failed(): + class AddFuncNet(nn.Cell): + def __init__(self, funcs, new_func): + super(AddFuncNet, self).__init__() + self.funcs = funcs + self.new_func = new_func + + def construct(self, i, inputs): + final_funcs = self.funcs + (self.new_func,) + x = final_funcs[i](inputs) + return x + + class ReLUTuple(nn.Cell): + def __init__(self): + super(ReLUTuple, self).__init__() + self.op = nn.ReLU() + + def construct(self, x): + return self.op(x[0]) + + func1 = nn.Softmax() + func2 = nn.ReLU() + func3 = ReLUTuple() + + funcs = (func1, func2) + + + net = AddFuncNet(funcs, func3) + + inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) + i = Tensor(1, mstype.int32) + with pytest.raises(ValueError) as err: + net(i, inp) + + +def test_switch_layer_dtype_join_failed(): + class Cast(nn.Cell): + def __init__(self, dtype): + super(Cast, self).__init__() + self.op = P.Cast() + self.dtype = dtype + + def construct(self, x): + y = self.op(x, self.dtype) + return y + y + + class SwitchNegNet(nn.Cell): + def __init__(self, funcs): + super(SwitchNegNet, self).__init__() + self.funcs = funcs + self.op = P.Neg() + + def construct(self, i, inputs): + x = self.funcs[i](inputs) + x = self.op(x) + return x + + + func1 = nn.ReLU() + func2 = Cast(mstype.int32) + funcs = (func1, func2) + net = SwitchNegNet(funcs) + + inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) + i = Tensor(0, mstype.int32) + + with pytest.raises(TypeError) as err: + net(i, inp) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index b64e42dbb4..49ed9625b5 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -33,6 +33,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ import pipeline_for_compile_grad_ge_graph_for_case_by_case_config +from ....ops_common import convert class InputBackward(nn.Cell): def __init__(self, network): @@ -1699,7 +1700,7 @@ test_case_nn_ops = [ ('ResizeBilinear', { 'block': P.ResizeBilinear((5, 5)), 'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)], - 'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)]}), + 'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float32)]}), ('ResizeBilinearGrad', { 'block': G.ResizeBilinearGrad(), 'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32), Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32)], @@ -1708,7 +1709,7 @@ test_case_nn_ops = [ ('ROIAlign', { 'block': P.ROIAlign(7, 7, 0.03125, 2), 'desc_inputs': [[2, 256, 192, 320], [1024, 5]], - 'desc_bprop': [[7, 7]]}), + 'desc_bprop': [[1024, 256, 7, 7]]}), ('ROIAlignGrad', { 'block': G.ROIAlignGrad((1, 1, 1, 1), 2, 2, 0.5, 2), 'desc_inputs': [[1, 1, 2, 2], [1, 5]], @@ -2311,7 +2312,7 @@ test_case_other_ops = [ ('IOU', { 'block': P.IOU(), 'desc_inputs': [Tensor(np.ones((256, 4), np.float16)), Tensor(np.ones((128, 4), np.float16))], - 'desc_bprop': [[128, 256]]}), + 'desc_bprop': [convert([128, 256], np.float16)]}), ('Summary', { 'block': SummaryNet(), 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), diff --git a/tests/ut/python/ops/test_ops_reid.py b/tests/ut/python/ops/test_ops_reid.py index b3b3e1d470..5cf320d5c9 100644 --- a/tests/ut/python/ops/test_ops_reid.py +++ b/tests/ut/python/ops/test_ops_reid.py @@ -118,29 +118,29 @@ test_case_reid_ops = [ 'desc_inputs': [[256, 8]], 'desc_bprop': [[256, 8]]}), ('Pow', { - 'block': P.Pow(), # 输入有标量插件产生了段错误。 + 'block': P.Pow(), 'desc_const': [2.0], 'desc_inputs': [[1, 512]], 'desc_bprop': [[1, 512]]}), ('LogicalNot', { 'block': P.LogicalNot(), 'desc_inputs': [convert([256], np.bool_)], - 'desc_bprop': [[256]]}), # 自定义算子 input bool没转换,gongchen提单。 + 'desc_bprop': [convert([256], np.bool_)]}), ('Equal', { 'block': P.Equal(), 'desc_inputs': [convert([256], np.float16), convert([256], np.float16)], - 'desc_bprop': [[256]]}), + 'desc_bprop': [convert([256], np.bool_)]}), ('Greater', { 'block': P.Greater(), 'desc_inputs': [convert([256], np.float16), convert([256], np.float16)], - 'desc_bprop': [[256]]}), + 'desc_bprop': [convert([256], np.bool_)]}), ('Dropout', { 'block': nn.Dropout(), 'desc_inputs': [[1, 512, 7, 7]], - 'desc_bprop': [[1, 512, 7, 7]]}), # 输入有标量插件产生了段错误。 + 'desc_bprop': [[1, 512, 7, 7]]}), ('MatMul', { 'block': P.MatMul(), - 'desc_inputs': [[64, 512], [512, 64]], # fp16不行。很有问题。 + 'desc_inputs': [[64, 512], [512, 64]], 'desc_bprop': [[64, 64]]}), ('Maximum', { 'block': P.Maximum(), diff --git a/tests/ut/python/parameter_feature/test_var_grad.py b/tests/ut/python/parameter_feature/test_var_grad.py index f0358394e7..0902aff522 100644 --- a/tests/ut/python/parameter_feature/test_var_grad.py +++ b/tests/ut/python/parameter_feature/test_var_grad.py @@ -77,8 +77,8 @@ class Bprop(Cell): self.grad = grad_op self.with_sens = False self.sens = sens - if sens: - self.sens = Tensor(sens, dtype=mstype.float32) + if not sens is None: + self.sens = sens if isinstance(sens, Tensor) else Tensor(sens, dtype=mstype.float32) self.with_sens = True def construct(self, *inputs): @@ -108,7 +108,7 @@ def test_all_var_args_grad_with_sens(): x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) - sens = Tensor(1.0, dtype=mstype.float32) + sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) net = VarNet(SecondNet()) grad_net = GradNet(net) _ = grad_net(x, y, sens) @@ -160,7 +160,7 @@ def test_grad_all_var_args_with_sens(): x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) - sens = Tensor(1.0, dtype=mstype.float32) + sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) net = VarNet(SecondNet()) grad_net = GradNet(net) _ = grad_net(x, y, sens) @@ -178,7 +178,7 @@ def test_grad_var_args_with_sens(): x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) - sens = Tensor(1.0, dtype=mstype.float32) + sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) net = VarNet(SecondNet()) grad_net = GradNet(net) _ = grad_net(x, y, sens) @@ -237,7 +237,7 @@ def test_var_args_grad(): x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) - sens = Tensor(1.0, dtype=mstype.float32) + sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) net = VarNet(SecondNet()) grad_net = GradNet(net) _ = grad_net(x, y, sens) @@ -285,14 +285,14 @@ def test_grad_within_if_else(): self.net = net grad_op = C.GradOperation( name='grad', get_all=False, get_by_list=True, sens_param=True) - self.grad = Bprop(self.net, True, self.weights, grad_op, 1.0) + sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + self.grad = Bprop(self.net, True, self.weights, grad_op, sens) def construct(self, *inputs): return self.grad(*inputs) x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) - _ = Tensor(1.0, dtype=mstype.float32) net = VarNet(SecondNet()) grad_net = GradNet(net) out = grad_net(x, y)