GraphKernel supports GPU

1. Update akg submodule
2. Refactor akg_kernel_build, akg_ascend_kernel_build, akg_gpu_kernel_build
3. Add akg_kernel_json_decoder to support converting kernel_json to AnfNode.
4. Add GraphKernel Cost Model. (mindspore/_extends/graph_kernel)
5. Add some GraphKernel passes to GpuSession, move these passes to backend/optimizer/graph_kernel.
6. Add global id for ir files.
7. Fix bug in ConstInputToAttr.
pull/5783/head
dayschan 5 years ago
parent c415e8ceda
commit 37a48f6aac

2
akg

@ -1 +1 @@
Subproject commit 3bb6264188d0b1d6ff776a35a571bc7190df0800
Subproject commit d237aa7d8e9d3fb709bda9f30205b02129bc2b59

@ -0,0 +1,17 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init"""
from .splitter import split_with_json
from .expander import get_op_expander

@ -0,0 +1,58 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate json desc for graph kernel ops"""
import json
import json.decoder as jd
import traceback
from mindspore import log as logger
import mindspore._extends.graph_kernel.expanders as expanders
def get_op_expander(json_str: str):
"""get op expander by json info"""
try:
kernel_info = json.loads(json_str)
expand_info = kernel_info['expand_info']
if 'name' not in expand_info:
logger.error("expand info have no op name")
return None
if 'process' not in expand_info:
logger.error("expand info have no processor info")
return None
processor = expand_info['process']
op_name = str(expand_info['name']).lower()
expand_op_func_name = 'expand_' + op_name
if not hasattr(expanders, expand_op_func_name):
logger.error("Generator do not support op: {}".format(op_name))
return None
expand_op_func = getattr(expanders, expand_op_func_name)
# generate graph desc.
graph = expand_op_func(expand_info)
if graph is None:
logger.error("Failed to generate graph of: {}".format(op_name))
return None
graph.set_processor(processor)
# dump graph to json desc.
desc = graph.dump()
return json.dumps(desc)
except jd.JSONDecodeError:
logger.error("Failed to generate graph kernel op")
logger.error(traceback.format_exc())
return None

