!14499 [GraphKernel]split UMonad in inputs of op

From: @wenfangpei
Reviewed-by: @dayschan,@ckey_dou,@gaoxiong1
Signed-off-by: @gaoxiong1
pull/14499/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8634675e2d

@ -26,6 +26,7 @@
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/kernel_build_info.h" #include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/split_umonad.h"
#include "backend/optimizer/graph_kernel/substitute_dropout.h" #include "backend/optimizer/graph_kernel/substitute_dropout.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "mindspore/core/ir/graph_utils.h" #include "mindspore/core/ir/graph_utils.h"
@ -37,10 +38,14 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
constexpr size_t kAssignInputIdx = 1;
constexpr size_t kLambInputIdx = 12;
std::vector<PrimitivePtr> GetExpandOps() { std::vector<PrimitivePtr> GetExpandOps() {
std::vector<PrimitivePtr> expand_ops = { std::vector<PrimitivePtr> expand_ops = {
prim::kPrimSquare, prim::kPrimSquare,
prim::kPrimGeLUGrad, prim::kPrimGeLUGrad,
prim::kPrimAssignAdd,
#if ENABLE_D #if ENABLE_D
prim::kPrimTile, prim::kPrimTile,
prim::kPrimSqrtGrad, prim::kPrimSqrtGrad,
@ -69,7 +74,6 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimSigmoidCrossEntropyWithLogits, prim::kPrimSigmoidCrossEntropyWithLogits,
prim::kPrimSigmoidCrossEntropyWithLogitsGrad, prim::kPrimSigmoidCrossEntropyWithLogitsGrad,
prim::kPrimSoftmaxCrossEntropyWithLogits, prim::kPrimSoftmaxCrossEntropyWithLogits,
prim::kPrimAssignAdd,
#endif #endif
}; };
const auto &flags = context::GraphKernelFlags::GetInstance(); const auto &flags = context::GraphKernelFlags::GetInstance();
@ -167,6 +171,22 @@ AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) {
return graph_kernel_node; return graph_kernel_node;
} }
ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = {
{prim::kPrimDropout, std::make_shared<DropoutExpander>()},
{prim::kPrimAssignAdd, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
{prim::kPrimAssignSub, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
{prim::kLambApplyOptimizerAssign, std::make_shared<OpUMonadExpander>(kLambInputIdx)},
};
for (auto &e : expanders) {
if (IsPrimitiveCNode(node, e.first)) {
return e.second;
}
}
return std::make_shared<DefaultExpander>();
}
bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
bool changed = false; bool changed = false;
auto todos = TopoSort(func_graph->get_return()); auto todos = TopoSort(func_graph->get_return());
@ -192,18 +212,6 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
return changed; return changed;
} }
ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = {
{prim::kPrimDropout, std::make_shared<DropoutExpander>()},
};
for (auto &e : expanders) {
if (IsPrimitiveCNode(node, e.first)) {
return e.second;
}
}
return std::make_shared<DefaultExpander>();
}
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
expand_ops_ = GetExpandOps(); expand_ops_ = GetExpandOps();
return DoExpand(func_graph); return DoExpand(func_graph);

@ -37,7 +37,7 @@
#include "backend/optimizer/graph_kernel/value_graph_binder.h" #include "backend/optimizer/graph_kernel/value_graph_binder.h"
#include "backend/optimizer/graph_kernel/parallel_fusion.h" #include "backend/optimizer/graph_kernel/parallel_fusion.h"
#include "backend/optimizer/graph_kernel/optimize_assign.h" #include "backend/optimizer/graph_kernel/optimize_assign.h"
#include "backend/optimizer/graph_kernel/split_assign.h" #include "backend/optimizer/graph_kernel/split_umonad.h"
#include "backend/optimizer/graph_kernel/reorder_ops.h" #include "backend/optimizer/graph_kernel/reorder_ops.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h" #include "backend/optimizer/graph_kernel/update_state_formatter.h"
#include "backend/optimizer/graph_kernel/axis_normalizer.h" #include "backend/optimizer/graph_kernel/axis_normalizer.h"

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/optimizer/graph_kernel/split_assign.h" #include "backend/optimizer/graph_kernel/split_umonad.h"
#include <vector> #include <vector>
#include <string> #include <string>
@ -35,31 +35,63 @@ const BaseRef SplitAssign::DefinePattern() const {
return VectorRef({v, Xs, Us, UMonad}); return VectorRef({v, Xs, Us, UMonad});
} }
bool CanSplit(const AnfNodePtr &node) { bool CanSplit(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimAssign); }
return IsPrimitiveCNode(node, prim::kPrimAssignAdd) || IsPrimitiveCNode(node, prim::kPrimAssign) ||
IsPrimitiveCNode(node, prim::kPrimAssignSub);
}
const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int input_idx) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!CanSplit(node)) return node;
CNodePtr cnode = node->cast<CNodePtr>(); CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, kAssignInputTensorNum);
// Get original assign op's abstract and inputs // Get original op's abstract and inputs
AbstractBasePtr original_abstract = cnode->abstract()->Clone(); AbstractBasePtr original_abstract = cnode->abstract()->Clone();
auto original_inputs = cnode->inputs(); auto original_inputs = cnode->inputs();
int input_node_size = cnode->size() - 1;
// Create depend node // Create depend node
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[1], original_inputs[3]}; AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[input_idx],
original_inputs[input_node_size]};
auto depend_cnode = func_graph->NewCNode(depend_inputs); auto depend_cnode = func_graph->NewCNode(depend_inputs);
depend_cnode->set_abstract(original_inputs[1]->abstract()); depend_cnode->set_abstract(original_inputs[input_idx]->abstract());
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>()); depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
// Create new assign node, delete U from inputs. // Create new node, delete U from inputs.
AnfNodePtrList new_assign_inputs = {cnode->input(0), depend_cnode, original_inputs[2]}; AnfNodePtrList new_inputs = {cnode->input(0)};
auto new_assign_cnode = func_graph->NewCNode(new_assign_inputs); for (int i = 1; i < input_node_size; i++) {
new_assign_cnode->set_abstract(original_abstract); if (i == input_idx) {
new_assign_cnode->set_kernel_info(cnode->kernel_info_ptr()); new_inputs.push_back(depend_cnode);
return new_assign_cnode; } else {
new_inputs.push_back(cnode->input(i));
}
}
auto new_cnode = func_graph->NewCNode(new_inputs);
new_cnode->set_abstract(original_abstract);
new_cnode->set_kernel_info(cnode->kernel_info_ptr());
return new_cnode;
}
const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
if (!CanSplit(node)) return node;
return ProcessNode(node->func_graph(), node, 1);
}
AnfNodePtr OpUMonadExpander::Run(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
bool has_umonad = false;
for (unsigned int i = 1; i < cnode->size(); i++) {
if (HasAbstractUMonad(cnode->input(i))) {
has_umonad = true;
break;
}
}
if (has_umonad) {
auto new_node = ProcessNode(node->func_graph(), node, input_idx_);
return DefaultExpander::Run(new_node);
}
return DefaultExpander::Run(node);
} }
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

