!14186 Support while bprop

From: @liangzelang
Reviewed-by: @kisnwang,@jjfeing
Signed-off-by: @jjfeing
pull/14186/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7f4994af7c

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -92,7 +92,9 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
return false;
}
kernel_mod_ptr->SetInputSizeList(input_size_list);
if (output_num == 1 && HasAbstractMonad(anf_node)) {
output_num = 0;
}
for (size_t i = 0; i < output_num; i++) {
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i));
@ -229,6 +231,9 @@ void SetNodeOutputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef
MS_EXCEPTION_IF_NULL(proto);
MS_EXCEPTION_IF_NULL(anf_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
if (output_num == 1 && HasAbstractMonad(anf_node)) {
output_num = 0;
}
if (output_num == 0) {
MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. ";
return;

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,32 +38,10 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
return;
}
// For compatibility with the current framework
if (op_name == kPrint || op_name == kGetNext || op_name == kPack || op_name == kMeshgrid) {
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
}
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetInputsFormat(inputs_format);
builder.SetInputsDeviceType(inputs_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetProcessor(AICPU);
builder.SetKernelType(AICPU_KERNEL);
builder.SetFusionType(OPAQUE);
kernel_info_list->push_back(builder.Build());
if (op_name == kPrint || op_name == kGetNext || op_name == kPack || op_name == kMeshgrid ||
op_name == kStackInitOpName || op_name == kStackDestroyOpName || op_name == kStackPushOpName ||
op_name == kStackPopOpName) {
AicpuMetadataInfoForSpecialNodes(kernel_node, kernel_info_list);
return;
}
if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) {
@ -71,5 +49,37 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
return;
}
}
void AicpuMetadataInfoForSpecialNodes(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid || op_name == kStackInitOpName ||
op_name == kStackDestroyOpName || op_name == kStackPushOpName || op_name == kStackPopOpName) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
}
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetInputsFormat(inputs_format);
builder.SetInputsDeviceType(inputs_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetProcessor(AICPU);
builder.SetKernelType(AICPU_KERNEL);
builder.SetFusionType(OPAQUE);
kernel_info_list->push_back(builder.Build());
return;
}
} // namespace kernel
} // namespace mindspore

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,6 +25,8 @@
namespace mindspore {
namespace kernel {
void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
void AicpuMetadataInfoForSpecialNodes(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_META_DATA_H_

@ -154,7 +154,7 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
return nullptr;
}
auto next_op_name = AnfAlgo::GetCNodeName(next_cnode);
if (next_op_name == prim::kPrimSend->name()) {
if (next_op_name == prim::kPrimSend->name() || next_op_name == kStackPushOpName) {
return nullptr;
}
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
@ -229,7 +229,8 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
}
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name()) {
if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name() ||
AnfAlgo::GetCNodeName(prior_op) == kStackPopOpName) {
return nullptr;
}
kernel_query->Query(prior_op, &kernel_info_list);

@ -106,6 +106,43 @@ std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
enum ShapeType { kMaxShape, kMinShape };
} // namespace
AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) {
return kg->NewValueNode(kUMonad->ToAbstract(), kUMonad);
}
// Convert:
// a = former(xxx)
// b = latter(x, xxx)
// To:
// a = former(xxx)
// d1 = Depend(x, a)
// b = latter(d1, xxx)
// ...
// out = Depend(out, latter)
void AnfRuntimeAlgorithm::KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter) {
if (latter->isa<CNode>()) {
auto latter_cnode = latter->cast<CNodePtr>();
constexpr size_t inputsize = 2;
constexpr size_t kFirstDataInputIndex = 1;
if (latter_cnode->inputs().size() < inputsize) {
return;
}
auto latter_input = latter_cnode->input(kFirstDataInputIndex);
auto depend1 = kg->NewCNode({NewValueNode(prim::kPrimDepend), latter_input, former});
depend1->set_abstract(latter_input->abstract());
latter_cnode->set_input(kFirstDataInputIndex, depend1);
auto return_node = kg->get_return();
MS_EXCEPTION_IF_NULL(return_node);
auto depend2 = kg->NewCNode(
{NewValueNode(prim::kPrimDepend), return_node->cast<CNodePtr>()->input(kFirstDataInputIndex), latter});
depend2->set_abstract(return_node->cast<CNodePtr>()->input(kFirstDataInputIndex)->abstract());
kg->set_output(depend2);
MS_LOG(DEBUG) << "former: " << former->DebugString() << ", latter: " << latter->DebugString()
<< ", depend1: " << depend1->DebugString() << ", depend2: " << depend2->DebugString();
}
}
AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
MS_EXCEPTION_IF_NULL(tuple_get_item);
if (tuple_get_item->size() != kTupleGetItemInputSize) {
@ -1529,6 +1566,13 @@ bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
return false;
}
// aicpu stack ops are not independent nodes.
if (AnfAlgo::GetCNodeName(node) == kStackInitOpName || AnfAlgo::GetCNodeName(node) == kStackDestroyOpName ||
AnfAlgo::GetCNodeName(node) == kStackPopOpName || AnfAlgo::GetCNodeName(node) == kStackPushOpName) {
MS_LOG(INFO) << "AICPU stack ops should not be independent node";
return false;
}
size_t input_nums = AnfAlgo::GetInputTensorNum(node);
if (input_nums == 0) {
return true;

@ -43,6 +43,8 @@ using DeviceAddress = device::DeviceAddress;
using DeviceAddressPtr = device::DeviceAddressPtr;
class AnfRuntimeAlgorithm {
public:
static AnfNodePtr MakeMonadValueNode(const KernelGraphPtr &kg);
static void KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter);
// get real input node of tuple_get_item
static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);