@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""expanders init"""
from .gelu import expand_gelu
from .layernorm import expand_layernorm
from .softmax import expand_softmax
from .square import expand_square

@ -0,0 +1,68 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for gelu"""
from mindspore._extends.graph_kernel.model import model_builder as builder
CSVALUE = 0.044715
CSVALUE_A = 1.5957691 # 2*np.sqrt(2/np.pi)
def expand_gelu(expand_info):
"""Gelu expander"""
# get op info.
input_desc = expand_info['input_desc'][0]
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
dtype = input_x.dtype
if dtype == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
# cal tanh.
mul_0 = graph_builder.emit('Mul', [input_x, input_x])
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format'])
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1])
const_csvalue_a = graph_builder.value(tanh_res.dtype, CSVALUE_A, input_desc['format'])
mul_0 = graph_builder.emit('Mul', [tanh_res, const_csvalue_a])
const_zero = graph_builder.value(mul_0.dtype, 0.0, input_desc['format'])
mul_0_min = graph_builder.emit('Minimum', [mul_0, const_zero])
right_mul = graph_builder.emit('Exp', [mul_0_min])
mul_0_abs = graph_builder.emit('Abs', [mul_0])
const_neg_one = graph_builder.value(mul_0_abs.dtype, -1.0, input_desc['format'])
mul_0_abs_neg = graph_builder.emit('Mul', [mul_0_abs, const_neg_one])
mul_0_abs_neg_exp = graph_builder.emit('Exp', [mul_0_abs_neg])
const_one = graph_builder.value(mul_0_abs_neg_exp.dtype, 1.0, input_desc['format'])
mul_0_abs_neg_exp_add = graph_builder.emit('TensorAdd', [mul_0_abs_neg_exp, const_one])
left_mul = graph_builder.emit('RealDiv', [input_x, mul_0_abs_neg_exp_add])
result = graph_builder.emit('Mul', [left_mul, right_mul])
if dtype == 'float16':
result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'})
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for LayerNorm"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_layernorm(expand_info):
"""LayerNorm expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
attrs = expand_info['attr']
begin_norm_axis = None
epsilon = None
for item in attrs:
if 'begin_norm_axis' in item:
begin_norm_axis = item['begin_norm_axis']
if 'epsilon' in item:
epsilon = item['epsilon']
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_gamma = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
input_beta = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
# Calculate the scaling ratio of the average
shape_x = input_desc_0['shape']
if begin_norm_axis < 0:
begin_norm_axis += len(shape_x)
reduce_axis = ()
for i, _ in enumerate(shape_x):
if i > begin_norm_axis or i == begin_norm_axis:
reduce_axis = reduce_axis + (i,)
reduce_elts = 1.0
for i in reduce_axis:
reduce_elts *= shape_x[i]
mean_cof = 1.0 / reduce_elts
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof, input_x.data_format)
# Calculate mean
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
# Calculate variance
variance_sub = graph_builder.emit('Sub', [input_x, mean])
variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub])
variance_red = graph_builder.emit('ReduceSum', [variance_mul],
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
# Calculate normalize
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format)
normalize_add = graph_builder.emit('TensorAdd', [variance, epsilon_v])
normalize_log = graph_builder.emit('Log', [normalize_add])
input_y = graph_builder.value(input_x.dtype, -0.5, input_x.data_format)
normalize_log_mul = graph_builder.emit('Mul', [normalize_log, input_y])
normalize_exp = graph_builder.emit('Exp', [normalize_log_mul])
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normalize_exp])
# Calculate scale and translate
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
res = graph_builder.emit('TensorAdd', [scale_mul, input_beta])
# set graph output.
graph_scope.set_output(res, mean, variance)
graph = graph_builder.get()[0]
return graph

@ -0,0 +1,51 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for softmax"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_softmax(expand_info):
"""Softmax expander"""
# get op info.
input_desc = expand_info['input_desc'][0]
attrs = expand_info['attr']
axis = None
for item in attrs:
if 'axis' in item:
axis = item['axis']
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
# cal softmax.
if input_x.dtype == 'float32':
input_x_cast = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
max_x = graph_builder.emit('ReduceMax', [input_x_cast], attrs={'reduce_axis': axis, 'keep_dims': True})
max_x = graph_builder.emit('Cast', [max_x], attrs={'dst_type': 'float32'})
else:
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
data_sub = graph_builder.emit('Sub', [input_x, max_x])
data_exp = graph_builder.emit('Exp', [data_sub])
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
result = graph_builder.emit('RealDiv', [data_exp, data_expsum])
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for square"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_square(expand_info):
"""Square expander"""
# get op info.
input_desc = expand_info['input_desc'][0]
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
# create op.
result = graph_builder.emit('Mul', [input_x, input_x])
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

@ -0,0 +1,18 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""GraphKernel cost model init"""
from .graph_split import split
from .model_builder import GraphBuilder, load_composite

