|
|
|
@ -20,9 +20,11 @@
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <nlohmann/json.hpp>
|
|
|
|
|
#include "runtime/device/device_address.h"
|
|
|
|
|
#include "runtime/device/ascend/ascend_memory_pool.h"
|
|
|
|
|
#include "ir/dtype.h"
|
|
|
|
|
#include "backend/kernel_compiler/kernel.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
#ifdef ENABLE_DEBUGGER
|
|
|
|
@ -53,7 +55,16 @@ class AscendDeviceAddress : public DeviceAddress {
|
|
|
|
|
bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const;
|
|
|
|
|
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
|
|
|
|
|
const void *host_ptr) const;
|
|
|
|
|
bool SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector<size_t> &host_shape,
|
|
|
|
|
const std::vector<size_t> &device_shape, size_t size,
|
|
|
|
|
mindspore::TypeId type, void *host_ptr) const;
|
|
|
|
|
void SyncStream() const;
|
|
|
|
|
|
|
|
|
|
void LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, void *output_address_ptr, size_t output_size,
|
|
|
|
|
const std::vector<size_t> &workspace_size_list) const;
|
|
|
|
|
std::vector<size_t> GetDeviceShape(std::vector<size_t> *host_shape) const;
|
|
|
|
|
std::vector<size_t> GetWorkspaceSizeList(const nlohmann::json &kernel_json) const;
|
|
|
|
|
kernel::KernelModPtr CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const;
|
|
|
|
|
};
|
|
|
|
|
using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>;
|
|
|
|
|
} // namespace ascend
|
|
|
|
|