parent
d8d502bfb6
commit
960da5cbed
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Server functions.
|
||||||
|
|
||||||
|
Python functions that will be called in the c++ client part of MindSpore.
|
||||||
|
"""
|
@ -0,0 +1,235 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""kernel build server"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_compiler, op_select_format, check_supported
|
||||||
|
|
||||||
|
class TbeBuilder:
|
||||||
|
"""Tbe building wrapper"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.tbe_builder = create_tbe_parallel_compiler()
|
||||||
|
|
||||||
|
def start(self, json):
|
||||||
|
return self.tbe_builder.start_compile_op(json)
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
return self.tbe_builder.wait_one()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.tbe_builder.reset_task_info()
|
||||||
|
|
||||||
|
def exit(self):
|
||||||
|
self.tbe_builder.exit()
|
||||||
|
|
||||||
|
class Messager:
|
||||||
|
'''Messager'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
logger.info('[TRACE]', 'Messager init...')
|
||||||
|
self.message = ''
|
||||||
|
self.tbe_builder = TbeBuilder()
|
||||||
|
|
||||||
|
def get_message(self):
|
||||||
|
"""
|
||||||
|
Get message from remote
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Not read by input() anymore
|
||||||
|
res = self.fin.readline()
|
||||||
|
if not res:
|
||||||
|
logger.info('[TRACE]', "read <empty>")
|
||||||
|
self.exit()
|
||||||
|
if res[len(res) - 1] == '\n':
|
||||||
|
res = res[0:len(res)-1]
|
||||||
|
self.message = res
|
||||||
|
logger.debug('[IN]', self.message)
|
||||||
|
except EOFError:
|
||||||
|
self.exit()
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
if self.message == '' or self.message == 'FIN':
|
||||||
|
self.send_ack()
|
||||||
|
self.exit()
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
def send_res(self, res, keep_format=True):
|
||||||
|
"""
|
||||||
|
Send result to remote
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keep_format: True or False
|
||||||
|
"""
|
||||||
|
logger.debug('[OUT]', str(res))
|
||||||
|
if keep_format:
|
||||||
|
res_str = str(res).replace('\n', '[LF]').replace('\r', '[CR]').replace(' ', '[SP]')
|
||||||
|
else:
|
||||||
|
res_str = str(res).replace('\n', '').replace('\r', '').replace(' ', '')
|
||||||
|
tag = '[~]' # The same as client kTAG
|
||||||
|
|
||||||
|
# Not write by print(tag + res_str, flush=True) any more
|
||||||
|
try:
|
||||||
|
self.fout.write(tag + res_str + "\n")
|
||||||
|
self.fout.flush()
|
||||||
|
except BrokenPipeError as err:
|
||||||
|
logger.info('[TRACE]', 'Write, ' + str(err))
|
||||||
|
self.exit()
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def send_ack(self, success=True):
|
||||||
|
"""
|
||||||
|
Send ack to remote
|
||||||
|
|
||||||
|
Args:
|
||||||
|
success: True or False
|
||||||
|
"""
|
||||||
|
if success:
|
||||||
|
self.send_res('ACK')
|
||||||
|
else:
|
||||||
|
self.send_res('ERR')
|
||||||
|
|
||||||
|
def handle(self):
|
||||||
|
"""
|
||||||
|
Communicate with remote
|
||||||
|
"""
|
||||||
|
arg = self.get_message()
|
||||||
|
if arg == 'START':
|
||||||
|
self.send_ack()
|
||||||
|
json = self.get_message()
|
||||||
|
res = self.tbe_builder.start(json)
|
||||||
|
self.send_res(res)
|
||||||
|
elif arg == 'WAIT':
|
||||||
|
self.send_ack()
|
||||||
|
task_id, res, pre = self.tbe_builder.wait()
|
||||||
|
logger.debug('[TRACE]', str(task_id) + '/' + str(res) + '/' + str(pre))
|
||||||
|
if self.get_message() != 'CONT':
|
||||||
|
self.send_ack(False)
|
||||||
|
self.exit()
|
||||||
|
self.send_res(task_id)
|
||||||
|
if self.get_message() != 'CONT':
|
||||||
|
self.send_ack(False)
|
||||||
|
self.exit()
|
||||||
|
self.send_res(res)
|
||||||
|
if self.get_message() != 'CONT':
|
||||||
|
self.send_ack(False)
|
||||||
|
self.exit()
|
||||||
|
self.send_res(pre)
|
||||||
|
elif arg == 'RESET':
|
||||||
|
self.tbe_builder.reset()
|
||||||
|
self.send_ack()
|
||||||
|
elif arg == 'FORMAT':
|
||||||
|
self.send_ack()
|
||||||
|
json = self.get_message()
|
||||||
|
self.send_res(op_select_format(json))
|
||||||
|
elif arg == 'SUPPORT':
|
||||||
|
self.send_ack()
|
||||||
|
json = self.get_message()
|
||||||
|
logger.debug('[SUPPORT]', json)
|
||||||
|
try:
|
||||||
|
res = check_supported(json)
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
self.send_ack(False)
|
||||||
|
self.exit()
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
self.send_res(res)
|
||||||
|
else:
|
||||||
|
self.send_ack(False)
|
||||||
|
self.exit()
|
||||||
|
|
||||||
|
def loop(self):
|
||||||
|
"""
|
||||||
|
Messaging loop
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
self.handle()
|
||||||
|
|
||||||
|
def exit(self):
|
||||||
|
os.close(self.fdin)
|
||||||
|
os.close(self.fdout)
|
||||||
|
self.tbe_builder.reset()
|
||||||
|
self.tbe_builder.exit()
|
||||||
|
logger.info('[TRACE]', 'Messager Exit...')
|
||||||
|
exit()
|
||||||
|
|
||||||
|
def run(self, fdin, fdout):
|
||||||
|
self.fdin = fdin
|
||||||
|
self.fdout = fdout
|
||||||
|
self.fin = os.fdopen(fdin, "r")
|
||||||
|
self.fout = os.fdopen(fdout, "w")
|
||||||
|
self.loop()
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
"""
|
||||||
|
Replace dummy 'logger' to output log as below:
|
||||||
|
logger = Logger("remote_kernel_build_" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log")
|
||||||
|
"""
|
||||||
|
def __init__(self, level=1, dumpfile=False, filename='Logger.log'):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
level: 0 for debug and info, 1 for info
|
||||||
|
dumpfile: if dump log into file
|
||||||
|
"""
|
||||||
|
self.level = level
|
||||||
|
self.dumpfile = dumpfile
|
||||||
|
if self.dumpfile:
|
||||||
|
self.log = open(filename, "a")
|
||||||
|
|
||||||
|
def write(self, msg):
|
||||||
|
self.log.write(msg)
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
def writeline(self, tag, msg):
|
||||||
|
prefix = tag + ' REMOTE(' + str(os.getpid()) + ',python)'
|
||||||
|
line = prefix + '\t' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ':\t' + msg
|
||||||
|
print(line, flush=True)
|
||||||
|
if self.dumpfile:
|
||||||
|
self.write(line + '\n')
|
||||||
|
|
||||||
|
def debug(self, tag, msg):
|
||||||
|
if self.level == 0:
|
||||||
|
self.writeline('[DEBUG]' + tag, msg)
|
||||||
|
|
||||||
|
def info(self, tag, msg):
|
||||||
|
self.writeline('[INFO]' + tag, msg)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
self.log.flush()
|
||||||
|
|
||||||
|
class DummyLogger:
|
||||||
|
"""DummyLogger"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def debug(self, tag, msg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def info(self, tag, msg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger = Logger()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if len(sys.argv) != 3:
|
||||||
|
raise Exception('Incorrect argv: {}'.format(sys.argv))
|
||||||
|
logger.debug('[TRACE]', 'argv: ' + str(sys.argv))
|
||||||
|
messager = Messager()
|
||||||
|
messager.run(int(sys.argv[1]), int(sys.argv[2]))
|
@ -1,198 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "backend/kernel_compiler/tbe/tbe_python_funcs.h"
|
|
||||||
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
|
||||||
#include "common/utils.h"
|
|
||||||
#include "utils/context/ms_context.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace kernel {
|
|
||||||
using mindspore::kernel::tbe::TbeUtils;
|
|
||||||
constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process";
|
|
||||||
constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler";
|
|
||||||
constexpr auto kOpSelectFormatFunc = "op_select_format";
|
|
||||||
constexpr auto kCheckSupportedFunc = "check_supported";
|
|
||||||
constexpr auto kTBEException = "TBEException";
|
|
||||||
|
|
||||||
PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr;
|
|
||||||
PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr;
|
|
||||||
PyObject *TbePythonFuncs::pOpSelectFormatFunc_ = nullptr;
|
|
||||||
PyObject *TbePythonFuncs::pCheckSupportedFunc_ = nullptr;
|
|
||||||
bool TbePythonFuncs::Init() {
|
|
||||||
static bool initialized = false;
|
|
||||||
if (initialized) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// Initialize cache
|
|
||||||
TbeUtils::LoadCache();
|
|
||||||
|
|
||||||
// tbe_process
|
|
||||||
PyObject *pTbeProcessModule = nullptr;
|
|
||||||
pTbeProcessModule = PyImport_ImportModule(kTbeProcessModule);
|
|
||||||
if (pTbeProcessModule == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to import [" << kTbeProcessModule << "] module.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
pCreateTbeParallelCompilerFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCreateTbeParallelCompilerFunc);
|
|
||||||
if (pCreateTbeParallelCompilerFunc_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule
|
|
||||||
<< "], FuncName:[" << kCreateTbeParallelCompilerFunc << "].";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
pTbeCompiler_ = PyEval_CallObject(pCreateTbeParallelCompilerFunc_, nullptr);
|
|
||||||
if (pTbeCompiler_ == nullptr) {
|
|
||||||
PyErr_Print();
|
|
||||||
MS_EXCEPTION(ArgumentError) << "Failed to call function : create_parallel_compiler.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
pOpSelectFormatFunc_ = PyObject_GetAttrString(pTbeProcessModule, kOpSelectFormatFunc);
|
|
||||||
if (pOpSelectFormatFunc_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule
|
|
||||||
<< "], FuncName:[" << kOpSelectFormatFunc << "].";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
pCheckSupportedFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCheckSupportedFunc);
|
|
||||||
if (pCheckSupportedFunc_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule
|
|
||||||
<< "], FuncName:[" << kCheckSupportedFunc << "].";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
initialized = true;
|
|
||||||
MS_LOG(INFO) << "TbePythonFuncs initialized Success.";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string TbePythonFuncs::PyObjectToStr(PyObject *PyObj) {
|
|
||||||
char *pChar = nullptr;
|
|
||||||
std::string str_res;
|
|
||||||
if (PyObj == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Input parameter is nullptr.";
|
|
||||||
return str_res;
|
|
||||||
}
|
|
||||||
PyObject *strArgs = PyObject_Str(PyObj);
|
|
||||||
if (strArgs != nullptr) {
|
|
||||||
(void)PyArg_Parse(strArgs, "s", &pChar);
|
|
||||||
}
|
|
||||||
if (pChar == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "pChar is nullptr.";
|
|
||||||
return str_res;
|
|
||||||
}
|
|
||||||
str_res = pChar;
|
|
||||||
return str_res;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) {
|
|
||||||
PyObject *pArg = nullptr;
|
|
||||||
PyObject *pRet = nullptr;
|
|
||||||
std::string res_json_str;
|
|
||||||
|
|
||||||
if (!Init()) {
|
|
||||||
MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !";
|
|
||||||
return res_json_str;
|
|
||||||
}
|
|
||||||
|
|
||||||
// assembly Args
|
|
||||||
pArg = PyTuple_New(1);
|
|
||||||
std::string json_str = kernel_json.dump();
|
|
||||||
(void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", json_str.c_str()));
|
|
||||||
if (pArg == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject.";
|
|
||||||
return res_json_str;
|
|
||||||
}
|
|
||||||
|
|
||||||
// call functions
|
|
||||||
if (pOpSelectFormatFunc_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "function is nullptr.";
|
|
||||||
return res_json_str;
|
|
||||||
}
|
|
||||||
|
|
||||||
pRet = PyEval_CallObject(pOpSelectFormatFunc_, pArg);
|
|
||||||
if (pRet == nullptr) {
|
|
||||||
PyErr_Print();
|
|
||||||
MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc
|
|
||||||
<< "], function args:" << PyObjectToStr(pArg);
|
|
||||||
}
|
|
||||||
|
|
||||||
char *pstr = nullptr;
|
|
||||||
(void)PyArg_Parse(pRet, "s", &pstr);
|
|
||||||
res_json_str = pstr;
|
|
||||||
if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) {
|
|
||||||
MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str
|
|
||||||
<< " ,function args:" << PyObjectToStr(pArg);
|
|
||||||
}
|
|
||||||
return res_json_str;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) {
|
|
||||||
PyObject *pArg = nullptr;
|
|
||||||
PyObject *pRes = nullptr;
|
|
||||||
bool ret = false;
|
|
||||||
|
|
||||||
if (!Init()) {
|
|
||||||
MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
// assembly Args
|
|
||||||
pArg = PyTuple_New(1);
|
|
||||||
std::string json_str = kernel_json.dump();
|
|
||||||
PyObject *arg1 = Py_BuildValue("s", json_str.c_str());
|
|
||||||
(void)PyTuple_SetItem(pArg, 0, arg1);
|
|
||||||
if (pArg == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject.";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
// call functions
|
|
||||||
if (pCheckSupportedFunc_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "function is nullptr.";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
pRes = PyEval_CallObject(pCheckSupportedFunc_, pArg);
|
|
||||||
if (pRes == nullptr) {
|
|
||||||
PyErr_Print();
|
|
||||||
MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc
|
|
||||||
<< "], function args: " << PyObjectToStr(pArg);
|
|
||||||
}
|
|
||||||
if (PyBool_Check(pRes)) {
|
|
||||||
ret = PyObject_IsTrue(pRes) != 0;
|
|
||||||
} else {
|
|
||||||
char *pstr = nullptr;
|
|
||||||
(void)PyArg_Parse(pRes, "s", &pstr);
|
|
||||||
std::string res_str = pstr;
|
|
||||||
if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) {
|
|
||||||
MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str
|
|
||||||
<< ", function args: " << PyObjectToStr(pArg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
PyObject *TbePythonFuncs::TbeParallelCompiler() {
|
|
||||||
if (!Init()) {
|
|
||||||
MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return pTbeCompiler_;
|
|
||||||
}
|
|
||||||
} // namespace kernel
|
|
||||||
} // namespace mindspore
|
|
@ -1,45 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_PYTHON_FUNCS_H_
|
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_PYTHON_FUNCS_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
#include "pybind11/stl.h"
|
|
||||||
#include "utils/log_adapter.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace kernel {
|
|
||||||
class TbePythonFuncs {
|
|
||||||
public:
|
|
||||||
TbePythonFuncs() = default;
|
|
||||||
~TbePythonFuncs() = default;
|
|
||||||
static std::string OpSelectFormat(const nlohmann::json &kernel_json);
|
|
||||||
static bool CheckSupported(const nlohmann::json &kernel_json);
|
|
||||||
static PyObject *TbeParallelCompiler();
|
|
||||||
|
|
||||||
private:
|
|
||||||
static bool Init();
|
|
||||||
static std::string PyObjectToStr(_object *PyObj);
|
|
||||||
static PyObject *pCreateTbeParallelCompilerFunc_;
|
|
||||||
static PyObject *pTbeCompiler_;
|
|
||||||
static PyObject *pOpSelectFormatFunc_;
|
|
||||||
static PyObject *pCheckSupportedFunc_;
|
|
||||||
};
|
|
||||||
} // namespace kernel
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_PYTHON_FUNCS_H_
|
|
@ -0,0 +1,105 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/session/kernel_build_client.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
void ReplaceStr(std::string *dest, const std::string &replace, char new_char) {
|
||||||
|
std::string::size_type start = 0;
|
||||||
|
while ((start = (*dest).find(replace, start)) != std::string::npos) {
|
||||||
|
(*dest).replace(start, replace.size(), 1, new_char);
|
||||||
|
start++; // Replaced 1 charactor.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int KernelBuildClient::Start(const std::string &json) {
|
||||||
|
// Start compiling..
|
||||||
|
std::string res = SendRequest(kSTART);
|
||||||
|
if (res != kACK) {
|
||||||
|
MS_LOG(ERROR) << "START failed, res: " << res;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
// Send the json data.
|
||||||
|
res = SendRequest(json);
|
||||||
|
if (res == kFAILED) {
|
||||||
|
MS_LOG(ERROR) << "START send data failed, res: " << res;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
// Return task id.
|
||||||
|
return std::stoi(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool KernelBuildClient::Wait(int *task_id, std::string *task_result, std::string *pre_build_result) {
|
||||||
|
// Start waiting..
|
||||||
|
std::string res = SendRequest(kWAIT);
|
||||||
|
if (res != kACK) {
|
||||||
|
MS_LOG(ERROR) << "WAIT failed, res: " << res;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Request task id.
|
||||||
|
*task_id = std::stoi(SendRequest(kCONT));
|
||||||
|
// Requst task result.
|
||||||
|
*task_result = SendRequest(kCONT);
|
||||||
|
// Request prebuild result.
|
||||||
|
*pre_build_result = SendRequest(kCONT);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KernelBuildClient::Reset() {
|
||||||
|
// Start compiling..
|
||||||
|
std::string res = SendRequest(kRESET);
|
||||||
|
if (res != kACK) {
|
||||||
|
MS_LOG(EXCEPTION) << "RESET response is: " << res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string KernelBuildClient::SelectFormat(const std::string &json) {
|
||||||
|
// Start compiling..
|
||||||
|
std::string res = SendRequest(kFORMAT);
|
||||||
|
if (res != kACK) {
|
||||||
|
MS_LOG(ERROR) << "FORMAT failed, res: " << res;
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
// Send the json data.
|
||||||
|
res = SendRequest(json);
|
||||||
|
if (res == kERR) {
|
||||||
|
MS_LOG(ERROR) << "FORMAT send data failed, res: " << res;
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool KernelBuildClient::CheckSupported(const std::string &json) {
|
||||||
|
// Checking support..
|
||||||
|
std::string res = SendRequest(kSUPPORT);
|
||||||
|
if (res != kACK) {
|
||||||
|
MS_LOG(ERROR) << "SUPPORT failed, res: " << res;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Send the json data.
|
||||||
|
res = SendRequest(json);
|
||||||
|
if (res != kTRUE) {
|
||||||
|
MS_LOG(ERROR) << "SUPPORT send data failed, res: " << res;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,188 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <cstring>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "common/duplex_pipe.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
void ReplaceStr(std::string *dest, const std::string &replace, char new_char);
|
||||||
|
class KernelBuildClient {
|
||||||
|
public:
|
||||||
|
// Server configure
|
||||||
|
constexpr inline static auto kEnv = "python";
|
||||||
|
constexpr inline static auto kGetPathScript =
|
||||||
|
"-c "
|
||||||
|
"\""
|
||||||
|
"import pkgutil;"
|
||||||
|
"path = pkgutil"
|
||||||
|
".get_loader(\\\"mindspore._extends.remote.kernel_build_server\\\")" // Server module name
|
||||||
|
".get_filename();"
|
||||||
|
"print('[~]' + path)"
|
||||||
|
"\"";
|
||||||
|
|
||||||
|
// Receive the response from server
|
||||||
|
constexpr inline static auto kACK = "ACK";
|
||||||
|
constexpr inline static auto kERR = "ERR";
|
||||||
|
constexpr inline static auto kFAILED = "-1";
|
||||||
|
// Send Finish request to server
|
||||||
|
constexpr inline static auto kFIN = "FIN";
|
||||||
|
|
||||||
|
// Send building request to server
|
||||||
|
constexpr inline static auto kSTART = "START";
|
||||||
|
constexpr inline static auto kWAIT = "WAIT";
|
||||||
|
constexpr inline static auto kCONT = "CONT";
|
||||||
|
constexpr inline static auto kSUCCESS = "Success";
|
||||||
|
constexpr inline static auto kRESET = "RESET";
|
||||||
|
|
||||||
|
// Send server info. query to server
|
||||||
|
constexpr inline static auto kFORMAT = "FORMAT";
|
||||||
|
constexpr inline static auto kSUPPORT = "SUPPORT";
|
||||||
|
constexpr inline static auto kTRUE = "True";
|
||||||
|
|
||||||
|
// Revert \n, \r, [space].
|
||||||
|
constexpr inline static auto kLF = "[LF]";
|
||||||
|
constexpr inline static auto kCR = "[CR]";
|
||||||
|
constexpr inline static auto kSP = "[SP]";
|
||||||
|
|
||||||
|
// The TAG as prefix of real command from remote.
|
||||||
|
constexpr inline static auto kTAG = "[~]";
|
||||||
|
|
||||||
|
constexpr inline static int kBufferSize = 4096;
|
||||||
|
constexpr inline static unsigned int kTimeOutSeconds = 20;
|
||||||
|
|
||||||
|
static KernelBuildClient &Instance() {
|
||||||
|
static KernelBuildClient instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
std::string GetScriptPath() {
|
||||||
|
std::string cmd = kEnv;
|
||||||
|
(void)cmd.append(1, ' ').append(kGetPathScript);
|
||||||
|
FILE *fpipe = popen(cmd.c_str(), "r");
|
||||||
|
if (fpipe == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "popen failed, " << strerror(errno) << "(" << errno << ")";
|
||||||
|
}
|
||||||
|
bool start = false;
|
||||||
|
std::string result;
|
||||||
|
char buf[kBufferSize];
|
||||||
|
while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) {
|
||||||
|
if (std::strncmp(buf, kTAG, std::strlen(kTAG)) == 0) {
|
||||||
|
start = true;
|
||||||
|
}
|
||||||
|
// Filter with 'kTAG' and '\n'
|
||||||
|
if (start) {
|
||||||
|
auto size = std::strlen(buf);
|
||||||
|
bool line_end = buf[size - 1] == '\n';
|
||||||
|
result.append(buf, line_end ? size - 1 : size);
|
||||||
|
if (line_end) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pclose(fpipe);
|
||||||
|
const std::string py_suffix = ".py";
|
||||||
|
if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) {
|
||||||
|
MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}";
|
||||||
|
}
|
||||||
|
result = result.substr(strlen(kTAG));
|
||||||
|
MS_LOG(DEBUG) << "result: " << result;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Open() {
|
||||||
|
if (!init_) {
|
||||||
|
// Exception's thrown if open failed
|
||||||
|
if (dp_->Open({kEnv, GetScriptPath()}, true) != -1) {
|
||||||
|
dp_->SetTimeOutSeconds(kTimeOutSeconds);
|
||||||
|
dp_->SetTimeOutCallback([this]() { SendRequest(kFIN); });
|
||||||
|
init_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Close() {
|
||||||
|
if (init_) {
|
||||||
|
dp_->Close();
|
||||||
|
init_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a request and fetch its response
|
||||||
|
std::string SendRequest(std::string data) {
|
||||||
|
Request(data);
|
||||||
|
return Response();
|
||||||
|
}
|
||||||
|
void Request(std::string req) {
|
||||||
|
if (!init_) {
|
||||||
|
MS_LOG(EXCEPTION) << "Try to send request before Open()";
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "\t[" << req << "]";
|
||||||
|
*dp_ << req;
|
||||||
|
}
|
||||||
|
std::string Response() {
|
||||||
|
if (!init_) {
|
||||||
|
MS_LOG(EXCEPTION) << "Try to get response before Open()";
|
||||||
|
}
|
||||||
|
std::string res;
|
||||||
|
*dp_ >> res;
|
||||||
|
// Filter out the interference
|
||||||
|
auto start = res.find(kTAG);
|
||||||
|
if (start == std::string::npos) {
|
||||||
|
MS_LOG(EXCEPTION) << "Response seems incorrect, res: " << res;
|
||||||
|
}
|
||||||
|
res = res.substr(start + std::strlen(kTAG), res.size() - start);
|
||||||
|
// Revert the line feed and space
|
||||||
|
if (res != kSUCCESS && res != kACK && res != kERR && res != kTRUE) {
|
||||||
|
ReplaceStr(&res, kLF, '\n');
|
||||||
|
ReplaceStr(&res, kSP, ' ');
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "\t[" << res << "]";
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Before building.
|
||||||
|
std::string SelectFormat(const std::string &json);
|
||||||
|
bool CheckSupported(const std::string &json);
|
||||||
|
|
||||||
|
// Run building.
|
||||||
|
int Start(const std::string &json);
|
||||||
|
bool Wait(int *task_id, std::string *task_result, std::string *pre_build_result);
|
||||||
|
void Reset();
|
||||||
|
|
||||||
|
KernelBuildClient(const KernelBuildClient &) = delete;
|
||||||
|
KernelBuildClient &operator=(const KernelBuildClient &) = delete;
|
||||||
|
|
||||||
|
KernelBuildClient(KernelBuildClient &&) = delete;
|
||||||
|
KernelBuildClient &operator=(KernelBuildClient &&) = delete;
|
||||||
|
|
||||||
|
private:
|
||||||
|
KernelBuildClient() : init_(false), dp_(std::make_shared<DuplexPipe>()) { Open(); }
|
||||||
|
~KernelBuildClient() { Close(); }
|
||||||
|
|
||||||
|
bool init_;
|
||||||
|
std::shared_ptr<DuplexPipe> dp_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
|
@ -1,3 +1,16 @@
|
|||||||
file(GLOB_RECURSE _COMMON_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
|
file(GLOB_RECURSE _COMMON_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
"trans.cc"
|
||||||
|
"utils.cc"
|
||||||
|
"duplex_pipe_win.cc"
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
file(GLOB_RECURSE _COMMON_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
"trans.cc"
|
||||||
|
"utils.cc"
|
||||||
|
"duplex_pipe.cc"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
set_property(SOURCE ${_COMMON_ALL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_COMMON)
|
set_property(SOURCE ${_COMMON_ALL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_COMMON)
|
||||||
add_library(_mindspore_common_obj OBJECT ${_COMMON_ALL_SRC_FILES})
|
add_library(_mindspore_common_obj OBJECT ${_COMMON_ALL_SRC_FILES})
|
||||||
|
@ -0,0 +1,160 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "common/duplex_pipe.h"
|
||||||
|
|
||||||
|
#include <signal.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
int DuplexPipe::Open(std::initializer_list<std::string> arg_list, bool append_fds) {
|
||||||
|
if (pipe(fd1_) == -1) {
|
||||||
|
DP_EXCEPTION << "pipe 1 failed, " << strerror(errno) << "(" << errno << ")";
|
||||||
|
}
|
||||||
|
if (pipe(fd2_) == -1) {
|
||||||
|
close(fd1_[0]);
|
||||||
|
close(fd1_[1]);
|
||||||
|
DP_EXCEPTION << "pipe 2 failed, " << strerror(errno) << "(" << errno << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
pid_ = fork();
|
||||||
|
if (pid_ < 0) {
|
||||||
|
close(fd1_[0]);
|
||||||
|
close(fd1_[1]);
|
||||||
|
close(fd2_[0]);
|
||||||
|
close(fd2_[1]);
|
||||||
|
DP_EXCEPTION << "fork failed, " << strerror(errno) << "(" << errno << ")";
|
||||||
|
} else if (pid_ == 0) { // Remote process
|
||||||
|
DP_INFO << "Remote process, pid: " << getpid() << ", " << fd1_[0] << "/" << fd2_[1];
|
||||||
|
remote_stdout_ = dup(STDOUT_FILENO);
|
||||||
|
remote_stdin_ = dup(STDIN_FILENO);
|
||||||
|
remote_stderr_ = dup(STDERR_FILENO);
|
||||||
|
close(fd1_[1]);
|
||||||
|
close(fd2_[0]);
|
||||||
|
if (!append_fds) {
|
||||||
|
dup2(fd1_[0], STDIN_FILENO);
|
||||||
|
dup2(fd2_[1], STDOUT_FILENO);
|
||||||
|
}
|
||||||
|
std::vector<const char *> args;
|
||||||
|
std::transform(arg_list.begin(), arg_list.end(), std::back_inserter(args),
|
||||||
|
[](const std::string &arg) -> const char * { return arg.c_str(); });
|
||||||
|
if (append_fds) {
|
||||||
|
std::string fd10 = std::to_string(fd1_[0]).c_str();
|
||||||
|
args.emplace_back(fd10.c_str());
|
||||||
|
std::string fd21 = std::to_string(fd2_[1]).c_str();
|
||||||
|
args.emplace_back(fd21.c_str());
|
||||||
|
}
|
||||||
|
args.emplace_back(nullptr);
|
||||||
|
if (execvp(args[0], const_cast<char *const *>(&args[0])) == -1) {
|
||||||
|
DP_EXCEPTION << "execute " << args[0] << " failed, " << strerror(errno) << "(" << errno << ")";
|
||||||
|
}
|
||||||
|
} else { // Local process
|
||||||
|
DP_INFO << "Local process, id: " << getpid() << ", " << fd2_[0] << "/" << fd1_[1];
|
||||||
|
local_stdout_ = dup(STDOUT_FILENO);
|
||||||
|
local_stdin_ = dup(STDIN_FILENO);
|
||||||
|
local_stderr_ = dup(STDERR_FILENO);
|
||||||
|
close(fd1_[0]);
|
||||||
|
close(fd2_[1]);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DuplexPipe::Write(const std::string &buf, bool flush) {
|
||||||
|
// Write the string into pipe
|
||||||
|
if (write(fd1_[1], buf.data(), buf.size()) == -1) {
|
||||||
|
DP_ERROR << "write failed, error: " << strerror(errno) << "(" << errno << ")";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (flush) {
|
||||||
|
// Flush into the pipe
|
||||||
|
if (write(fd1_[1], "\n", 1) == -1) {
|
||||||
|
DP_ERROR << "write failed, error: " << strerror(errno) << "(" << errno << ")";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DP_DEBUG << "<< [" << buf << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string DuplexPipe::Read() {
|
||||||
|
// Read the string from pipe
|
||||||
|
std::string buf;
|
||||||
|
ssize_t size;
|
||||||
|
// MAYBE BLOCKED
|
||||||
|
// Read one line or multiple lines
|
||||||
|
while (SetTimeOut(), (size = read(fd2_[0], c_buf_, kBufferSize)) > 0) { // Till reading something
|
||||||
|
CancelTimeOut();
|
||||||
|
DP_DEBUG << ">> [" << c_buf_ << "]";
|
||||||
|
bool line_end = c_buf_[size - 1] == '\n';
|
||||||
|
buf.append(c_buf_, line_end ? size - 1 : size); // Copy without the last '\n'
|
||||||
|
if (line_end) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DuplexPipe::WriteWithStdout(const std::string &buf, bool flush) {
|
||||||
|
dup2(fd1_[1], STDOUT_FILENO);
|
||||||
|
// Write the string into pipe
|
||||||
|
std::cout << buf;
|
||||||
|
if (flush) {
|
||||||
|
// Flush into the pipe
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
dup2(local_stdout_, STDOUT_FILENO);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string DuplexPipe::ReadWithStdin() {
|
||||||
|
std::string buf;
|
||||||
|
dup2(fd2_[0], STDIN_FILENO);
|
||||||
|
// Maybe blocked
|
||||||
|
SetTimeOut();
|
||||||
|
std::getline(std::cin, buf); // Not use 'std::cin >>' to include space
|
||||||
|
CancelTimeOut();
|
||||||
|
dup2(local_stdin_, STDIN_FILENO);
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
DuplexPipe &DuplexPipe::operator<<(const std::string &buf) {
|
||||||
|
Write(buf);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
DuplexPipe &DuplexPipe::operator>>(std::string &buf) {
|
||||||
|
buf = Read();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DuplexPipe::Close() {
|
||||||
|
close(fd1_[0]);
|
||||||
|
close(fd1_[1]);
|
||||||
|
close(fd2_[0]);
|
||||||
|
close(fd2_[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DuplexPipe::Alarm::Set(std::shared_ptr<DuplexPipe> dp, unsigned int interval_secs) {
|
||||||
|
dp_ = dp;
|
||||||
|
signal(SIGALRM, SigHandler);
|
||||||
|
alarm(interval_secs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DuplexPipe::Alarm::Cancel() {
|
||||||
|
alarm(0);
|
||||||
|
dp_.reset();
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,123 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_COMMON_DUPLEX_PIPE_H_
|
||||||
|
#define MINDSPORE_CCSRC_COMMON_DUPLEX_PIPE_H_
|
||||||
|
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#define DP_DEBUG MS_LOG(DEBUG) << "[DuplexPipe] "
|
||||||
|
#define DP_INFO MS_LOG(INFO) << "[DuplexPipe] "
|
||||||
|
#define DP_ERROR MS_LOG(ERROR) << "[DuplexPipe] "
|
||||||
|
#define DP_EXCEPTION MS_LOG(EXCEPTION) << "[DuplexPipe] "
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
// A tool to run a command as child process and build a duplex pipe between them.
|
||||||
|
// Similar to 'popen()', but use duplex not simplex pipe, more like 'socketpair'.
|
||||||
|
class DuplexPipe : public std::enable_shared_from_this<mindspore::DuplexPipe> {
|
||||||
|
public:
|
||||||
|
constexpr inline static int kBufferSize = 4096;
|
||||||
|
constexpr inline static unsigned int kTimeOutSeconds = 5;
|
||||||
|
|
||||||
|
DuplexPipe() = default;
|
||||||
|
~DuplexPipe() = default;
|
||||||
|
|
||||||
|
// Create a subprocess and open a duplex pipe between local and remote
|
||||||
|
int Open(std::initializer_list<std::string> arg_list, bool append_fds = false);
|
||||||
|
void Close();
|
||||||
|
void SetTimeOutSeconds(unsigned int secs) { time_out_secs_ = secs; }
|
||||||
|
void SetTimeOutCallback(const std::function<void()> &cb) {
|
||||||
|
has_time_out_callback_ = true;
|
||||||
|
time_out_callback_ = cb;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the 'buf' to remote stdin
|
||||||
|
void Write(const std::string &buf, bool flush = true);
|
||||||
|
// Read from remote stdout/stderr into 'c_buf_'
|
||||||
|
std::string Read();
|
||||||
|
|
||||||
|
void WriteWithStdout(const std::string &buf, bool flush);
|
||||||
|
std::string ReadWithStdin();
|
||||||
|
|
||||||
|
DuplexPipe &operator<<(const std::string &buf);
|
||||||
|
DuplexPipe &operator>>(std::string &buf);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void SetTimeOut() { alarm_.Set(shared_from_this(), time_out_secs_); }
|
||||||
|
void CancelTimeOut() { alarm_.Cancel(); }
|
||||||
|
void TimeOut() {
|
||||||
|
if (has_time_out_callback_) {
|
||||||
|
time_out_callback_();
|
||||||
|
}
|
||||||
|
Close();
|
||||||
|
DP_EXCEPTION << "Time out when read from pipe";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subprocess id in parent process,
|
||||||
|
// otherwise zero in child process.
|
||||||
|
pid_t pid_;
|
||||||
|
|
||||||
|
// Pipe: { Local:fd1_[1] --> Remote:fd1_[0] }
|
||||||
|
// Remote:fd1_[0] would be redirected by subprocess's stdin.
|
||||||
|
// Local:fd1_[1] would be used by 'Write()' as output.
|
||||||
|
int fd1_[2];
|
||||||
|
|
||||||
|
// Pipe: { Remote:fd2_[1] --> Local:fd2_[0] }
|
||||||
|
// Remote:fd2_[1] would be redirected by subprocess's stdout.
|
||||||
|
// Local:fd2_[0] would be used by 'Read()' as input.
|
||||||
|
int fd2_[2];
|
||||||
|
|
||||||
|
// // Used and returned by 'Read()'.
|
||||||
|
// std::string buf_;
|
||||||
|
char c_buf_[kBufferSize];
|
||||||
|
|
||||||
|
int local_stdin_;
|
||||||
|
int local_stdout_;
|
||||||
|
int local_stderr_;
|
||||||
|
int remote_stdin_;
|
||||||
|
int remote_stdout_;
|
||||||
|
int remote_stderr_;
|
||||||
|
|
||||||
|
class Alarm {
|
||||||
|
public:
|
||||||
|
Alarm() = default;
|
||||||
|
~Alarm() = default;
|
||||||
|
|
||||||
|
void Set(std::shared_ptr<DuplexPipe> dp, unsigned int interval_secs);
|
||||||
|
void Cancel();
|
||||||
|
|
||||||
|
private:
|
||||||
|
static void SigHandler(int sig) {
|
||||||
|
DP_INFO << "Signal: " << sig;
|
||||||
|
dp_->TimeOut();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline static std::shared_ptr<DuplexPipe> dp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
unsigned int time_out_secs_ = kTimeOutSeconds;
|
||||||
|
bool has_time_out_callback_ = false;
|
||||||
|
std::function<void()> time_out_callback_;
|
||||||
|
Alarm alarm_;
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_COMMON_DUPLEX_PIPE_H_
|
@ -0,0 +1,48 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "common/duplex_pipe.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
int DuplexPipe::Open(std::initializer_list<std::string> arg_list, bool append_fds) {
|
||||||
|
DP_EXCEPTION << "Not support for Windows by now.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void DuplexPipe::Write(const std::string &buf, bool flush) { DP_EXCEPTION << "Not support for Windows by now."; }
|
||||||
|
|
||||||
|
std::string DuplexPipe::Read() { DP_EXCEPTION << "Not support for Windows by now."; }
|
||||||
|
|
||||||
|
void DuplexPipe::WriteWithStdout(const std::string &buf, bool flush) {
|
||||||
|
DP_EXCEPTION << "Not support for Windows by now.";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string DuplexPipe::ReadWithStdin() { DP_EXCEPTION << "Not support for Windows by now."; }
|
||||||
|
|
||||||
|
DuplexPipe &DuplexPipe::operator<<(const std::string &buf) { DP_EXCEPTION << "Not support for Windows by now."; }
|
||||||
|
|
||||||
|
DuplexPipe &DuplexPipe::operator>>(std::string &buf) { DP_EXCEPTION << "Not support for Windows by now."; }
|
||||||
|
|
||||||
|
void DuplexPipe::Close() { DP_EXCEPTION << "Not support for Windows by now."; }
|
||||||
|
|
||||||
|
void DuplexPipe::Alarm::Set(std::shared_ptr<DuplexPipe> dp, unsigned int interval_secs) {
|
||||||
|
DP_EXCEPTION << "Not support for Windows by now.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void DuplexPipe::Alarm::Cancel() { DP_EXCEPTION << "Not support for Windows by now."; }
|
||||||
|
} // namespace mindspore
|
Loading…
Reference in new issue