add-new-interface-forward-value-and-grad

pull/11649/head
lvliang 4 years ago
parent 2e71163539
commit dd36171976

@ -738,27 +738,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
inputs.emplace_back(input_node);
}
}
auto const_input_index = prim->get_const_input_indexes();
bool have_const_input = !const_input_index.empty();
bool is_const_prim = prim->is_const_prim();
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
<< prim->is_const_prim();
bool is_const_input =
have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
if (abs == nullptr || is_const_prim || is_const_input) {
MS_LOG(DEBUG) << "MakeCnode get node no in map " << id;
ValuePtr input_value = PyAttrValue(obj);
abs = input_value->ToAbstract();
if (!is_const_prim && !is_const_input) {
auto config = abstract::AbstractBase::kBroadenTensorOnly;
abs = abs->Broaden(config);
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
}
node_abs_map_[id] = abs;
}
(*args_spec_list).emplace_back(abs);
(*args_spec_list).emplace_back(CheckConstValue(prim, obj, abs, id, i));
}
CNodePtr cnode = nullptr;
@ -770,6 +750,34 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
return cnode;
}
abstract::AbstractBasePtr PynativeExecutor::CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
const abstract::AbstractBasePtr &abs, const std::string &id,
size_t index) {
MS_EXCEPTION_IF_NULL(prim);
auto const_input_index = prim->get_const_input_indexes();
bool have_const_input = !const_input_index.empty();
bool is_const_prim = prim->is_const_prim();
auto new_abs = abs;
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
<< prim->is_const_prim();
bool is_const_input =
have_const_input && std::find(const_input_index.begin(), const_input_index.end(), index) != const_input_index.end();
if (abs == nullptr || is_const_prim || is_const_input) {
MS_LOG(DEBUG) << "MakeCnode get node no in map " << id;
ValuePtr input_value = PyAttrValue(obj);
MS_EXCEPTION_IF_NULL(input_value);
new_abs = input_value->ToAbstract();
if (!is_const_prim && !is_const_input) {
auto config = abstract::AbstractBase::kBroadenTensorOnly;
MS_EXCEPTION_IF_NULL(new_abs);
new_abs = new_abs->Broaden(config);
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
}
node_abs_map_[id] = new_abs;
}
return new_abs;
}
void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) {
MS_EXCEPTION_IF_NULL(is_find);
@ -1004,6 +1012,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
return free_param;
}
node = graph_info->node_map.at(obj_id).first;
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id;
return node;
}
@ -2008,9 +2017,14 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
top_cell_id_ = cell_id;
in_grad_process_ = true;
// update forward already run flag with previous top cell
std::string input_args_id;
for (size_t i = 0; i < args.size(); ++i) {
input_args_id = input_args_id + GetId(args[i]) + "_";
}
auto pre_top_cell = GetTopCell(cell_id);
if (pre_top_cell != nullptr) {
pre_top_cell->forward_already_run = true;
pre_top_cell->input_args_id = input_args_id;
}
auto df_builder = std::make_shared<FuncGraph>();
auto graph_info = std::make_shared<GraphInfo>(cell_id);
@ -2019,6 +2033,7 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
auto top_cell_info = std::make_shared<TopCellInfo>(true, resource, df_builder, cell_id);
top_cell_info->forward_already_run = true;
top_cell_info->input_args_id = input_args_id;
if (!IsTopestGraph(cell_id)) {
top_cell_info->top_cell_index = cell_graph_list_.size();
top_cell_index_ = top_cell_info->top_cell_index;
@ -2862,11 +2877,24 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &
}
py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) {
bool forward_run = false;
const auto &cell_id = GetCellId(cell, args);
// Checkout whether top cell has already run.
std::string input_args_id;
for (size_t i = 0; i < args.size(); ++i) {
input_args_id = input_args_id + GetId(args[i]) + "_";
}
auto top_cell = GetTopCell(cell_id);
bool forward_run = false;
if (top_cell != nullptr) {
forward_run = top_cell->forward_already_run;
if (!top_cell->input_args_id.empty() && top_cell->input_args_id != input_args_id && top_cell->forward_already_run &&
CheckDynamicCell(cell_id)) {
MS_LOG(WARNING) << "The construct of running cell is dynamic and the input info of this cell has changed, "
"forward process will run again";
top_cell->forward_already_run = false;
top_cell->input_args_id = input_args_id;
} else {
forward_run = top_cell->forward_already_run;
}
if (forward_run) {
top_cell_index_ = top_cell->top_cell_index;
}

@ -107,6 +107,7 @@ class TopCellInfo {
std::string cell_id;
std::string sens_id;
std::string weights_id;
std::string input_args_id;
};
using GraphInfoPtr = std::shared_ptr<GraphInfo>;
@ -209,6 +210,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
const abstract::AbstractBasePtr &abs, const std::string &id, size_t index);
void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
bool *is_find);
void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode);

