|
|
|
@ -14,12 +14,66 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
#include "kernel/cpu/sparse_apply_adam_cpu_kernel.h"
|
|
|
|
|
#include "kernel/common_utils.h"
|
|
|
|
|
#include "device/cpu/cpu_device_address.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr size_t kSparseApplyAdamInputSize = 11;
|
|
|
|
|
|
|
|
|
|
void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_params);
|
|
|
|
|
auto m = input_params->m_;
|
|
|
|
|
auto m_t = input_params->m_t_;
|
|
|
|
|
auto v = input_params->v_;
|
|
|
|
|
auto beta1 = input_params->beta1_;
|
|
|
|
|
auto beta2 = input_params->beta2_;
|
|
|
|
|
auto use_nesterov = input_params->use_nesterov_;
|
|
|
|
|
auto unique_sparse_grad = input_params->sparse_grad_;
|
|
|
|
|
auto var_first_dim_size = input_params->var_first_dim_size_;
|
|
|
|
|
auto var_outer_dim_size = input_params->var_outer_dim_size_;
|
|
|
|
|
for (size_t i = start; i < end; ++i) {
|
|
|
|
|
int index = unique_sparse_grad.indices_[i];
|
|
|
|
|
if (index < 0 || IntToSize(index) >= var_first_dim_size) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process";
|
|
|
|
|
}
|
|
|
|
|
size_t start_index = var_outer_dim_size * index;
|
|
|
|
|
size_t end_index = start_index + var_outer_dim_size;
|
|
|
|
|
for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) {
|
|
|
|
|
auto summed_grad = unique_sparse_grad.value_[k];
|
|
|
|
|
m[j] += (1 - beta1) * summed_grad;
|
|
|
|
|
v[j] += (1 - beta2) * summed_grad * summed_grad;
|
|
|
|
|
if (use_nesterov) {
|
|
|
|
|
m_t[j] = m[j] * beta1 + (1 - beta1) * summed_grad;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_t end) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_params);
|
|
|
|
|
auto m = input_params->m_;
|
|
|
|
|
auto v = input_params->v_;
|
|
|
|
|
auto beta1 = input_params->beta1_;
|
|
|
|
|
auto beta2 = input_params->beta2_;
|
|
|
|
|
for (size_t i = start; i < end; ++i) {
|
|
|
|
|
m[i] *= beta1;
|
|
|
|
|
v[i] *= beta2;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t end) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_params);
|
|
|
|
|
auto var = input_params->var_;
|
|
|
|
|
auto m = input_params->m_;
|
|
|
|
|
auto v = input_params->v_;
|
|
|
|
|
auto lr = input_params->lr_;
|
|
|
|
|
auto epsilon = input_params->epsilon_;
|
|
|
|
|
for (size_t i = start; i < end; ++i) {
|
|
|
|
|
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
|
|
|
@ -64,29 +118,6 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SparseApplyAdamCPUKernel::UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t,
|
|
|
|
|
float *v, float beta1, float beta2) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(m);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(m_t);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(v);
|
|
|
|
|
for (size_t i = 0; i < unique_sparse_grad.indices_size_; ++i) {
|
|
|
|
|
int index = unique_sparse_grad.indices_[i];
|
|
|
|
|
if (index < 0 || IntToSize(index) >= var_first_dim_size_) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process";
|
|
|
|
|
}
|
|
|
|
|
size_t start_index = var_outer_dim_size_ * index;
|
|
|
|
|
size_t end_index = start_index + var_outer_dim_size_;
|
|
|
|
|
for (size_t j = start_index, k = var_outer_dim_size_ * i; j < end_index; ++j, ++k) {
|
|
|
|
|
auto summed_grad = unique_sparse_grad.value_[k];
|
|
|
|
|
m[j] += (1 - beta1) * summed_grad;
|
|
|
|
|
v[j] += (1 - beta2) * summed_grad * summed_grad;
|
|
|
|
|
if (use_nesterov_) {
|
|
|
|
|
m_t[j] = m[j] * beta1 + (1 - beta1) * summed_grad;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &workspace,
|
|
|
|
|
const std::vector<kernel::AddressPtr> & /*outputs*/) {
|
|
|
|
@ -115,21 +146,31 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_,
|
|
|
|
|
var_outer_dim_size_);
|
|
|
|
|
size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_;
|
|
|
|
|
// Update momentum
|
|
|
|
|
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
|
|
|
|
|
for (size_t i = 0; i < total_dim_size; ++i) {
|
|
|
|
|
m[i] *= beta1;
|
|
|
|
|
v[i] *= beta2;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MultiThreadComputeParams input_params;
|
|
|
|
|
input_params.m_ = m;
|
|
|
|
|
input_params.v_ = v;
|
|
|
|
|
input_params.beta1_ = beta1;
|
|
|
|
|
input_params.beta2_ = beta2;
|
|
|
|
|
const size_t kThreadNum = 16;
|
|
|
|
|
MultiThreadCompute(ComputeMomentum, &input_params, kThreadNum, total_dim_size);
|
|
|
|
|
|
|
|
|
|
std::vector<float> m_t(m, m + total_dim_size);
|
|
|
|
|
UpdateSparseMomentum(unique_sparse_grad, m, m_t.data(), v, beta1, beta2);
|
|
|
|
|
// Update weight
|
|
|
|
|
input_params.m_t_ = m_t.data();
|
|
|
|
|
input_params.use_nesterov_ = use_nesterov_;
|
|
|
|
|
input_params.sparse_grad_ = unique_sparse_grad;
|
|
|
|
|
input_params.var_first_dim_size_ = var_first_dim_size_;
|
|
|
|
|
input_params.var_outer_dim_size_ = var_outer_dim_size_;
|
|
|
|
|
MultiThreadCompute(ComputeAdam, &input_params, kThreadNum, unique_sparse_grad.indices_size_);
|
|
|
|
|
|
|
|
|
|
if (use_nesterov_) {
|
|
|
|
|
m = m_t.data();
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < total_dim_size; ++i) {
|
|
|
|
|
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
|
|
|
|
|
input_params.m_ = input_params.m_t_;
|
|
|
|
|
}
|
|
|
|
|
input_params.var_ = var;
|
|
|
|
|
input_params.lr_ = lr;
|
|
|
|
|
input_params.epsilon_ = epsilon;
|
|
|
|
|
MultiThreadCompute(ComputeWeight, &input_params, kThreadNum, total_dim_size);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace kernel
|
|
|
|
|