|
|
|
@ -335,15 +335,16 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v
|
|
|
|
|
AddressPtrList kernel_inputs = {input_address};
|
|
|
|
|
AddressPtrList kernel_outputs = {output_address};
|
|
|
|
|
AddressPtrList kernel_workspaces;
|
|
|
|
|
std::vector<DeviceAddressPtr> workspace_address_ptr(workspace_size_list.size());
|
|
|
|
|
if (!workspace_size_list.empty()) {
|
|
|
|
|
for (size_t i = 0; i < workspace_size_list.size(); ++i) {
|
|
|
|
|
auto workspace_size = GetCommonAlignSize(workspace_size_list[i]);
|
|
|
|
|
auto workspace_address_ptr = AssignLaunchMemory(workspace_size, "", kTypeUnknown);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(workspace_address_ptr);
|
|
|
|
|
workspace_address_ptr[i] = AssignLaunchMemory(workspace_size, "", kTypeUnknown);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(workspace_address_ptr[i]);
|
|
|
|
|
auto workspace_address = std::make_shared<kernel::Address>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(workspace_address);
|
|
|
|
|
workspace_address->addr = workspace_address_ptr->GetMutablePtr();
|
|
|
|
|
workspace_address->size = workspace_address_ptr->GetSize();
|
|
|
|
|
workspace_address->addr = workspace_address_ptr[i]->GetMutablePtr();
|
|
|
|
|
workspace_address->size = workspace_address_ptr[i]->GetSize();
|
|
|
|
|
kernel_workspaces.push_back(workspace_address);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -357,6 +358,7 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v
|
|
|
|
|
if (!ret) {
|
|
|
|
|
MS_LOG(ERROR) << "Launch kernel failed.";
|
|
|
|
|
}
|
|
|
|
|
SyncStream();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const {
|
|
|
|
@ -421,7 +423,6 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
|
|
|
|
|
auto workspace_size_list = GetWorkspaceSizeList(kernel_json);
|
|
|
|
|
// launch
|
|
|
|
|
LaunchTransData(kernel_mod_ptr, output_address->GetMutablePtr(), output_address->GetSize(), workspace_size_list);
|
|
|
|
|
SyncStream();
|
|
|
|
|
if (type_id_ == type) {
|
|
|
|
|
SyncMemory(host_ptr, output_address->GetPtr(), host_size, RT_MEMCPY_DEVICE_TO_HOST);
|
|
|
|
|
} else {
|
|
|
|
|