!8039 support key ward way to pass arg for outermost net in graph mode

Merge pull request !8039 from zhangbuxue/support_key_ward_way_to_pass_arg_for_outermost_net_in_graph_mode
pull/8039/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7092a33e87

@ -282,19 +282,19 @@ class Cell(Cell_):
return tuple(res)
def __call__(self, *inputs, **kwargs):
if kwargs:
bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
inputs = bound_args.args
kwargs = bound_args.kwargs
if context.get_context("mode") == context.GRAPH_MODE:
if kwargs:
raise ValueError("For 'graph' mode, the outermost network does not support passing "
"key-value pair parameters and variable key-value pair parameters.")
"variable key-value pair parameters.")
if self.enable_hook:
raise ValueError("The graph mode does not support hook function.")
out = self.compile_and_run(*inputs)
return out
if kwargs:
bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
inputs = bound_args.args
kwargs = bound_args.kwargs
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.")

@ -22,7 +22,6 @@ from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
grad_all = C.GradOperation(get_all=True)
grad_all_with_sens = C.GradOperation(sens_param=True)
@ -285,3 +284,31 @@ def test_mixed_precision_const_parameter():
y = Tensor(np.ones((1, 3, 14, 14), np.float32))
z = Tensor(np.ones((1, 3, 28, 28), np.float32))
_ = net(x, y, z)
def test_pass_args_by_key_ward_way():
class KeyWardNet(Cell):
def __init__(self):
super(KeyWardNet, self).__init__()
def construct(self, x, y, z):
return x + y - z
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.net = net
self.sens = Tensor(np.ones((3, 3, 4), np.float32))
def construct(self, x, y, z, sens):
return self.grad(self.net)(x, y, z, sens)
x = Tensor(np.ones((1, 3, 4), np.float32))
y = Tensor(np.ones((1, 3, 4), np.float32))
z = Tensor(np.ones((3, 3, 4), np.float32))
net = KeyWardNet()
net(x, z=z, y=y)
grad_net = GradNet(net)
sens = Tensor(np.ones((3, 3, 4), np.float32))
grad_net(x, y=y, z=z, sens=sens)

Loading…
Cancel
Save