File diff suppressed because it is too large Load Diff

@ -116,6 +116,10 @@ constexpr auto kApplyProximalAdagradOpName = "ApplyProximalAdagrad ";
constexpr auto kApplyProximalGradientDescentOpName = "ApplyProximalGradientDescent";
constexpr auto kApplyRMSPropOpName = "ApplyRMSProp";
constexpr auto kTransDataOpName = "TransData";
constexpr auto kStackInitOpName = "StackInit";
constexpr auto kStackPushOpName = "StackPush";
constexpr auto kStackPopOpName = "StackPop";
constexpr auto kStackDestroyOpName = "StackDestroy";
constexpr auto kBNTrainingUpdateGradOpName = "BNTrainingUpdateGrad";
constexpr auto kBNTrainingReduceGradOpName = "BNTrainingReduceGrad";
constexpr auto kSquareSumV1OpName = "SquareSumV1";
@ -381,6 +385,7 @@ constexpr auto kAttrRankSize = "rank_size";
constexpr auto kAttrPadDimSize = "pad_dim_size";
constexpr auto kAttrPaddings = "paddings";
constexpr auto kAttrNumSegments = "num_segments";
constexpr auto kAttrStackOpName = "stack_op_name";
constexpr auto kAttrBegin = "begin";
constexpr auto kAttrSize = "size";
constexpr auto kAttrIsDynamicShape = "is_dynamic_shape";

@ -105,6 +105,12 @@ inline const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGot
inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch");
inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");
// Stack ops
inline const PrimitivePtr kPrimStackInit = std::make_shared<Primitive>("StackInit");
inline const PrimitivePtr kPrimStackDestroy = std::make_shared<Primitive>("StackDestroy");
inline const PrimitivePtr kPrimStackPush = std::make_shared<Primitive>("StackPush");
inline const PrimitivePtr kPrimStackPop = std::make_shared<Primitive>("StackPop");
// Arrays
inline const PrimitivePtr kPrimBroadcastTo = std::make_shared<Primitive>("BroadcastTo");
inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -33,7 +33,7 @@ grad_by_list = C.GradOperation(get_by_list=True)
grad_all = C.GradOperation(get_all=True)
def test_while_forward():
def test_while_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
@ -46,31 +46,71 @@ def test_while_forward():
x[idx, :, 0:2] = max_num
idx = idx + 1
return x
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return grad_all(self.net)(*inputs)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = MyWhileNet()
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
graph_output = net(idx, end, x)
#pynative mode
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
def test_while_with_const_param_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.mul = P.Mul()
self.add = P.Add()
def construct(self, x, y):
while x < y:
z = self.mul(x, x)
x = self.add(z, 1)
return x
def test_while_grad():
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor([1.1], dtype=ms.float32)
end = Tensor([8.0], dtype=ms.float32)
graph_output = net(idx, end)
expect_one = np.array([1.14433983e+02], dtype=np.float32)
expect_two = np.array([0], dtype=np.float32)
assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
def test_while_with_variable_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.mul = P.Mul()
self.add = P.Add()
def construct(self, idx, end, x):
while idx < end:
part = x[idx, :, :]
max_num = self.max(part)
x[idx, :, 0:2] = max_num
idx = idx + 1
def construct(self, x, y):
while x < y:
z = self.mul(x, x)
x = self.add(z, y)
return x
class GradNet(nn.Cell):
@ -80,20 +120,16 @@ def test_while_grad():
def construct(self, *inputs):
return grad_all(self.net)(*inputs)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
idx = Tensor([1.1], dtype=ms.float32)
end = Tensor([8.0], dtype=ms.float32)
graph_output = net(idx, end)
expect_one = np.array([2.20000005e+00], dtype=np.float32)
expect_two = np.array([1.00000000e+00], dtype=np.float32)
assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
def test_while_with_param_forward():
class MyWhileNet(nn.Cell):
@ -153,7 +189,6 @@ def test_while_endless_case():
pynative_output = net(idx, end, x)
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
def test_while_with_param_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
@ -180,7 +215,6 @@ def test_while_with_param_grad():
def construct(self, a, b, c):
return grad_by_list(self.net, self.weights)(a, b, c)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
while_net = MyWhileNet()
net = GradNet(while_net)
@ -188,10 +222,8 @@ def test_while_with_param_grad():
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
graph_output = net(idx, end, x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_output = net(idx, end, x)
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
expect = np.array([[[2, 2], [2, 2]], [[2, 2], [2, 2]]], dtype=np.int32)
assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
def test_while_with_param_forward_with_const_branch():
class MyWhileNet(nn.Cell):

Loading…
Cancel
Save