@ -0,0 +1,153 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Cost model splitter"""
from .model import PrimLib, Graph
class GraphSplitByPattern:
"""Graph split by pattern"""
def __init__(self, graph):
self.graph = graph
self.groups = []
self.op_group = {}
for op in self.graph.ops:
g = [op]
self.groups.append(g)
self.op_group[op] = g
self.ids = {}
for i, op in enumerate(graph.ops):
self.ids[op] = i
self.doms = self.post_dom(graph.ops)
_, outputs = graph.deduce_parameters()
self.outputs = set(outputs)
def post_dom(self, ops):
"""Post dom"""
doms, i_doms = {}, {}
for i in range(len(ops) - 1, -1, -1):
op = ops[i]
doms[op] = {op}
i_dom = None
if op.output.to_ops:
suc_dom = set(doms[op.output.to_ops[0]])
for to in op.output.to_ops[1:]:
suc_dom.intersection_update(doms[to])
doms[op].update(suc_dom)
for dom in suc_dom:
if i_dom is None or self.ids[dom] < self.ids[i_dom]:
i_dom = dom
i_doms[op] = i_dom
return i_doms
def get_pattern(self, op, i):
"""Get pattern"""
pattern = PrimLib.UNKNOWN
_, elem_relation = PrimLib.input_relation(op, i)
for pat in elem_relation:
if pat and pat > pattern:
pattern = pat
return pattern
def fuse(self, check_fun):
"""Fuse ops"""
def _get_path(op, dom):
path_ops, visited = [], set()
def _get_path_depth(p):
visited.add(p)
if self.op_group[p][0] == p:
path_ops.append(p)
for to in p.output.to_ops:
if to != dom and to not in visited:
_get_path_depth(to)
_get_path_depth(op)
return path_ops
changed = True
while changed:
for group in self.groups:
op = group[0]
dom = self.doms[op]
if dom is None or op.output in self.outputs:
continue
ops = _get_path(op, dom)
if check_fun(op, dom, ops):
dom_group = self.op_group[dom]
fused = []
for fop in ops:
f_group = self.op_group[fop]
for p in f_group:
self.op_group[p] = dom_group
fused.append(f_group)
dom_group += f_group
for g in fused:
self.groups.remove(g)
break
else:
changed = False
def to_subgraphs(self):
"""Transform op groups to subgraphs"""
subgraphs = []
for i, group in enumerate(self.groups):
group.sort(key=lambda op: self.ids[op])
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), group))
return subgraphs
def split(self):
"""Split graph"""
def _buddy(op, dom, path_ops):
"""Fuse buddy together"""
# pylint: disable=unused-argument
group = self.op_group[op]
for p in group:
# p is buddy
if p.output.buddy is not None and p.output.buddy.members[0].op not in group:
return True
# p's output is buddy
for to in p.output.to_ops:
if to.output.buddy is not None and to not in group:
return True
return False
def _injective(pattern, limit):
def _checker(op, dom, path_ops):
# pylint: disable=unused-argument
for p in op.output.to_ops:
if p not in self.op_group[dom]:
return False
if PrimLib.iter_type(op) in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
for i, t in enumerate(dom.inputs):
if t == op.output:
return self.get_pattern(dom, i) == pattern and len(self.op_group[op]) < limit
return False
return _checker
def _diamond(op, dom, path_ops):
if PrimLib.iter_type(op) not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
PrimLib.iter_type(dom) in (PrimLib.UNKNOWN, PrimLib.TRANSFORM):
return False
return len(path_ops) == 1 and op.output not in dom.inputs
self.fuse(_buddy)
self.fuse(_injective(PrimLib.ELEMWISE, 100))
self.fuse(_injective(PrimLib.BROADCAST, 6))
self.fuse(_injective(PrimLib.REDUCE, 6))
self.fuse(_diamond)
return self.to_subgraphs()
def split(graph):
return GraphSplitByPattern(graph).split()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GraphKernel splitter"""
import json
import json.decoder as jd
import traceback
from mindspore import log as logger
from . import model
def split_with_json(json_str: str):
"""Call costmodel to split GraphKernel"""
try:
graph_desc = json.loads(json_str)
comp = model.load_composite(graph_desc)
graph_split = model.split(comp.graph)
is_multi_graph = len(graph_split) > 1
graph_list = list(map(comp.dump, graph_split))
result = {"multi_graph": is_multi_graph, "graph_desc": graph_list}
return json.dumps(result)
except jd.JSONDecodeError:
logger.error(traceback.format_exc())
return None

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
PYTHONPATH="$(pwd)/..:${PYTHONPATH}"
export PYTHONPATH