@ -13,11 +13,11 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class SplitAssign : public PatternProcessPass { class SplitAssign : public PatternProcessPass {
@ -27,6 +27,16 @@ class SplitAssign : public PatternProcessPass {
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
}; };
class OpUMonadExpander : public DefaultExpander {
public:
explicit OpUMonadExpander(int input_idx) : input_idx_(input_idx) {}
~OpUMonadExpander() = default;
AnfNodePtr Run(const AnfNodePtr &node) override;
private:
int input_idx_;
};
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_

@ -219,7 +219,9 @@ bool ExtendOutputForUpdateState::ProcessIndex(const FuncGraphPtr &func_graph, co
auto mng = func_graph->manager(); auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
for (auto user : mng->node_users()[getitems_[index]]) { for (auto user : mng->node_users()[getitems_[index]]) {
user.first->cast<CNodePtr>()->set_input(user.second, new_node); if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
user.first->cast<CNodePtr>()->set_input(user.second, new_node);
}
} }
return true; return true;
} }

@ -32,26 +32,38 @@ class AssignAdd(nn.Cell):
self.add(self.var, y) self.add(self.var, y)
return self.var return self.var
def get_output(x2, y2, enable_graph_kernel=False):
@pytest.mark.level0 context.set_context(enable_graph_kernel=enable_graph_kernel)
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_assign_add():
x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=True, device_target="GPU")
add = AssignAdd(x2) add = AssignAdd(x2)
result_gk_on_1 = add(y2) result_gk_on_1 = add(y2)
add_2 = AssignAdd(result_gk_on_1) add_2 = AssignAdd(result_gk_on_1)
result_gk_on_2 = add_2(y2) result_gk_on_2 = add_2(y2)
output = [result_gk_on_1, result_gk_on_2]
return output
def assign_add():
x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
expect = get_output(x2, y2, False)
output = get_output(x2, y2, True)
e1, e2 = list(expect)
o1, o2 = list(output)
assert np.allclose(o1.asnumpy(), e1.asnumpy())
assert np.allclose(o2.asnumpy(), e2.asnumpy())
context.set_context(mode=context.GRAPH_MODE, @pytest.mark.level0
enable_graph_kernel=False, device_target="GPU") @pytest.mark.platform_x86_gpu_training
add_beta = AssignAdd(x2) @pytest.mark.env_onecard
result_gk_off_1 = add_beta(y2) def test_assign_add_gpu():
add_beta_2 = AssignAdd(result_gk_off_1) context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
result_gk_off_2 = add_beta_2(y2) assign_add()
assert (result_gk_on_1.asnumpy() == result_gk_off_1.asnumpy()).all()
assert (result_gk_on_2.asnumpy() == result_gk_off_2.asnumpy()).all() @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_assign_add_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
assign_add()

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
@ -67,6 +68,10 @@ def lamb_apply_optimizer_assign():
assert np.allclose(o2.asnumpy(), e2.asnumpy()) assert np.allclose(o2.asnumpy(), e2.asnumpy())
assert np.allclose(o3.asnumpy(), e3.asnumpy()) assert np.allclose(o3.asnumpy(), e3.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_lamb_apply_optimizer_assign_ascend(): def test_lamb_apply_optimizer_assign_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
lamb_apply_optimizer_assign() lamb_apply_optimizer_assign()

Loading…
Cancel
Save