add recompute nodes

pull/10474/head
yujianfeng 4 years ago
parent 237faca57e
commit 7b412d7cb2

@ -38,9 +38,6 @@ AnfNodePtr CreateTensorInput(const AnfNodePtr &node, const KernelGraphPtr &kerne
if (value->isa<Scalar>()) {
tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
} else if (value->isa<ValueTuple>()) {
if (!AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr;
}
tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
} else {
MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple";
@ -89,7 +86,11 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
MS_EXCEPTION_IF_NULL(func_graph);
auto new_cnode = func_graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
new_cnode->set_abstract(new_inputs[1]->abstract());
} else {
new_cnode->set_abstract(cnode->abstract());
}
new_cnode->set_scope(cnode->scope());
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
if (kernel_graph != nullptr) {
@ -123,7 +124,8 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
return nullptr;
}
if (!node->isa<CNode>()) {

File diff suppressed because it is too large Load Diff

@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_RECOMPUTE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_RECOMPUTE_H_
#include "ir/anf.h"
namespace mindspore {
namespace opt {
// Automatically insert duplicated recomputed nodes.
void InsertRecomputedNodes(const FuncGraphPtr &graph);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_RECOMPUTE_H_

@ -37,6 +37,7 @@
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/step_auto_parallel.h"
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "frontend/optimizer/recompute.h"
#include "utils/log_adapter.h"
#include "pipeline/jit/pipeline_split.h"
@ -383,6 +384,12 @@ bool AddControlDependPass(const ResourcePtr &res) {
return true;
}
bool AddRecomputationPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
opt::InsertRecomputedNodes(res->func_graph());
return true;
}
bool MergeDupGraphPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
@ -474,7 +481,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"tuple_transform", OptPassTransformGraphGroup},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}};
{"add_control_depend", AddControlDependPass},
{"add_recomputation", AddRecomputationPass}};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},

@ -913,6 +913,8 @@ class Cell(Cell_):
"""Sets the name on the first time."""
if self._scope is None:
self._scope = name
elif self._scope == 'recomputed':
self._scope = self._scope + "_" + name
def _children_scope_recursive(self, parent_prefix='Default'):
"""Generates the scope of each layer of the network recursively."""
@ -1093,6 +1095,15 @@ class Cell(Cell_):
param.comm_fusion = fusion_type
return self
def recompute(self):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad
node and is set recomputed, we will compute it again for the grad node after the forward computation.
"""
self._set_scope('recomputed')
for cell in self.cells():
cell.recompute()
class GraphKernel(Cell):
"""

@ -0,0 +1,83 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 32
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool.recompute()
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
def train(net, data, label):
learning_rate = 0.01
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res = train_network(data, label)
print("+++++++++Loss+++++++++++++")
print(res)
print("+++++++++++++++++++++++++++")
diff = res.asnumpy() - 2.302585
assert np.all(diff < 1.e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_lenet():
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
train(net, data, label)
Loading…
Cancel
Save