|
|
|
@ -92,6 +92,7 @@ void Bucket::CalculateMean() {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parallel_context);
|
|
|
|
|
auto grad_mean = parallel_context->gradients_mean();
|
|
|
|
|
if (!grad_mean) {
|
|
|
|
|
UpdateTensorOutputAddr(ar_output_addr_);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (launch_mul_ == nullptr) {
|
|
|
|
@ -102,12 +103,16 @@ void Bucket::CalculateMean() {
|
|
|
|
|
launch_mul_->SetInputAddr(ar_output_addr_);
|
|
|
|
|
// launch mean
|
|
|
|
|
launch_mul_->LaunchOpKernel();
|
|
|
|
|
// store output tensor addr
|
|
|
|
|
// store tensor output addr
|
|
|
|
|
auto launch_output = launch_mul_->GetKernelOutputAddr();
|
|
|
|
|
if (launch_output.size() != 1) {
|
|
|
|
|
MS_LOG(ERROR) << "launch mul outputs should have one output";
|
|
|
|
|
MS_LOG(EXCEPTION) << "launch mul outputs should have one output";
|
|
|
|
|
}
|
|
|
|
|
uint8_t *tensor_output = launch_output[0];
|
|
|
|
|
UpdateTensorOutputAddr(launch_output[0]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Bucket::UpdateTensorOutputAddr(uint8_t *addr) {
|
|
|
|
|
uint8_t *tensor_output = addr;
|
|
|
|
|
for (size_t i = 0; i < bucket_size_; ++i) {
|
|
|
|
|
new_tensor_output_addrs_.emplace_back(tensor_output);
|
|
|
|
|
tensor_output += align_size_list_[i];
|
|
|
|
|