diff --git a/mindspore/common/api.py b/mindspore/common/api.py index bf9188c205..87c10a9ed7 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -291,14 +291,14 @@ class _PynativeExecutor: def __init__(self): self._executor = PynativeExecutor_.get_instance() - def new_graph(self, obj, *args): - self._executor.new_graph(obj, *args) + def new_graph(self, obj, *args, **kwargs): + self._executor.new_graph(obj, *args, *(kwargs.values())) - def end_graph(self, obj, output, *args): - self._executor.end_graph(obj, output, *args) + def end_graph(self, obj, output, *args, **kwargs): + self._executor.end_graph(obj, output, *args, *(kwargs.values())) - def grad(self, grad, obj, weights, *args): - self._executor.grad_net(grad, obj, weights, *args) + def grad(self, grad, obj, weights, *args, **kwargs): + self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values())) def clear(self, flag=""): self._executor.clear(flag) @@ -306,7 +306,8 @@ class _PynativeExecutor: def set_grad_flag(self, flag): self._executor.set_grad_flag(flag) - def __call__(self, *args): + def __call__(self, *args, **kwargs): + args = args + tuple(kwargs.values()) return self._executor(args, "") diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 024f93f5d3..c68eef9e2e 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """cell""" +import inspect import time import gc from collections import OrderedDict @@ -222,19 +223,27 @@ class Cell: else: object.__delattr__(self, name) - def __call__(self, *inputs): + def __call__(self, *inputs, **kwargs): if context.get_context("mode") == context.GRAPH_MODE: + if kwargs: + raise ValueError("For 'graph' mode, the outermost network does not support passing " + "key-value pair parameters and variable key-value pair parameters.") out = self.compile_and_run(*inputs) return out + + if kwargs: + bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) + inputs = bound_args.args + kwargs = bound_args.kwargs for item in inputs: if isinstance(item, numpy.ndarray): raise TypeError("cell inputs should not be numpy array.") - orign_grad = [] + origin_grad = [] if self.requires_grad is True: _pynative_exec.set_grad_flag(True) - _pynative_exec.new_graph(self, *inputs) + _pynative_exec.new_graph(self, *inputs, **kwargs) for cell in self.cells(): - orign_grad.append(cell.requires_grad) + origin_grad.append(cell.requires_grad) cell.set_grad(True) else: _pynative_exec.set_grad_flag(False) @@ -251,15 +260,15 @@ class Cell: else: cast_inputs = inputs if self.enable_hook: - output = self._hook_construct(*cast_inputs) + output = self._hook_construct(*cast_inputs, **kwargs) else: - output = self.construct(*cast_inputs) + output = self.construct(*cast_inputs, **kwargs) if isinstance(output, Parameter): output = output.data if self.requires_grad is True: - _pynative_exec.end_graph(self, output, *inputs) + _pynative_exec.end_graph(self, output, *inputs, **kwargs) for i, cell in enumerate(self.cells()): - cell.set_grad(orign_grad[i]) + cell.set_grad(origin_grad[i]) self._already_run = True return output @@ -400,7 +409,6 @@ class Cell: def _get_construct_inputs_number_and_name(self): """Compute self._construct_inputs_names and self._construct_inputs_num""" - import inspect from mindspore._extends.parse.parser import get_parse_method_of_class fn = get_parse_method_of_class(self) @@ -517,7 +525,7 @@ class Cell: raise TypeError("Child cell type is incorrect.") self._cells[child_name] = child - def construct(self, *inputs): + def construct(self, *inputs, **kwargs): """ Defines the computation to be performed. @@ -878,7 +886,7 @@ class Cell: self.add_flags(auto_parallel=True) self._get_construct_inputs_number_and_name() - def _hook_construct(self, *inputs): + def _hook_construct(self, *inputs, **kwargs): """Hook construct method to replace original construct method when hook function enabled.""" inputs = self._backward_hook(*inputs) inputs = self.construct(inputs) diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 766bedc5d0..6c755646e7 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -116,7 +116,7 @@ class GradOperation(GradOperation_): self.fn = None self.need_forward = False - def _pynative_forward_run(self, args, fn): + def _pynative_forward_run(self, args, kwargs, fn): """ Pynative forward run to build grad graph. """ if self.sens_param: args = args[:-1] @@ -125,9 +125,9 @@ class GradOperation(GradOperation_): raise TypeError("grad inputs should be tensor in pynative mode") if isinstance(fn, FunctionType): _pynative_exec.set_grad_flag(True) - _pynative_exec.new_graph(fn, *args) - output = fn(*args) - _pynative_exec.end_graph(fn, output, *args) + _pynative_exec.new_graph(fn, *args, **kwargs) + output = fn(*args, **kwargs) + _pynative_exec.end_graph(fn, output, *args, **kwargs) else: if fn.already_run and not fn.requires_grad: raise ValueError("obj must set_grad.") @@ -135,7 +135,7 @@ class GradOperation(GradOperation_): self.need_forward = True if self.need_forward: fn.set_grad() - fn(*args) + fn(*args, **kwargs) fn.already_run = False def __call__(self, fn, weights=None): @@ -152,10 +152,10 @@ class GradOperation(GradOperation_): return grad_(fn)(*args) else: @_wrap_func - def after_grad(*args): - self._pynative_forward_run(args, fn) - _pynative_exec.grad(grad_, fn, weights, *args) - out = _pynative_exec(*args) + def after_grad(*args, **kwargs): + self._pynative_forward_run(args, kwargs, fn) + _pynative_exec.grad(grad_, fn, weights, *args, **kwargs) + out = _pynative_exec(*args, **kwargs) _pynative_exec.clear() return out self.grad_fn = after_grad diff --git a/tests/ut/python/dtype/test_list.py b/tests/ut/python/dtype/test_list.py index c63763e295..2553385a13 100644 --- a/tests/ut/python/dtype/test_list.py +++ b/tests/ut/python/dtype/test_list.py @@ -30,6 +30,7 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \ context.set_context(mode=context.GRAPH_MODE) + def test_list_equal(): class Net(nn.Cell): def __init__(self, z: list): @@ -156,8 +157,10 @@ def test_class_member_not_defined(): z = [[1, 2], 3] net = Net(z) + x = Tensor(np.ones([6, 8, 10], np.int32)) + y = Tensor(np.zeros([3, 4, 5], np.int32)) with pytest.raises(TypeError) as ex: - net() + net(x, y) assert "'self.x' was not defined in the class '__init__' function." in str(ex.value) @@ -181,7 +184,7 @@ def test_change_list_element(): class ListOperate(nn.Cell): - def __init__(self,): + def __init__(self): super(ListOperate, self).__init__() def construct(self, t, l): @@ -201,7 +204,7 @@ class ListOperate(nn.Cell): class InListNet(nn.Cell): - def __init__(self,): + def __init__(self): super(InListNet, self).__init__() self.list_ = [1, 2, 3, 4, 5, "ok"] diff --git a/tests/ut/python/pynative_mode/test_kw_and_kwarg.py b/tests/ut/python/pynative_mode/test_kw_and_kwarg.py new file mode 100644 index 0000000000..0100e0d0fc --- /dev/null +++ b/tests/ut/python/pynative_mode/test_kw_and_kwarg.py @@ -0,0 +1,139 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test dtype and shape as attr""" +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore import dtype as mstype +from mindspore.ops.composite import base as C + + +def test_kw_nested(): + class NetKeyValueArg(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y, *arg, w, **kwargs): + return x + y + arg[0] + w + kwargs['c'] + + class NetOut(nn.Cell): + def __init__(self, net): + super().__init__() + self.in_net = net + + def construct(self, x, y, z): + ret = self.in_net(x, y, z, w=x, a=x, b=y, c=z) + x + return ret + + in_net = NetKeyValueArg() + out_net = NetOut(in_net) + x = Tensor(np.ones([3, 4, 5], np.float32)) + y = Tensor(np.zeros([3, 4, 5], np.int32)) + z = Tensor(np.ones([3, 4, 5], np.float64)) + context.set_context(mode=context.PYNATIVE_MODE) + + ret = out_net(x, y, z) + assert ret.dtype == mstype.float64 + assert ret.shape == (3, 4, 5) + assert (ret.asnumpy() == np.ones([3, 4, 5], np.float64) * 5).all() + + +def test_kw_grad(): + class KwNet(nn.Cell): + def __init__(self): + super(KwNet, self).__init__() + + def construct(self, x, y, *arg, **kwargs): + return 2 * x + 3 * y + 4 * arg[0] + 5 * kwargs['v'] + + class GradKwNet(nn.Cell): + def __init__(self, net): + super(GradKwNet, self).__init__() + self.net = net + self.grad_all_wit_sense = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True) + + def construct(self, x, y, *arg, **kwargs): + return self.grad_all_wit_sense(self.net)(x, y, *arg, **kwargs) + + kw_net = KwNet() + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.float32)) + z = Tensor(np.ones([1, 2, 3], np.float64)) + u = Tensor(np.ones([1, 2, 3], np.float16)) + v = Tensor(np.ones([1, 2, 3], np.int32)) + w = Tensor(np.ones([1, 2, 3], np.float64)) + sens = Tensor(np.ones([1, 2, 3], np.float64)) + context.set_context(mode=context.PYNATIVE_MODE) + + kw_net.set_grad(True) + ret = kw_net(x, y, z, u=u, v=v, w=w) + assert (ret.asnumpy() == np.ones([1, 2, 3], np.float64) * 14).all() + + grad_kw_net = GradKwNet(kw_net) + ret_grad = grad_kw_net(x, y, z, u=u, v=v, w=w, sens=sens) + assert len(ret_grad) == 6 + assert (ret_grad[0].asnumpy() == np.ones([1, 2, 3]) * 2).all() + assert ret_grad[0].dtype == mstype.int32 + assert (ret_grad[1].asnumpy() == np.ones([1, 2, 3]) * 3).all() + assert ret_grad[1].dtype == mstype.float32 + assert (ret_grad[2].asnumpy() == np.ones([1, 2, 3]) * 4).all() + assert ret_grad[2].dtype == mstype.float64 + assert (ret_grad[3].asnumpy() == np.zeros([1, 2, 3])).all() + assert ret_grad[3].dtype == mstype.float16 + assert (ret_grad[4].asnumpy() == np.ones([1, 2, 3]) * 5).all() + assert ret_grad[4].dtype == mstype.int32 + assert (ret_grad[5].asnumpy() == np.zeros([1, 2, 3])).all() + assert ret_grad[5].dtype == mstype.float64 + + +def test_grad(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x, y, z): + return 2 * x + 3 * y + 4 * z + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.grad_all_wit_sense = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True) + + def construct(self, x, y, z, sens): + return self.grad_all_wit_sense(self.net)(x, y, z, sens) + + net = Net() + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.float32)) + z = Tensor(np.ones([1, 2, 3], np.float16)) + sens = Tensor(np.ones([1, 2, 3], np.float32)) + context.set_context(mode=context.PYNATIVE_MODE) + + net.set_grad(True) + ret = net(x, y, z) + assert (ret.asnumpy() == np.ones([1, 2, 3], np.float64) * 9).all() + + grad_net = GradNet(net) + ret_grad = grad_net(x, y, z, sens) + assert len(ret_grad) == 3 + assert (ret_grad[0].asnumpy() == np.ones([1, 2, 3]) * 2).all() + assert ret_grad[0].dtype == mstype.int32 + assert (ret_grad[1].asnumpy() == np.ones([1, 2, 3]) * 3).all() + assert ret_grad[1].dtype == mstype.float32 + assert (ret_grad[2].asnumpy() == np.ones([1, 2, 3]) * 4).all() + assert ret_grad[2].dtype == mstype.float16