@ -307,6 +307,23 @@ class Cell(Cell_):
res.append(cast(item, dst_type))
return tuple(res)
def do_parameter_broadcast(self):
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
if not self.parameter_broadcast_done:
_pynative_exec.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
self.parameter_broadcast_done = True
def run_construct(self, cast_inputs, kwargs):
if self.enable_hook:
_pynative_exec.enter_construct(self)
output = self._hook_construct(*cast_inputs, **kwargs)
_pynative_exec.leave_construct(self)
else:
_pynative_exec.enter_construct(self)
output = self.construct(*cast_inputs, **kwargs)
_pynative_exec.leave_construct(self)
return output
def __call__(self, *inputs, **kwargs):
if self.__class__.construct is Cell.construct:
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
@ -324,11 +341,7 @@ class Cell(Cell_):
out = self.compile_and_run(*inputs)
return out
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
if not self.parameter_broadcast_done:
_pynative_exec.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
self.parameter_broadcast_done = True
self.do_parameter_broadcast()
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.")
@ -349,14 +362,7 @@ class Cell(Cell_):
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
if not cast_inputs:
cast_inputs = inputs
if self.enable_hook:
_pynative_exec.enter_construct(self)
output = self._hook_construct(*cast_inputs, **kwargs)
_pynative_exec.leave_construct(self)
else:
_pynative_exec.enter_construct(self)
output = self.construct(*cast_inputs, **kwargs)
_pynative_exec.leave_construct(self)
output = self.run_construct(cast_inputs, kwargs)
if isinstance(output, Parameter):
output = output.data
if self.requires_grad is True:

@ -17,7 +17,7 @@ Wrap cells for networks.
Use the Wrapper to combine the loss or build the training steps.
"""
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
from .grad_reducer import DistributedGradReducer
@ -25,6 +25,7 @@ from ..layer.timedistributed import TimeDistributed
__all__ = [
"TimeDistributed",
"ForwardValueAndGrad",
"TrainOneStepCell",
"WithLossCell",
"WithGradCell",

@ -13,9 +13,12 @@
# limitations under the License.
# ============================================================================
"""Cell_wrapper."""
from types import FunctionType, MethodType
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode)
from mindspore.context import ParallelMode
from ...common.tensor import Tensor
from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C
@ -174,6 +177,107 @@ class WithGradCell(Cell):
return grads
class ForwardValueAndGrad(Cell):
r"""
Network training package class.
Including the network and a gradient function. The resulting Cell is trained with input '\*inputs'.
The backward graph will be created in the gradient function to calculating gradient.
Args:
network (Cell): The training network. The network only supports single output.
weights (ParameterTuple): The parameters of the training network that need to calculate the gradient
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
If get_all and get_by_list are both False, get the gradient with respect to first input.
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
at the same time in the form of ((gradients with respect to inputs),
(gradients with respect to parameters)). Default: False.
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
Default: False.
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
the location parameter or key-value pair parameter. If the value is transferred through the key-value pair
parameter, the key must be sens.
Inputs:
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
- sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Outputs:
- **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running.
- **gradients** (tuple(tensor)) - The gradients of network parameters and inputs.
Supported Platforms:
``Ascend`` ``GPU````CPU``
Examples:
>>> inputs = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32))
>>> labels = Tensor(np.ones([32]).astype(np.int32))
>>> net = Net()
>>> weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> #1) Using the WithLossCell existing provide
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True)
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0)
>>>
>>> #2) Using user-defined WithLossCell
>>> class MyWithLossCell(Cell):
... def __init__(self, backbone, loss_fn):
... super(MyWithLossCell, self).__init__(auto_prefix=False)
... self._backbone = backbone
... self._loss_fn = loss_fn
...
... def construct(self, x, y, label):
... out = self._backbone(x, y)
... return self._loss_fn(out, label)
...
... @property
... def backbone_network(self):
... return self._backbone
...
>>> loss_net = MyWithLossCell(net, loss_fn)
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True)
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0)
"""
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False):
super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
if not isinstance(network, (Cell, FunctionType, MethodType)):
raise TypeError(f"The type of training network should be cell, function type or method type, "
f"but got '{type(network)}'")
if get_by_list and not isinstance(weights, ParameterTuple):
raise TypeError(f"When get_by_list is set to True, the parameters of training network should be "
f"ParameterTuple type, but got '{type(weights)}'")
if get_by_list is not True and weights is not None:
raise TypeError(f"When get_by_list is set to False, the parameters of training network should be "
f"NoneType, but got '{type(weights)}'")
self.network = network
self.network.set_grad()
self.weights = weights
self.get_all = get_all
self.get_by_list = get_by_list
self.sens_param = sens_param
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
def construct(self, *inputs):
weights = self.weights
if self.sens_param:
sens = inputs[-1]
inputs = inputs[:-1]
else:
sens = None
loss = self.network(*inputs)
if self.sens_param:
if not isinstance(sens, Tensor):
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), sens)
grads = self.grad(self.network, weights)(*inputs, sens)
else:
grads = self.grad(self.network, weights)(*inputs)
return loss, grads
class TrainOneStepCell(Cell):
r"""
Network training package class.