@ -0,0 +1,142 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""graph kernel split"""
import json
import getopt
import sys
import model
def print_usage():
print('Usage: graph_kernel_split.py [OPTION] <JSON_FILE>')
print('Options:')
print(' -s <config/auto>\tsplit graph with config')
print(' -e \t\testimate graph')
print(' -i \t\tnaive estimate')
print(' -o <prefix>\toutput split graphs')
print(' -v \t\tverbose mode')
print(' -h \t\tprint this help')
print('Report bugs to xiong.gao@huawei.com')
class Option:
"""Options"""
def __init__(self):
self.split = None
self.estimate = False
self.estimate_naive = False
self.output = None
self.verbose = False
self.help = False
def parse(self, options):
"""parse options"""
for name, val in options:
if name == '-h':
self.help = True
elif name == '-v':
self.verbose = True
elif name == '-o':
self.output = val
elif name == '-e':
self.estimate = True
elif name == '-s':
self.split = val
elif name == '-i':
self.estimate_naive = True
opt = Option()
def estimate(graph_in, parts_in, naive):
"""estimate graphs costs"""
def _print_cost(name, c):
print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" %
(name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type()))
main_cost, _ = model.estimate(graph_in, naive)
split_cost, sub_costs = model.estimate(parts_in, naive) if parts_in else (None, None)
_print_cost("MainGraph:", main_cost)
if parts_in:
_print_cost("Subgraphs:", split_cost)
if opt.verbose:
for i, sub_cost in enumerate(sub_costs):
_print_cost(" |_%d:\t" % (i), sub_cost)
def split_graph(graph_in, config):
"""split graph"""
if config == 'auto':
return model.split(graph_in)
subgraphs = []
all_tensors = []
subgraph_idx = 0
config_parts = config.split('|')
for part in config_parts:
tensor_names = part.split(',')
graph_name = "%s_%d" % (graph_in.name, subgraph_idx)
g = graph_in.extract_subgraph(graph_name, tensor_names)
assert len(g.ops) == len(tensor_names)
subgraphs.append(g)
all_tensors += tensor_names
subgraph_idx += 1
if len(all_tensors) < len(graph_in.ops):
graph_name = "%s_%d" % (graph_in.name, subgraph_idx)
g = graph_in.extract_subgraph(graph_name, all_tensors, True)
subgraphs.append(g)
return subgraphs
def main():
opts, args = getopt.getopt(sys.argv[1:], 'heivo:s:')
opt.parse(opts)
if len(args) != 1 or opt.help:
print_usage()
sys.exit(0)
in_file = args[0]
with open(in_file, 'r') as f:
desc = json.loads(f.read())
comp = model.load_composite(desc)
graph = comp.graph
parts = []
# 1. split sub-graphs
if opt.split is not None:
parts = split_graph(graph, opt.split)
if opt.verbose:
print('----------- main graph --------------')
print(graph)
for i, _ in enumerate(parts):
print('---------------- sub graph %d ---------------' % (i))
print(parts[i])
# 2. estimate cost
if opt.estimate:
print('------------- cost --------------')
estimate(graph, parts, False)
if opt.estimate_naive:
print('------------- naive cost --------------')
estimate(graph, parts, True)
# 3. output parts
if opt.output is not None:
for graph_part in parts:
desc = comp.dump(graph_part)
s_desc = json.dumps(desc)
fname = "%s_%s.json" % (opt.output, graph_part.name)
with open(fname, 'w', encoding='utf-8') as of:
of.write(s_desc)
if __name__ == '__main__':
main()

@ -0,0 +1,53 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""test split"""
import model
def graph_1():
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a = gb.tensor([1024, 16], "float32", name="a")
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", b, 'c')
d = gb.emit("Abs", c, 'd')
gb.emit("TensorAdd", [b, d], "e")
return gb.get()[0]
def graph_2():
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a = gb.tensor([1024, 16], "float32", name="a")
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", b, 'c')
d = gb.emit("ReduceSum", c, 'd', attrs={'reduce_axis': (1,)})
gb.emit("Sqrt", d, 'e')
return gb.get()[0]
def test_split_by_pattern():
def _test(graph):
print("***************** main graph ***************")
print(graph)
subgraphs = model.split(graph)
for i, g in enumerate(subgraphs):
print('------------- subgraph {} --------------'.format(i))
print(g)
_test(graph_2())
if __name__ == '__main__':
test_split_by_pattern()

@ -44,7 +44,8 @@ if(ENABLE_GPU)
"runtime/device/gpu/*.cu"
"backend/kernel_compiler/gpu/*.cu"
"backend/kernel_compiler/akg/gpu/*.cc"
"backend/kernel_compiler/akg/akg_kernel_build.cc"
"backend/kernel_compiler/akg/akg_kernel_json_generator.cc"
"backend/kernel_compiler/akg/akg_kernel_json_decoder.cc"
"backend/kernel_compiler/akg/akg_kernel_attrs_process.cc"
)

@ -10,7 +10,8 @@ if (ENABLE_D)
"kernel_query.cc"
"kernel_fusion.cc"
"akg/ascend/*.cc"
"akg/akg_kernel_build.cc"
"akg/akg_kernel_json_generator.cc"
"akg/akg_kernel_json_decoder.cc"
"akg/akg_kernel_attrs_process.cc"
"akg/akg_kernel_metadata.cc"
"tbe/*.cc"
@ -49,7 +50,8 @@ if (ENABLE_GPU)
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"gpu/*.cu"
"akg/gpu/*.cc"
"akg/akg_kernel_build.cc"
"akg/akg_kernel_json_generator.cc"
"akg/akg_kernel_json_decoder.cc"
"akg/akg_kernel_attrs_process.cc"
)

