!14409 fix a bug in launch allreduce

From: @lvchangquan
Reviewed-by: @chujinjin,@jjfeing
Signed-off-by: @jjfeing
pull/14409/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 08189a1428

@ -92,6 +92,7 @@ void Bucket::CalculateMean() {
MS_EXCEPTION_IF_NULL(parallel_context); MS_EXCEPTION_IF_NULL(parallel_context);
auto grad_mean = parallel_context->gradients_mean(); auto grad_mean = parallel_context->gradients_mean();
if (!grad_mean) { if (!grad_mean) {
UpdateTensorOutputAddr(ar_output_addr_);
return; return;
} }
if (launch_mul_ == nullptr) { if (launch_mul_ == nullptr) {
@ -102,12 +103,16 @@ void Bucket::CalculateMean() {
launch_mul_->SetInputAddr(ar_output_addr_); launch_mul_->SetInputAddr(ar_output_addr_);
// launch mean // launch mean
launch_mul_->LaunchOpKernel(); launch_mul_->LaunchOpKernel();
// store output tensor addr // store tensor output addr
auto launch_output = launch_mul_->GetKernelOutputAddr(); auto launch_output = launch_mul_->GetKernelOutputAddr();
if (launch_output.size() != 1) { if (launch_output.size() != 1) {
MS_LOG(ERROR) << "launch mul outputs should have one output"; MS_LOG(EXCEPTION) << "launch mul outputs should have one output";
} }
uint8_t *tensor_output = launch_output[0]; UpdateTensorOutputAddr(launch_output[0]);
}
void Bucket::UpdateTensorOutputAddr(uint8_t *addr) {
uint8_t *tensor_output = addr;
for (size_t i = 0; i < bucket_size_; ++i) { for (size_t i = 0; i < bucket_size_; ++i) {
new_tensor_output_addrs_.emplace_back(tensor_output); new_tensor_output_addrs_.emplace_back(tensor_output);
tensor_output += align_size_list_[i]; tensor_output += align_size_list_[i];

@ -84,6 +84,7 @@ class Bucket {
virtual void FreeAllDeviceMem() = 0; virtual void FreeAllDeviceMem() = 0;
virtual void FreeDeviceMem(void *dev_ptr) = 0; virtual void FreeDeviceMem(void *dev_ptr) = 0;
virtual void CopyTensorToContiguousMemory() = 0; virtual void CopyTensorToContiguousMemory() = 0;
void UpdateTensorOutputAddr(uint8_t *addr);
void LazyDeleteOldAddr(); void LazyDeleteOldAddr();
}; };
} // namespace mindspore::device } // namespace mindspore::device

Loading…
Cancel
Save