!13001 Auto_tune add sync fusion env.

From: @linqingke
Reviewed-by: @jjfeing,@xu-yfei
Signed-off-by: @xu-yfei
pull/13001/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8de7fbccd7

@ -22,7 +22,7 @@ from te.platform.cce_conf import te_set_version
from te.platform.fusion_manager import set_current_op_name
from te.platform.fusion_util import fusion_op, dump_fusion_json
from te.platform.parallel_compilation import init_multi_process_env, get_finished_compilation_task, \
deinit_multi_process_env, dispatch_autotune_task, start_ga_multi_process
deinit_multi_process_env, dispatch_autotune_task, start_ga_multi_process, import_py_module
import auto_tune
from schedule_search.rl_online_tune import rl_tune_init, dispatch_fusion_tune_task, dispatch_single_tune_task, \
rl_tune_deinit
@ -48,6 +48,8 @@ class TbeTuner:
if os.environ.get("TUNE_DUMP_PATH") is not None:
self.offline_dump_path = os.getenv("TUNE_DUMP_PATH", "")
self._creating_custom_path(tune_mode)
self.fusion_need_sync = 0
self.module_list = {}
def init_tune_interface(self, json_str, process_num):
"""
@ -222,6 +224,24 @@ class TbeTuner:
log.info("GA Tune init success.")
return True
def sync_fusion_env(self):
"""
Sync fusion env
:return: None
"""
if self.fusion_need_sync == 0:
return
module_using = []
for key, value in self.module_list.items():
if value > 0:
module_using.append(str(key))
self.module_list[key] = 0
module_str = ",".join(module_using)
import_py_module(module_str)
self.fusion_need_sync = 0
def rl_tune(self, task_id, op_json):
"""
RL tune for single op and fusion op
@ -231,6 +251,7 @@ class TbeTuner:
"""
json_info = json.loads(op_json)
if "fusion_op" in json_info:
self.sync_fusion_env()
ret = self.fusion_rl_tune(task_id, json_info)
else:
ret = self.single_rl_tune(task_id, json_info)
@ -244,6 +265,7 @@ class TbeTuner:
"""
json_info = json.loads(op_json)
if "fusion_op" in json_info:
self.sync_fusion_env()
self.fusion_ga_tune(task_id, json_info)
else:
self.single_ga_tune(task_id, json_info)
@ -289,6 +311,9 @@ class TbeTuner:
l1size = 0 # todo need to verify
ret = dispatch_single_tune_task(graph_id, task_id, l1size, base_kernel, kernel_name, op_module_name,
op_module_name + "@" + op_module_name, op_type, op_type, op_args)
self.module_list[op_module_name] = 1
self.fusion_need_sync += 1
return ret, job_type
def get_op_module_names(self, json_info):

@ -20,6 +20,7 @@
#include <algorithm>
#include <string>
#include <vector>
#include <map>
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
@ -28,6 +29,8 @@
namespace mindspore {
namespace kernel {
const std::map<std::string, size_t> kFormatIndexMap = {{"NCHW", 2}, {"HWCN", 0}, {"NHWC", 1}};
template <typename T>
class ConvGradInputGpuBkwKernel : public GpuKernel {
public:
@ -339,7 +342,16 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
void SetStrideAndDilation(const CNodePtr &kernel_node) {
std::vector<int64_t> stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "stride");
std::vector<int64_t> dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "dilation");
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
std::string format_me = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "format");
auto iter = kFormatIndexMap.find(format_me);
if (iter == kFormatIndexMap.end()) {
MS_LOG(EXCEPTION) << "OriFormat is " << format_me << ", Please confirm that in {NCHW, HWCN, NHWC}.";
}
size_t h_index = iter->second;
if (stride_me.size() < h_index + 2) {
MS_LOG(EXCEPTION) << "Strides should greater than " << h_index + 1 << ", but got " << stride_me.size();
}
(void)std::transform(stride_me.begin() + h_index, stride_me.begin() + h_index + 2, std::back_inserter(stride_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
[](const int64_t &value) { return static_cast<int>(value); });

@ -1981,7 +1981,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('stride', self.stride)
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('dilation', self.dilation)

Loading…
Cancel
Save