@ -22,10 +22,10 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, ParameterTuple
from mindspore import amp
from mindspore.nn import Dense
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn import TrainOneStepCell, WithLossCell, ForwardValueAndGrad
from mindspore.nn.cell import Cell
from mindspore.nn.layer.basic import Flatten
from mindspore.nn.layer.conv import Conv2d
@ -33,6 +33,7 @@ from mindspore.nn.layer.normalization import BatchNorm2d
from mindspore.nn.layer.pooling import MaxPool2d
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import Add
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -399,3 +400,53 @@ def test_trainTensor_amp(num_classes=10, epoch=18, batch_size=16):
assert (losses[-1][0].asnumpy() < 1)
assert not losses[-1][1].asnumpy()
assert (losses[-1][2].asnumpy() > 1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_trainTensor_with_new_interface(num_classes=10, epoch=8, batch_size=1):
net = resnet50(num_classes)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
net_with_criterion.set_train()
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer = Momentum(weights, 0.1, 0.9)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 224, 224]
).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
loss, grads = train_network(data, label, 1.0)
grads = F.identity(grads)
optimizer(grads)
losses.append(loss)
assert (losses[-1].asnumpy() < 0.8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_big_batchSize_with_new_interface(num_classes=10, epoch=8, batch_size=338):
net = resnet50(num_classes)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
net_with_criterion.set_train()
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer = Momentum(weights, 0.1, 0.9)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 224, 224]
).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
loss, grads = train_network(data, label, 1.0)
grads = F.identity(grads)
optimizer(grads)
losses.append(loss)
assert (losses[-1].asnumpy() < 0.8)

@ -164,3 +164,40 @@ def test_ascend_pynative_lenet():
print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
assert loss_output.asnumpy() < 0.004
assert loss_output.asnumpy() > 0.003
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_lenet_with_new_interface():
context.set_context(mode=context.PYNATIVE_MODE)
epoch_size = 20
batch_size = 32
inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32))
labels = Tensor(np.ones([batch_size]).astype(np.int32))
net = LeNet()
criterion = CrossEntropyLoss()
net_with_criterion = WithLossCell(net, criterion)
net_with_criterion.set_train()
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer = Momentum(weights, 0.1, 0.9)
forward_value_and_grad = nn.ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True)
total_time = 0
for epoch in range(0, epoch_size):
start_time = time.time()
loss_output, grads = forward_value_and_grad(inputs, labels)
optimizer(grads)
end_time = time.time()
cost_time = end_time - start_time
total_time = total_time + cost_time
print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
assert loss_output.asnumpy() < 0.005
assert loss_output.asnumpy() > 0.003

Loading…
Cancel
Save