|
|
|
@ -48,6 +48,13 @@ bool MemCpyAsyncKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|
|
|
|
MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (outputs[0]->size < inputs[0]->size) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size";
|
|
|
|
|
}
|
|
|
|
|
// input x -> memcpy_async -> AllReduce
|
|
|
|
|
if (outputs[0]->size > inputs[0]->size) {
|
|
|
|
|
MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size";
|
|
|
|
|
}
|
|
|
|
|
rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size,
|
|
|
|
|
RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr);
|
|
|
|
|
if (status != RT_ERROR_NONE) {
|
|
|
|
@ -70,7 +77,7 @@ void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) {
|
|
|
|
|
if (input_size != 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1";
|
|
|
|
|
}
|
|
|
|
|
input_type_id_ = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, 0);
|
|
|
|
|
input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) {
|
|
|
|
@ -102,6 +109,14 @@ std::vector<TaskInfoPtr> MemCpyAsyncKernel::GenTask(const std::vector<AddressPtr
|
|
|
|
|
MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (outputs[0]->size < inputs[0]->size) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size";
|
|
|
|
|
}
|
|
|
|
|
// input x -> memcpy_async -> AllReduce
|
|
|
|
|
if (outputs[0]->size > inputs[0]->size) {
|
|
|
|
|
MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
stream_id_ = stream_id;
|
|
|
|
|
std::shared_ptr<MemcpyAsyncTaskInfo> task_info_ptr = std::make_shared<MemcpyAsyncTaskInfo>(
|
|
|
|
|
stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE);
|
|
|
|
|