!4370 support kw and kwargs for cell in Pynative

Merge pull request !4370 from zhangbuxue/support_kw_and_kwargs_for_cell_in_pynative
pull/4370/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b4b6e5c8ed

@ -291,14 +291,14 @@ class _PynativeExecutor:
def __init__(self):
self._executor = PynativeExecutor_.get_instance()
def new_graph(self, obj, *args):
self._executor.new_graph(obj, *args)
def new_graph(self, obj, *args, **kwargs):
self._executor.new_graph(obj, *args, *(kwargs.values()))
def end_graph(self, obj, output, *args):
self._executor.end_graph(obj, output, *args)
def end_graph(self, obj, output, *args, **kwargs):
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
def grad(self, grad, obj, weights, *args):
self._executor.grad_net(grad, obj, weights, *args)
def grad(self, grad, obj, weights, *args, **kwargs):
self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))
def clear(self, flag=""):
self._executor.clear(flag)
@ -306,7 +306,8 @@ class _PynativeExecutor:
def set_grad_flag(self, flag):
self._executor.set_grad_flag(flag)
def __call__(self, *args):
def __call__(self, *args, **kwargs):
args = args + tuple(kwargs.values())
return self._executor(args, "")

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""cell"""
import inspect
import time
import gc
from collections import OrderedDict
@ -222,19 +223,27 @@ class Cell:
else:
object.__delattr__(self, name)
def __call__(self, *inputs):
def __call__(self, *inputs, **kwargs):
if context.get_context("mode") == context.GRAPH_MODE:
if kwargs:
raise ValueError("For 'graph' mode, the outermost network does not support passing "
"key-value pair parameters and variable key-value pair parameters.")
out = self.compile_and_run(*inputs)
return out
if kwargs:
bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
inputs = bound_args.args
kwargs = bound_args.kwargs
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.")
orign_grad = []
origin_grad = []
if self.requires_grad is True:
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(self, *inputs)
_pynative_exec.new_graph(self, *inputs, **kwargs)
for cell in self.cells():
orign_grad.append(cell.requires_grad)
origin_grad.append(cell.requires_grad)
cell.set_grad(True)
else:
_pynative_exec.set_grad_flag(False)
@ -251,15 +260,15 @@ class Cell:
else:
cast_inputs = inputs
if self.enable_hook:
output = self._hook_construct(*cast_inputs)
output = self._hook_construct(*cast_inputs, **kwargs)
else:
output = self.construct(*cast_inputs)
output = self.construct(*cast_inputs, **kwargs)
if isinstance(output, Parameter):
output = output.data
if self.requires_grad is True:
_pynative_exec.end_graph(self, output, *inputs)
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
for i, cell in enumerate(self.cells()):
cell.set_grad(orign_grad[i])
cell.set_grad(origin_grad[i])
self._already_run = True
return output
@ -400,7 +409,6 @@ class Cell:
def _get_construct_inputs_number_and_name(self):
"""Compute self._construct_inputs_names and self._construct_inputs_num"""
import inspect
from mindspore._extends.parse.parser import get_parse_method_of_class
fn = get_parse_method_of_class(self)
@ -517,7 +525,7 @@ class Cell:
raise TypeError("Child cell type is incorrect.")
self._cells[child_name] = child
def construct(self, *inputs):
def construct(self, *inputs, **kwargs):
"""
Defines the computation to be performed.
@ -878,7 +886,7 @@ class Cell:
self.add_flags(auto_parallel=True)
self._get_construct_inputs_number_and_name()
def _hook_construct(self, *inputs):
def _hook_construct(self, *inputs, **kwargs):
"""Hook construct method to replace original construct method when hook function enabled."""
inputs = self._backward_hook(*inputs)
inputs = self.construct(inputs)

@ -116,7 +116,7 @@ class GradOperation(GradOperation_):
self.fn = None
self.need_forward = False
def _pynative_forward_run(self, args, fn):
def _pynative_forward_run(self, args, kwargs, fn):
""" Pynative forward run to build grad graph. """
if self.sens_param:
args = args[:-1]
@ -125,9 +125,9 @@ class GradOperation(GradOperation_):
raise TypeError("grad inputs should be tensor in pynative mode")
if isinstance(fn, FunctionType):
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(fn, *args)
output = fn(*args)
_pynative_exec.end_graph(fn, output, *args)
_pynative_exec.new_graph(fn, *args, **kwargs)
output = fn(*args, **kwargs)
_pynative_exec.end_graph(fn, output, *args, **kwargs)
else:
if fn.already_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
@ -135,7 +135,7 @@ class GradOperation(GradOperation_):
self.need_forward = True
if self.need_forward:
fn.set_grad()
fn(*args)
fn(*args, **kwargs)
fn.already_run = False
def __call__(self, fn, weights=None):
@ -152,10 +152,10 @@ class GradOperation(GradOperation_):
return grad_(fn)(*args)
else:
@_wrap_func
def after_grad(*args):
self._pynative_forward_run(args, fn)
_pynative_exec.grad(grad_, fn, weights, *args)
out = _pynative_exec(*args)
def after_grad(*args, **kwargs):
self._pynative_forward_run(args, kwargs, fn)
_pynative_exec.grad(grad_, fn, weights, *args, **kwargs)
out = _pynative_exec(*args, **kwargs)
_pynative_exec.clear()
return out
self.grad_fn = after_grad

