|
|
|
@ -155,15 +155,14 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
input_params.v_ = v;
|
|
|
|
|
input_params.beta1_ = beta1;
|
|
|
|
|
input_params.beta2_ = beta2;
|
|
|
|
|
const size_t kThreadNum = 16;
|
|
|
|
|
MultiThreadCompute(ComputeMomentum, &input_params, kThreadNum, total_dim_size);
|
|
|
|
|
MultiThreadCompute(ComputeMomentum, &input_params, total_dim_size);
|
|
|
|
|
|
|
|
|
|
input_params.m_t_ = m_t;
|
|
|
|
|
input_params.use_nesterov_ = use_nesterov_;
|
|
|
|
|
input_params.sparse_grad_ = unique_sparse_grad;
|
|
|
|
|
input_params.var_first_dim_size_ = var_first_dim_size_;
|
|
|
|
|
input_params.var_outer_dim_size_ = var_outer_dim_size_;
|
|
|
|
|
MultiThreadCompute(ComputeAdam, &input_params, kThreadNum, unique_sparse_grad.indices_size_);
|
|
|
|
|
MultiThreadCompute(ComputeAdam, &input_params, unique_sparse_grad.indices_size_);
|
|
|
|
|
|
|
|
|
|
if (use_nesterov_) {
|
|
|
|
|
input_params.m_ = input_params.m_t_;
|
|
|
|
@ -171,7 +170,7 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
input_params.var_ = var;
|
|
|
|
|
input_params.lr_ = lr;
|
|
|
|
|
input_params.epsilon_ = epsilon;
|
|
|
|
|
MultiThreadCompute(ComputeWeight, &input_params, kThreadNum, total_dim_size);
|
|
|
|
|
MultiThreadCompute(ComputeWeight, &input_params, total_dim_size);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace kernel
|
|
|
|
|