@ -24,7 +24,6 @@
#include <climits>
#include "runtime/device/kernel_runtime.h"
#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h"
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
#include "proto/tensor.pb.h"
#include "proto/tensor_shape.pb.h"
#include "proto/attr.pb.h"
@ -33,6 +32,7 @@
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
#include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/oplib/oplib.h"
namespace mindspore {
namespace kernel {

@ -15,13 +15,20 @@
*/
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/common_utils.h"
#include "base/core_ops.h"
#include "utils/utils.h"
namespace mindspore {
namespace kernel {
namespace {
void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
// The x and output are akg op input and output param.
@ -169,5 +176,29 @@ void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) {
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_input_names), anf_node);
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node);
}
const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_node)>> kAkgKernelAttrsProcessMap = {
{kFour2FiveOpName, SetAkgAttrsForFour2Five},
{kFive2FourOpName, SetAkgAttrsForFive2Four},
{kCastOpName, SetAkgAttrsForCast},
{kBNGrad1OpName, SetAkgAttrsForBNGrad1},
{kBNGrad2OpName, SetAkgAttrsForBNGrad2},
{kBNGrad3OpName, SetAkgAttrsForBNGrad3},
{kFusedBN1OpName, SetAkgAttrsForFusedBN1},
{kFusedBN2OpName, SetAkgAttrsForFusedBN2},
{kFusedBN3OpName, SetAkgAttrsForFusedBN3},
{kConvBN1OpName, SetAkgAttrsForConvBN1},
{kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu},
{kBN2ReLUOpName, SetAkgAttrsForBN2Relu},
};
} // namespace
void SetAkgKernelAttrs(const AnfNodePtr &anf_node) {
auto it = kAkgKernelAttrsProcessMap.find(AnfAlgo::GetCNodeName(anf_node));
if (it != kAkgKernelAttrsProcessMap.end()) {
it->second(anf_node);
}
}
} // namespace kernel
} // namespace mindspore

@ -16,43 +16,13 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H
#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
#include "ir/anf.h"
#include "utils/utils.h"
#include "base/core_ops.h"
namespace mindspore {
namespace kernel {
void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node);
void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node);
void SetAkgAttrsForCast(const AnfNodePtr &anf_node);
void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node);
void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node);
void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node);
void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node);
void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node);
void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node);
void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node);
void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node);
void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node);
const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_node)>> kAkgKernelAttrsProcessMap = {
{kFour2FiveOpName, SetAkgAttrsForFour2Five},
{kFive2FourOpName, SetAkgAttrsForFive2Four},
{"Cast", SetAkgAttrsForCast},
{kBNGrad1OpName, SetAkgAttrsForBNGrad1},
{kBNGrad2OpName, SetAkgAttrsForBNGrad2},
{kBNGrad3OpName, SetAkgAttrsForBNGrad3},
{kFusedBN1OpName, SetAkgAttrsForFusedBN1},
{kFusedBN2OpName, SetAkgAttrsForFusedBN2},
{kFusedBN3OpName, SetAkgAttrsForFusedBN3},
{kConvBN1OpName, SetAkgAttrsForConvBN1},
{kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu},
{kBN2ReLUOpName, SetAkgAttrsForBN2Relu},
};
void SetAkgKernelAttrs(const AnfNodePtr &anf_node);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H

@ -1,76 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_
#include <unordered_map>
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <utility>
#include "backend/kernel_compiler/kernel.h"
#include "ir/dtype.h"
#include "ir/primitive.h"
#include <nlohmann/json.hpp>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/oplib/oplib.h"
namespace mindspore {
namespace kernel {
class AkgKernelBuild {
public:
AkgKernelBuild() {
input_tensor_idx_ = {};
output_tensor_idx_ = 0;
}
~AkgKernelBuild() = default;
KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size);
static std::string GetProcessor(const AnfNodePtr &anf_node);
protected:
bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json);
bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json);
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name,
const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json);
KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node);
int GetOpCntInc();
size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
size_t GetOutputTensorIdxInc();
bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name,
nlohmann::json *const node_json);
static int op_cnt_;
// lock for variable fusionOpCnt in singleton mode
static std::mutex op_cnt_mtx_;
std::string json_name_;
std::string json_info_;
std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_;
size_t output_tensor_idx_;
};
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size);
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
nlohmann::json *const node_json);
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_DECODER_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_DECODER_H_
#include <string>
#include <vector>
#include <map>
#include <nlohmann/json.hpp>
#include "ir/scalar.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace kernel {
class AkgKernelJsonDecoder {
public:
AkgKernelJsonDecoder() { nodes_map_.clear(); }
~AkgKernelJsonDecoder() = default;
FuncGraphPtr DecodeFusedNodes(const nlohmann::json &kernel_json);
FuncGraphPtr DecodeFusedNodes(const std::string &kernel_json_str);
bool DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map<std::string, AnfNodePtr> &address_node_map,
AnfNodePtrList *res_graphs);
private:
ScalarPtr DecodeScalar(const nlohmann::json &scalar_json);
ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph);
ParameterPtr DecodeParameter(const nlohmann::json &parameter_json, const FuncGraphPtr &func_graph);
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
std::map<std::string, AnfNodePtr> nodes_map_{};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_DECODER_H_

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save