@ -30,6 +30,7 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \
context.set_context(mode=context.GRAPH_MODE)
def test_list_equal():
class Net(nn.Cell):
def __init__(self, z: list):
@ -156,8 +157,10 @@ def test_class_member_not_defined():
z = [[1, 2], 3]
net = Net(z)
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.zeros([3, 4, 5], np.int32))
with pytest.raises(TypeError) as ex:
net()
net(x, y)
assert "'self.x' was not defined in the class '__init__' function." in str(ex.value)
@ -181,7 +184,7 @@ def test_change_list_element():
class ListOperate(nn.Cell):
def __init__(self,):
def __init__(self):
super(ListOperate, self).__init__()
def construct(self, t, l):
@ -201,7 +204,7 @@ class ListOperate(nn.Cell):
class InListNet(nn.Cell):
def __init__(self,):
def __init__(self):
super(InListNet, self).__init__()
self.list_ = [1, 2, 3, 4, 5, "ok"]

@ -0,0 +1,139 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test dtype and shape as attr"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from mindspore.ops.composite import base as C
def test_kw_nested():
class NetKeyValueArg(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x, y, *arg, w, **kwargs):
return x + y + arg[0] + w + kwargs['c']
class NetOut(nn.Cell):
def __init__(self, net):
super().__init__()
self.in_net = net
def construct(self, x, y, z):
ret = self.in_net(x, y, z, w=x, a=x, b=y, c=z) + x
return ret
in_net = NetKeyValueArg()
out_net = NetOut(in_net)
x = Tensor(np.ones([3, 4, 5], np.float32))
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = Tensor(np.ones([3, 4, 5], np.float64))
context.set_context(mode=context.PYNATIVE_MODE)
ret = out_net(x, y, z)
assert ret.dtype == mstype.float64
assert ret.shape == (3, 4, 5)
assert (ret.asnumpy() == np.ones([3, 4, 5], np.float64) * 5).all()
def test_kw_grad():
class KwNet(nn.Cell):
def __init__(self):
super(KwNet, self).__init__()
def construct(self, x, y, *arg, **kwargs):
return 2 * x + 3 * y + 4 * arg[0] + 5 * kwargs['v']
class GradKwNet(nn.Cell):
def __init__(self, net):
super(GradKwNet, self).__init__()
self.net = net
self.grad_all_wit_sense = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
def construct(self, x, y, *arg, **kwargs):
return self.grad_all_wit_sense(self.net)(x, y, *arg, **kwargs)
kw_net = KwNet()
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.float32))
z = Tensor(np.ones([1, 2, 3], np.float64))
u = Tensor(np.ones([1, 2, 3], np.float16))
v = Tensor(np.ones([1, 2, 3], np.int32))
w = Tensor(np.ones([1, 2, 3], np.float64))
sens = Tensor(np.ones([1, 2, 3], np.float64))
context.set_context(mode=context.PYNATIVE_MODE)
kw_net.set_grad(True)
ret = kw_net(x, y, z, u=u, v=v, w=w)
assert (ret.asnumpy() == np.ones([1, 2, 3], np.float64) * 14).all()
grad_kw_net = GradKwNet(kw_net)
ret_grad = grad_kw_net(x, y, z, u=u, v=v, w=w, sens=sens)
assert len(ret_grad) == 6
assert (ret_grad[0].asnumpy() == np.ones([1, 2, 3]) * 2).all()
assert ret_grad[0].dtype == mstype.int32
assert (ret_grad[1].asnumpy() == np.ones([1, 2, 3]) * 3).all()
assert ret_grad[1].dtype == mstype.float32
assert (ret_grad[2].asnumpy() == np.ones([1, 2, 3]) * 4).all()
assert ret_grad[2].dtype == mstype.float64
assert (ret_grad[3].asnumpy() == np.zeros([1, 2, 3])).all()
assert ret_grad[3].dtype == mstype.float16
assert (ret_grad[4].asnumpy() == np.ones([1, 2, 3]) * 5).all()
assert ret_grad[4].dtype == mstype.int32
assert (ret_grad[5].asnumpy() == np.zeros([1, 2, 3])).all()
assert ret_grad[5].dtype == mstype.float64
def test_grad():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y, z):
return 2 * x + 3 * y + 4 * z
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_all_wit_sense = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
def construct(self, x, y, z, sens):
return self.grad_all_wit_sense(self.net)(x, y, z, sens)
net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.float32))
z = Tensor(np.ones([1, 2, 3], np.float16))
sens = Tensor(np.ones([1, 2, 3], np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net.set_grad(True)
ret = net(x, y, z)
assert (ret.asnumpy() == np.ones([1, 2, 3], np.float64) * 9).all()
grad_net = GradNet(net)
ret_grad = grad_net(x, y, z, sens)
assert len(ret_grad) == 3
assert (ret_grad[0].asnumpy() == np.ones([1, 2, 3]) * 2).all()
assert ret_grad[0].dtype == mstype.int32
assert (ret_grad[1].asnumpy() == np.ones([1, 2, 3]) * 3).all()
assert ret_grad[1].dtype == mstype.float32
assert (ret_grad[2].asnumpy() == np.ones([1, 2, 3]) * 4).all()
assert ret_grad[2].dtype == mstype.float16
Loading…
Cancel
Save