insert trans_data to reduce time in print process

pull/3174/head
lvchangquan 5 years ago
parent 4e0cfafcf9
commit 7b48a122dd

@ -20,9 +20,11 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <nlohmann/json.hpp>
#include "runtime/device/device_address.h" #include "runtime/device/device_address.h"
#include "runtime/device/ascend/ascend_memory_pool.h" #include "runtime/device/ascend/ascend_memory_pool.h"
#include "ir/dtype.h" #include "ir/dtype.h"
#include "backend/kernel_compiler/kernel.h"
namespace mindspore { namespace mindspore {
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
@ -53,7 +55,16 @@ class AscendDeviceAddress : public DeviceAddress {
bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const; bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const;
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
const void *host_ptr) const; const void *host_ptr) const;
bool SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector<size_t> &host_shape,
const std::vector<size_t> &device_shape, size_t size,
mindspore::TypeId type, void *host_ptr) const;
void SyncStream() const; void SyncStream() const;
void LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, void *output_address_ptr, size_t output_size,
const std::vector<size_t> &workspace_size_list) const;
std::vector<size_t> GetDeviceShape(std::vector<size_t> *host_shape) const;
std::vector<size_t> GetWorkspaceSizeList(const nlohmann::json &kernel_json) const;
kernel::KernelModPtr CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const;
}; };
using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>; using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>;
} // namespace ascend } // namespace ascend

@ -757,6 +757,18 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
} }
bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, AddressPtrList kernel_inputs,
AddressPtrList kernel_outputs,
AddressPtrList kernel_workspaces) const {
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
auto ret = kernel_mod_ptr->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
return false;
}
return true;
}
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
bool KernelRuntime::SetDumpConf() { bool KernelRuntime::SetDumpConf() {
dump_conf_ptr_ = std::make_shared<Dump>(); dump_conf_ptr_ = std::make_shared<Dump>();

@ -61,6 +61,8 @@ class KernelRuntime {
virtual bool RunTask(const session::KernelGraph *graph); virtual bool RunTask(const session::KernelGraph *graph);
virtual bool GenTask(const session::KernelGraph *graph); virtual bool GenTask(const session::KernelGraph *graph);
bool LaunchKernel(const session::KernelGraph *graph); bool LaunchKernel(const session::KernelGraph *graph);
bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, AddressPtrList kernel_inputs,
AddressPtrList kernel_outputs, AddressPtrList kernel_workspaces) const;
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
virtual void ClearGraphRuntimeResource(uint32_t graph_id); virtual void ClearGraphRuntimeResource(uint32_t graph_id);

Loading…
Cancel
Save