From 1364b5c301f96e1f088af7dbb7d6b9df2ebede39 Mon Sep 17 00:00:00 2001 From: buxue Date: Fri, 30 Oct 2020 16:48:04 +0800 Subject: [PATCH] support key ward way to pass arg for outermost net in graph mode --- mindspore/nn/cell.py | 10 +++---- .../parameter_feature/test_parameter.py | 29 ++++++++++++++++++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 662ee24464..18788e8574 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -282,19 +282,19 @@ class Cell(Cell_): return tuple(res) def __call__(self, *inputs, **kwargs): + if kwargs: + bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) + inputs = bound_args.args + kwargs = bound_args.kwargs if context.get_context("mode") == context.GRAPH_MODE: if kwargs: raise ValueError("For 'graph' mode, the outermost network does not support passing " - "key-value pair parameters and variable key-value pair parameters.") + "variable key-value pair parameters.") if self.enable_hook: raise ValueError("The graph mode does not support hook function.") out = self.compile_and_run(*inputs) return out - if kwargs: - bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) - inputs = bound_args.args - kwargs = bound_args.kwargs for item in inputs: if isinstance(item, numpy.ndarray): raise TypeError("cell inputs should not be numpy array.") diff --git a/tests/ut/python/parameter_feature/test_parameter.py b/tests/ut/python/parameter_feature/test_parameter.py index 551f175dbe..61127d1075 100644 --- a/tests/ut/python/parameter_feature/test_parameter.py +++ b/tests/ut/python/parameter_feature/test_parameter.py @@ -22,7 +22,6 @@ from mindspore.ops import operations as P context.set_context(mode=context.GRAPH_MODE, save_graphs=True) - grad_all = C.GradOperation(get_all=True) grad_all_with_sens = C.GradOperation(sens_param=True) @@ -285,3 +284,31 @@ def test_mixed_precision_const_parameter(): y = Tensor(np.ones((1, 3, 14, 14), np.float32)) z = Tensor(np.ones((1, 3, 28, 28), np.float32)) _ = net(x, y, z) + + +def test_pass_args_by_key_ward_way(): + class KeyWardNet(Cell): + def __init__(self): + super(KeyWardNet, self).__init__() + + def construct(self, x, y, z): + return x + y - z + + class GradNet(Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.net = net + self.sens = Tensor(np.ones((3, 3, 4), np.float32)) + + def construct(self, x, y, z, sens): + return self.grad(self.net)(x, y, z, sens) + + x = Tensor(np.ones((1, 3, 4), np.float32)) + y = Tensor(np.ones((1, 3, 4), np.float32)) + z = Tensor(np.ones((3, 3, 4), np.float32)) + net = KeyWardNet() + net(x, z=z, y=y) + grad_net = GradNet(net) + sens = Tensor(np.ones((3, 3, 4), np.float32)) + grad_net(x, y=y, z=z, sens=sens)