!9570 Modifications for GraphKernel

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
pull/9570/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7b311f7d2a

@ -13,11 +13,12 @@
# limitations under the License.
# ===========================================================================
"""Cost model splitter"""
from functools import reduce
from .model import PrimLib, Graph, Tensor
use_poly_reduce = True
class GraphSplitByPattern:
"""Graph splitter"""
class Area:
@ -34,6 +35,8 @@ class GraphSplitByPattern:
if self.pattern == PrimLib.TRANSFORM or self.pattern == PrimLib.BROADCAST or \
(use_poly_reduce and self.pattern == PrimLib.REDUCE):
self.mode = self.MODE_COMPOSITE
if init_op.prim == "AddN":
self.mode = self.MODE_COMPOSITE
self.is_output = is_output
self.output_excluded = set()
if self.pattern == PrimLib.REDUCE:
@ -197,7 +200,7 @@ class GraphSplitByPattern:
min_area, forward_fuse = None, False
for a, _ in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \
(min_area is None or a.pattern < min_area.pattern):
(min_area is None or a.pattern < min_area.pattern):
min_area = a
for a, _ in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \
@ -211,7 +214,7 @@ class GraphSplitByPattern:
return None
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \
a.dom_op().output.shape != dom.dom_op().output.shape:
a.dom_op().output.shape != dom.dom_op().output.shape:
return None
return [a], True
@ -221,7 +224,7 @@ class GraphSplitByPattern:
fused = []
for a, r in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \
a.dom_op().output.shape == dom.dom_op().output.shape:
a.dom_op().output.shape == dom.dom_op().output.shape:
fused.append(a)
return fused, True
@ -232,7 +235,7 @@ class GraphSplitByPattern:
def _broadcast_depth(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
return None
a, r = list(dom.out_relations.items())[0]
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
@ -241,12 +244,12 @@ class GraphSplitByPattern:
def _broadcast_width(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
return None
fused = []
for a, r in dom.out_relations.items():
if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
return None
fused.append(a)
return fused, False
@ -302,8 +305,19 @@ class GraphSplitByPattern:
return size
def _reduce_output(dom):
def _is_atomic_add_available(dom):
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
return False
op = dom.ops[0]
reduce_axis = op.attrs["reduce_axis"]
if len(op.inputs[0].shape) - 1 in reduce_axis:
reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis])
return reduce_size >= 1024
return True
if dom.pattern != PrimLib.REDUCE:
return None
if _is_atomic_add_available(dom):
return None
is_all_reduce = _tensor_size(dom.ops[0].output) == 1
# excluded large size all reduce
if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
@ -311,7 +325,7 @@ class GraphSplitByPattern:
fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
dom.check_circle(a) and not dom.reduce_out_exclude(a):
dom.check_circle(a) and not dom.reduce_out_exclude(a):
fused.append(a)
return fused, False

@ -208,7 +208,7 @@ class CompositeGraph:
def _get_axis_while_none(input_shape, output_shape):
red_axis = []
if len(output_shape) == len(input_shape):
for s, i in enumerate(output_shape):
for i, s in enumerate(output_shape):
if s == 1 and input_shape[i] > 1:
red_axis.append(i)
else:

@ -158,7 +158,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
}
auto fuse_nodes = FindFuseCNodes(node, depend_prior);
if (fuse_nodes.size() <= 1) {
if (fuse_nodes.empty() || (fuse_nodes.size() == 1 && AnfAlgo::IsGraphKernel(fuse_nodes[0]))) {
continue;
}
changed = true;
@ -173,17 +173,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
}
} // namespace
bool FuseBasicOps(const FuncGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
bool FuseBasicOps(const FuncGraphPtr &func_graph) {
std::unordered_set<AnfNodePtr> fused_ops;
auto todos = TopoSort(kernel_graph->get_return());
auto todos = TopoSort(func_graph->get_return());
std::reverse(todos.begin(), todos.end());
return FuseBasicOps(kernel_graph, todos, &fused_ops);
return FuseBasicOps(func_graph, todos, &fused_ops);
}
void EliminateGetitem(const FuncGraphPtr &func_graph) {
@ -197,9 +191,16 @@ void EliminateGetitem(const FuncGraphPtr &func_graph) {
}
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
bool changed = FuseBasicOps(func_graph);
if (changed) {
EliminateGetitem(func_graph);
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}

@ -192,7 +192,7 @@ class EliminateGetitemForControlDepend : public Pass {
MS_EXCEPTION_IF_NULL(maketuple);
std::vector<size_t> result;
for (auto i : indexes_) {
auto real_output = maketuple->input(i);
auto real_output = maketuple->input(i + 1);
if (users[real_output].size() > 1) {
result.push_back(i);
}

@ -711,11 +711,11 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimGelu,
prim::kPrimFusedAdam,
prim::kPrimFusedAdamWeightDecay,
prim::kPrimTanhGrad,
prim::kPrimReduceMean,
prim::kPrimMaximumGrad,
prim::kPrimMinimumGrad,
prim::kPrimGkDropout
prim::kPrimGkDropout,
prim::kPrimDropoutGrad,
#endif
};
return expand_ops;

@ -544,7 +544,8 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
}
func_graph_ = func_graph;
this->Run();
return split_plan_.size() > 1;
if (split_plan_.empty()) return false;
return split_plan_.size() > 1 || NeedInline(0);
}
bool NeedInline(size_t group_id) const override {
@ -629,7 +630,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
}
GetValidKernelNodes();
// call CostModel to get a split plan.
if (!SplitByCostModel() || split_plan_.size() <= 1) {
if (!SplitByCostModel()) {
split_plan_.clear();
need_inline_.clear();
return;

@ -77,8 +77,8 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons
ShapeVector shape_i64;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); });
// Create new tensor
AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal)};
// The primitive should use a clone, otherwise the attr seed will be overrided.
AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal->Clone())};
auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())),
static_cast<void *>(&shape[0]), kNumberTypeInt64);
uniform_input.push_back(NewValueNode(tensor));
@ -98,8 +98,8 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons
// create new uniform_real_node
auto uniform_real_node = func_graph->NewCNode(uniform_input);
AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed", MakeValue(SizeToLong(rand_r(&seed_))));
AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed2", MakeValue(SizeToLong(rand_r(&seed_))));
AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed", MakeValue(SizeToLong(seed_++)));
AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed2", MakeValue(SizeToLong(seed_++)));
auto uniform_abstract = std::make_shared<abstract::AbstractTensor>(std::make_shared<Float>(32), shape_i64);
uniform_real_node->set_abstract(uniform_abstract);
uniform_real_node->set_kernel_info(std::make_shared<device::KernelInfo>());

Loading…
Cancel
Save