From acec35d4d4bfb69064238f4283ab8be1ffd89b06 Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 11 Jan 2021 10:38:22 +0800 Subject: [PATCH] support non tensor inputs --- .../frontend/operator/composite/composite.cc | 58 ++++----- .../frontend/operator/composite/composite.h | 8 +- mindspore/ccsrc/pipeline/jit/pass.cc | 18 +++ mindspore/ccsrc/pipeline/jit/pipeline.cc | 29 ++--- mindspore/common/api.py | 6 +- .../multitype_ops/zeros_like_impl.py | 22 +++- ...test_ms_function_pass_non_tensor_inputs.py | 60 ++++++++++ ...t_outermost_net_pass_non_tensor_inputs.py} | 16 +-- .../ut/python/pynative_mode/ops/test_grad.py | 20 +--- .../python/pynative_mode/test_framstruct.py | 113 +----------------- .../pynative_mode/test_high_order_grad.py | 22 ---- .../python/pynative_mode/test_parse_method.py | 2 +- 12 files changed, 153 insertions(+), 221 deletions(-) create mode 100644 tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py rename tests/ut/python/pipeline/parse/{test_outermost_net_pass_scalar_tuple_list_dict.py => test_outermost_net_pass_non_tensor_inputs.py} (79%) diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index a2fa345eca..e498e730ad 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -62,8 +62,6 @@ ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalar {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe}, {"__ge__", kPrimScalarGe}}; -const MetaFuncGraphPtr kTail = std::make_shared("tail"); - // copy from python API: reduce. // Apply a function of two arguments cumulatively to the items of a sequence, // from left to right, so as to reduce the sequence to a single value.For example, @@ -384,8 +382,8 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { .def(py::init<>()); })); -FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { - MS_EXCEPTION_IF_NULL(a_tuple); +FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) { + MS_EXCEPTION_IF_NULL(sequeue); FuncGraphPtr ret = std::make_shared(); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); @@ -393,31 +391,24 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tu AnfNodePtr ptrTup = ret->add_parameter(); std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeTuple)); - - int64_t tuple_size = SizeToLong(a_tuple->size()); - for (int64_t i = 1; i < tuple_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)})); + PrimitivePtr op = nullptr; + if (sequeue->isa()) { + elems.push_back(NewValueNode(prim::kPrimMakeTuple)); + op = prim::kPrimTupleGetItem; + } else { + elems.push_back(NewValueNode(prim::kPrimMakeList)); + op = prim::kPrimListGetItem; } - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { - MS_EXCEPTION_IF_NULL(a_list); - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ret->debug_info()->set_name("tail"); - AnfNodePtr ptrList = ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeList)); - - int64_t list_size = SizeToLong(a_list->size()); - for (int64_t i = 1; i < list_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)})); + for (size_t i = 1; i < sequeue->size(); ++i) { + if (do_grad_) { + MS_EXCEPTION_IF_NULL((*sequeue)[i]); + if ((*sequeue)[i]->isa()) { + elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); + } + } else { + elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); + } } ret->set_output(ret->NewCNode(elems)); @@ -430,14 +421,8 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) } AbstractBasePtr a = args_spec_list[0]; - abstract::AbstractTuplePtr a_tuple = dyn_cast(a); - if (a_tuple != nullptr) { - return GenerateTupleFuncGraph(a_tuple); - } - - abstract::AbstractListPtr a_list = dyn_cast(a); - if (a_list != nullptr) { - return GenerateListFuncGraph(a_list); + if (a->isa() || a->isa()) { + return GenerateSequeueFuncGraph(a->cast()); } MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString(); @@ -614,7 +599,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An CNodePtr inputs_bprop = nullptr; if (get_all_) { - inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp}); + TailPtr tail = std::make_shared("tail", true); + inputs_bprop = func_graph->NewCNode({NewValueNode(tail), ptr_bapp}); } // Gradients wrt inputs and parameters diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.h b/mindspore/ccsrc/frontend/operator/composite/composite.h index 21a4588958..981efcd3e2 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.h +++ b/mindspore/ccsrc/frontend/operator/composite/composite.h @@ -99,15 +99,17 @@ extern ValuePtr kCompositeHyperMap; class Tail : public MetaFuncGraph { public: - explicit Tail(const std::string &name) : MetaFuncGraph(name) {} + explicit Tail(const std::string &name, bool do_grad = false) : MetaFuncGraph(name), do_grad_(do_grad) {} ~Tail() override = default; MS_DECLARE_PARENT(Tail, MetaFuncGraph) FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); - FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); + FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue); friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } + + private: + bool do_grad_; }; using TailPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 4a9cd00d5d..bbeab5cd59 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -446,10 +446,28 @@ bool TransformTopGraphPass(const ResourcePtr &res) { bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); } +void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + std::vector new_paras; + for (const auto ¶m : func_graph->parameters()) { + auto param_node = param->cast(); + if (param_node->has_default()) { + new_paras.push_back(param_node); + continue; + } + AbstractBasePtr par_abs = param_node->abstract(); + if (par_abs->isa()) { + new_paras.push_back(param_node); + } + } + func_graph->set_parameters(new_paras); +} + bool ValidatePass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); Validate(func_graph); + UpdateFuncGraphParameter(func_graph); return true; } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 362c3af8c4..632e3ce247 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -69,6 +69,10 @@ namespace pipeline { using Tensor = mindspore::tensor::Tensor; using MetaTensor = mindspore::tensor::MetaTensor; using TensorOrderMap = std::map>; +using mindspore::abstract::AbstractDictionary; +using mindspore::abstract::AbstractDictionaryPtr; +using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractListPtr; using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTensorPtr; using mindspore::abstract::AbstractTuple; @@ -93,15 +97,10 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) return oss.str(); } -void CheckArgIsTensor(const ValuePtr &arg, std::size_t idx) { - MS_EXCEPTION_IF_NULL(arg); - auto tensor_arg = arg->cast(); - if (tensor_arg == nullptr) { - MS_EXCEPTION(TypeError) << "For 'graph mode', the " << idx << "th arg: " << arg->ToString() << " is not a tensor."; - } - if (tensor_arg->is_parameter()) { - MS_EXCEPTION(TypeError) << "The inputs could not be Parameter."; - } +AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { + MS_EXCEPTION_IF_NULL(value); + bool broaden = value->isa(); + return abstract::FromValue(value, broaden); } } // namespace @@ -117,8 +116,7 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_mapisa() || converted->isa(); - args_spec.push_back(abstract::FromValue(converted, broaden)); + args_spec.push_back(ArgsToAbstract(converted)); } if (g_args_cache.count(args_spec) == 0) { static int64_t key = 0; @@ -484,11 +482,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons if (!succ) { MS_LOG(EXCEPTION) << "Args convert error"; } - if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { - CheckArgIsTensor(converted, i); - } - bool broaden = true; - args_spec.push_back(abstract::FromValue(converted, broaden)); + args_spec.push_back(ArgsToAbstract(converted)); } resource->set_args_spec(args_spec); @@ -814,9 +808,6 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { if (!parse::ConvertData(args[i], &converted)) { MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; } - if (!converted->isa()) { - MS_EXCEPTION(TypeError) << "The " << i << "th arg: " << converted->ToString() << " is not tensor."; - } } } return *ret_val; diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 4af7e483cd..1909173a89 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -208,7 +208,11 @@ class _MindSporeFunction: if context.get_context("precompile_only"): return None - return self._executor(args_list, phase) + new_inputs = [] + for i in args_list: + if isinstance(i, Tensor): + new_inputs.append(i) + return self._executor(tuple(new_inputs), phase) def ms_function(fn=None, obj=None, input_signature=None): diff --git a/mindspore/ops/composite/multitype_ops/zeros_like_impl.py b/mindspore/ops/composite/multitype_ops/zeros_like_impl.py index 7872952aec..067e0cd013 100644 --- a/mindspore/ops/composite/multitype_ops/zeros_like_impl.py +++ b/mindspore/ops/composite/multitype_ops/zeros_like_impl.py @@ -18,7 +18,6 @@ from ...composite import base from ... import functional as F - zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True) """ `zeros_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type @@ -31,11 +30,13 @@ def _zeros_like_scala(x): """Returns 0 which has the same dtype as x where x is a scalar.""" return 0 + @zeros_like_leaf.register("Bool") def _zeros_like_bool(x): """Returns False if x is a bool.""" return False + newenv = base.EnvInstance_() @@ -100,6 +101,25 @@ def _zeros_like_abstract_error(x): return x +@zeros_like_leaf.register("Dictionary") +def _zeros_like_dict(x): + """ + Derivation of a AbstractError. + + Args: + x (dict): the input + + Returns: + dict, keys are same as input's keys, and value are same as zeros_like of input'value. + """ + keys = x.keys() + values = x.values() + new_values = () + for ele in values: + new_values += (zeros_like_leaf(ele),) + return F.make_dict(keys, new_values) + + # zeros_like is an object that will generate graph of zero_like operation for different type zeros_like = base.HyperMap(zeros_like_leaf) """`zeros_like` is an object that will generate graph of `zero_like` operation for different type.""" diff --git a/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py b/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py new file mode 100644 index 0000000000..35c4c64486 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test ms_function pass non_tensor inputs""" +import numpy as np + +from mindspore import Tensor, ms_function, Parameter +from mindspore import context +from mindspore.ops import operations as P + +context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True) + + +@ms_function +def compute(x, y, tuple_p, list_q, dict_w): + return x + y - tuple_p[0] + list_q[1] - dict_w["x"] + + +def test_scalar_compute(): + int_x = 1 + int_y = 2 + p = (3, 4) + q = [5, 6] + w = {"x": 7, "y": 8} + ret = compute(int_x, int_y, p, q, w) + assert ret == -1 + + +def test_tensor_compute(): + tensor_x = Tensor(np.ones((2, 3, 4), np.float32)) + tensor_y = Tensor(np.ones((2, 3, 4), np.float32) * 2) + p = (Tensor(np.ones((2, 3, 4), np.float32) * 3), Tensor(np.ones((2, 3, 4), np.float32) * 4)) + q = [Tensor(np.ones((2, 3, 4), np.float32) * 5), Tensor(np.ones((2, 3, 4), np.float32) * 6)] + w = {"x": Tensor(np.ones((2, 3, 4), np.float32) * 7), "y": Tensor(np.ones((2, 3, 4), np.float32) * 8)} + compute(tensor_x, tensor_y, p, q, w) + + +@ms_function +def tensor_reduce(tensor_x, axis, tensor_y): + reduce_sum = P.ReduceSum() + ret = reduce_sum(tensor_x, axis) + tensor_y + return ret + + +def test_tensor_reduce(): + tensor_x = Tensor(np.ones((2, 3, 4, 5), np.float32)) + axis = (0, 1) + tensor_y = Parameter(Tensor(np.ones((4, 5), np.float32) * 2)) + tensor_reduce(tensor_x, axis, tensor_y) diff --git a/tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py similarity index 79% rename from tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py rename to tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py index 1964ad63e3..85fc8504bd 100644 --- a/tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py +++ b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" test outermost net pass scalar tuple list dict""" -import pytest +""" test outermost net pass non_tensor inputs""" import numpy as np import mindspore.nn as nn @@ -28,7 +27,7 @@ def test_outermost_net_pass_scalar_tuple_list_dict(): class TestNet(nn.Cell): def __init__(self): super(TestNet, self).__init__() - self.support_non_tensor_inputs = True + self.support_non_tensor_inputs = False def construct(self, tuple_a, z, list_m, w, s, dict_n): return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] @@ -58,12 +57,5 @@ def test_outermost_net_pass_scalar_tuple_list_dict(): forward_net(arg_t1, z, arg_l1, x, 6, args_d1) grad_net = GradNet(forward_net) - with pytest.raises(TypeError) as err: - grad_net(arg_t0, z, arg_l0, w, 6, args_d0) - assert "For 'graph mode', the 0th arg" in str(err.value) - - grad_net.support_non_tensor_inputs = True - with pytest.raises(ValueError) as err: - grad_net(arg_t0, z, arg_l0, w, 6, args_d0) - assert "Not support set 'support_non_tensor_inputs' to the 'True' for grad net, only support forward net." \ - in str(err.value) + grad_net(arg_t0, z, arg_l0, w, 6, args_d0) + grad_net(arg_t1, z, arg_l1, x, 6, args_d1) diff --git a/tests/ut/python/pynative_mode/ops/test_grad.py b/tests/ut/python/pynative_mode/ops/test_grad.py index b928690793..1f47a7e2fe 100644 --- a/tests/ut/python/pynative_mode/ops/test_grad.py +++ b/tests/ut/python/pynative_mode/ops/test_grad.py @@ -20,9 +20,9 @@ import mindspore.ops.operations as P from mindspore import Tensor, context from mindspore.common.api import ms_function from mindspore.ops import composite as C -from mindspore.ops import functional as F from ...ut_filter import non_graph_engine + # pylint: disable=unused-argument def setup_module(module): context.set_context(mode=context.PYNATIVE_MODE) @@ -86,24 +86,6 @@ def test_cast_grad(): assert np.all(gout[0].asnumpy() == expect) -def test_scalar_cast_grad(): - """ test_scalar_cast_grad """ - input_x = 255.5 - input_t = ms.int8 - - def fx_cast(x): - output = F.scalar_cast(x, input_t) - return output - - @ms_function - def grad_fx_cast(input_x): - return grad(fx_cast)(input_x) - - gfn = grad_fx_cast(input_x) - expect_dx = 1 - assert gfn == expect_dx - - @non_graph_engine def test_reshape_grad(): """ test_reshape_grad """ diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index ab5a80bbb7..781dc30698 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -14,31 +14,28 @@ # ============================================================================ """ test_framstruct """ import numpy as np -import pytest import mindspore as ms import mindspore.nn as nn from mindspore import context from mindspore.common import dtype as mstype from mindspore.common.parameter import Parameter, ParameterTuple -from mindspore.common.tensor import Tensor from mindspore.ops import composite as C from mindspore.ops import operations as P from ..ut_filter import non_graph_engine from ....mindspore_test_framework.utils.check_gradient import ( ms_function, check_jacobian, Tensor, NNGradChecker, - OperationGradChecker, check_gradient, ScalarGradChecker) + OperationGradChecker, check_gradient) context.set_context(mode=context.PYNATIVE_MODE) + def setup_module(module): context.set_context(mode=context.PYNATIVE_MODE) -grad = C.GradOperation() grad_all = C.GradOperation(get_all=True) grad_by_list = C.GradOperation(get_by_list=True) -grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) @ms_function @@ -79,9 +76,7 @@ def dynamic_make_tuple(x, lower, upper): def test_dynamic_make_tuple(): - # Dynamicly recursively creating static type is invalid in mindspore, as mindspore is a static language. - with pytest.raises(RuntimeError): - dynamic_make_tuple(2, 1, 5) + assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2) def test_make_tuple(): @@ -273,15 +268,6 @@ def rec(x): return rec(x - 1) return x -@ms_function -def grad_rec(input_x): - return grad(rec)(input_x) - -def test_grad_rec(): - """ test_grad_rec """ - res = grad_rec(3) - assert res == 1 - def test_me_rec(): """ test_me_rec """ @@ -303,13 +289,6 @@ def test_while2(): assert res == 6 -def test_grad_while2(): - @ms_function - def df_t2_while(input_x, input_y): - return grad(t2_while)(input_x, input_y) - assert df_t2_while(2, 3) == 3 - - def if_test(a, b): """ if_test """ if a > b: @@ -327,24 +306,6 @@ def test_grad_if(): assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0) -# While loop is not unrolled in forward and backward graphs. -def test_dont_unroll_while(): - def dont_unroll_while(x, y): - i = 2 - out = y - x - while i < 10: - out = mul(x, y) - i = i + 1 - return out - - @ms_function() - def invoke_while(x, y): - return grad(dont_unroll_while)(x, y) - - res = invoke_while(2, 3) - assert res == 3 - - class ConvNet(nn.Cell): def __init__(self): super(ConvNet, self).__init__() @@ -445,13 +406,6 @@ def test_factorial(): assert res == 6 -def test_grad_factorial(): - @ms_function - def df_factorial(x): - return grad(factorial)(x) - assert df_factorial(3) == 11 - - @ms_function def factorial2(n): """ factorial """ @@ -523,17 +477,13 @@ def _for(x): ret = ret * i return ret + @ms_function def grad_for(x): """ grad_for """ return grad_all(_for)(x) -def test_grad_for(): - """ test_grad_for """ - assert grad_for(5) == (60,) - - @ms_function def try_tail(x): """ try_tail """ @@ -675,15 +625,6 @@ def test_arithmetic_simplify_08(): assert np.all(res.asnumpy() == expect) -def test_ScalarGradChecker(): - """ test_ScalarGradChecker """ - - def scalar_f(x, y): - return x * y - - check_gradient(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, sampling_times=1) - - def test_GradCheckerPrimitive(): """ test_GradCheckerPrimitive """ matmul = P.MatMul() @@ -737,15 +678,6 @@ def test_OperationGradChecker(): input_selector=[1], sampling_times=2) -def test_ScalarJacobianChecker(): - """ test_ScalarJacobianChecker """ - - def scalar_f(x, y): - return x * y - - check_jacobian(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, input_selector=[0]) - - def test_OperationJacobianChecker(): """ test_OperationJacobianChecker """ @@ -795,13 +727,6 @@ def multi_outputs(x, y): return 2 * z, 2 * z -def test_grad_multi_outputs(): - @ms_function - def df_multi_outputs(x, y): - return grad_all_with_sens(multi_outputs)(x, y, (1, 1)) - assert df_multi_outputs(2, 3) == (4, 4) - - @ms_function def while_sp(x, y, z): out = x @@ -874,13 +799,6 @@ def grad_refactor_3(a): return 3 * a -def test_grad_refactor_3(): - @ms_function - def df_refactor_3(x): - return grad_all(grad_refactor_3)(x) - assert df_refactor_3(3) == (3,) - - def grad_refactor_4(a): """ if_test """ if a > 3: @@ -899,13 +817,6 @@ def grad_refactor_5(a): return a -def test_grad_refactor_5(): - @ms_function - def df_refactor_5(x): - return grad_all(grad_refactor_5)(x) - assert df_refactor_5(1) == (1,) - - def grad_refactor_6(a, b): """ if_test """ if a > b: @@ -925,13 +836,6 @@ def grad_refactor_while(x): return rval -def test_grad_refactor_9(): - @ms_function - def df_refactor_while(input_x): - return grad_all(grad_refactor_while)(input_x) - assert df_refactor_while(3) == (6,) - - def grad_refactor__while_1(x): """ _while """ ret = x * x @@ -1009,13 +913,6 @@ def grad_refactor_14(a, b): return inner1(b) + inner2(a) + inner3(a) -def test_grad_refactor_14(): - @ms_function - def df_refactor_14(x, y): - return grad_all(grad_refactor_14)(x, y) - assert df_refactor_14(2, 3) == (3, 9) - - # pylint: disable=using-constant-test class IfDeferInline(nn.Cell): def __init__(self, mul_size): @@ -1044,6 +941,8 @@ def test_dict_const(): def __init__(self): super(Net, self).__init__() self.res = {'1': 10} + def construct(self): return self.res + Net()() diff --git a/tests/ut/python/pynative_mode/test_high_order_grad.py b/tests/ut/python/pynative_mode/test_high_order_grad.py index e41df500c3..a26ca71024 100644 --- a/tests/ut/python/pynative_mode/test_high_order_grad.py +++ b/tests/ut/python/pynative_mode/test_high_order_grad.py @@ -109,25 +109,3 @@ def first_derivative_if(x): def second_derivative_if(x): """ second_derivative_if """ return grad(first_derivative_if)(x) - - -def test_high_order_grad_1(): - """ test_high_order_grad_1 """ - # 18 - assert third_derivative(2) == 18 - # 18 * y * y * y, 18 * x * x * x - assert third_derivative_dual(4, 5) == (2250, 1152) - # 18 * x - assert second_derivative_all(3) == 54 - - -def test_high_order_grad_2(): - """ test_high_order_grad_2 """ - # 2 - assert second_derivative_if(12) == 2 - - -def test_high_order_grad_3(): - """ test_high_order_grad_2 """ - # 6 * x - assert second_derivative_if(4) == 24 diff --git a/tests/ut/python/pynative_mode/test_parse_method.py b/tests/ut/python/pynative_mode/test_parse_method.py index 0a8c1767db..7c494f0ef7 100644 --- a/tests/ut/python/pynative_mode/test_parse_method.py +++ b/tests/ut/python/pynative_mode/test_parse_method.py @@ -325,7 +325,7 @@ def invoke_dataclass2(x, y): def test_access_attr_error(): """ test_access """ with pytest.raises(AttributeError): - invoke_dataclass2(1, 2) + invoke_dataclass2(2, 1) def myfunc(x):