|
|
|
@ -22,7 +22,8 @@ namespace kernel {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr size_t kSparseApplyAdamInputSize = 11;
|
|
|
|
|
|
|
|
|
|
void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) {
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ComputeAdam(MultiThreadComputeParams<T> *input_params, size_t start, size_t end) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_params);
|
|
|
|
|
auto m = input_params->m_;
|
|
|
|
|
auto m_t = input_params->m_t_;
|
|
|
|
@ -34,8 +35,8 @@ void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t en
|
|
|
|
|
const auto var_first_dim_size = input_params->var_first_dim_size_;
|
|
|
|
|
const auto var_outer_dim_size = input_params->var_outer_dim_size_;
|
|
|
|
|
for (size_t i = start; i < end; ++i) {
|
|
|
|
|
int index = unique_sparse_grad.indices_[i];
|
|
|
|
|
if (index < 0 || IntToSize(index) >= var_first_dim_size) {
|
|
|
|
|
T index = unique_sparse_grad.indices_[i];
|
|
|
|
|
if (index < 0 || LongToSize(index) >= var_first_dim_size) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process";
|
|
|
|
|
}
|
|
|
|
|
size_t start_index = var_outer_dim_size * index;
|
|
|
|
@ -51,7 +52,8 @@ void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t en
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_t end) {
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ComputeMomentum(MultiThreadComputeParams<T> *input_params, size_t start, size_t end) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_params);
|
|
|
|
|
auto m = input_params->m_;
|
|
|
|
|
auto v = input_params->v_;
|
|
|
|
@ -63,7 +65,8 @@ void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t end) {
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ComputeWeight(MultiThreadComputeParams<T> *input_params, size_t start, size_t end) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_params);
|
|
|
|
|
auto var = input_params->var_;
|
|
|
|
|
const auto *m = input_params->m_;
|
|
|
|
@ -76,16 +79,24 @@ void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
|
|
|
|
CPUKernel::InitInputOutputSize(kernel_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
template <typename T>
|
|
|
|
|
void SparseApplyAdamCPUKernel::InitWorkspaceSize() {
|
|
|
|
|
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
|
|
|
|
|
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
|
|
|
|
|
workspace_size_list_.emplace_back(indices_size_ * sizeof(T));
|
|
|
|
|
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
|
|
|
|
|
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
|
|
|
|
|
workspace_size_list_.emplace_back(indices_size_ * sizeof(T));
|
|
|
|
|
workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
|
|
|
|
CPUKernel::InitInputOutputSize(kernel_node);
|
|
|
|
|
if (indices_data_type_ == kNumberTypeInt32) {
|
|
|
|
|
InitWorkspaceSize<int>();
|
|
|
|
|
} else {
|
|
|
|
|
InitWorkspaceSize<int64_t>();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
@ -119,15 +130,12 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) {
|
|
|
|
|
use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
|
|
|
|
|
}
|
|
|
|
|
indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 10);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &workspace,
|
|
|
|
|
const std::vector<kernel::AddressPtr> & /*outputs*/) {
|
|
|
|
|
if (inputs.size() < kSparseApplyAdamInputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Error input size!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void SparseApplyAdamCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &workspace) const {
|
|
|
|
|
auto var = reinterpret_cast<float *>(inputs[0]->addr);
|
|
|
|
|
auto m = reinterpret_cast<float *>(inputs[1]->addr);
|
|
|
|
|
auto v = reinterpret_cast<float *>(inputs[2]->addr);
|
|
|
|
@ -141,17 +149,17 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
auto beta2 = reinterpret_cast<float *>(inputs[7]->addr)[0];
|
|
|
|
|
auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0];
|
|
|
|
|
auto grad = reinterpret_cast<float *>(inputs[9]->addr);
|
|
|
|
|
auto indices = reinterpret_cast<int *>(inputs[10]->addr);
|
|
|
|
|
auto indices = reinterpret_cast<T *>(inputs[10]->addr);
|
|
|
|
|
auto new_grad = reinterpret_cast<float *>(workspace[0]->addr);
|
|
|
|
|
auto new_indices = reinterpret_cast<int *>(workspace[1]->addr);
|
|
|
|
|
auto new_indices = reinterpret_cast<T *>(workspace[1]->addr);
|
|
|
|
|
auto workspace_grad = reinterpret_cast<float *>(workspace[2]->addr);
|
|
|
|
|
auto workspace_indices = reinterpret_cast<int *>(workspace[3]->addr);
|
|
|
|
|
auto workspace_indices = reinterpret_cast<T *>(workspace[3]->addr);
|
|
|
|
|
auto m_t = reinterpret_cast<float *>(workspace[4]->addr);
|
|
|
|
|
|
|
|
|
|
SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_});
|
|
|
|
|
SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_});
|
|
|
|
|
SparseGradient input_sparse_grad({grad, indices, indices_size_});
|
|
|
|
|
ReduceSparseGradientParam param;
|
|
|
|
|
SparseGradient<T> unique_sparse_grad({new_grad, new_indices, indices_size_});
|
|
|
|
|
SparseGradient<T> workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_});
|
|
|
|
|
SparseGradient<T> input_sparse_grad({grad, indices, indices_size_});
|
|
|
|
|
ReduceSparseGradientParam<T> param;
|
|
|
|
|
param.input_grad_ = &input_sparse_grad;
|
|
|
|
|
param.workspace_grad_ = &workspace_sparse_grad;
|
|
|
|
|
param.output_grad_ = &unique_sparse_grad;
|
|
|
|
@ -162,19 +170,19 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_;
|
|
|
|
|
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
|
|
|
|
|
|
|
|
|
|
MultiThreadComputeParams input_params;
|
|
|
|
|
MultiThreadComputeParams<T> input_params;
|
|
|
|
|
input_params.m_ = m;
|
|
|
|
|
input_params.v_ = v;
|
|
|
|
|
input_params.beta1_ = beta1;
|
|
|
|
|
input_params.beta2_ = beta2;
|
|
|
|
|
MultiThreadCompute(ComputeMomentum, &input_params, total_dim_size);
|
|
|
|
|
MultiThreadCompute<T>(ComputeMomentum<T>, &input_params, total_dim_size);
|
|
|
|
|
|
|
|
|
|
input_params.m_t_ = m_t;
|
|
|
|
|
input_params.use_nesterov_ = use_nesterov_;
|
|
|
|
|
input_params.sparse_grad_ = unique_sparse_grad;
|
|
|
|
|
input_params.var_first_dim_size_ = var_first_dim_size_;
|
|
|
|
|
input_params.var_outer_dim_size_ = var_outer_dim_size_;
|
|
|
|
|
MultiThreadCompute(ComputeAdam, &input_params, unique_sparse_grad.indices_size_);
|
|
|
|
|
MultiThreadCompute<T>(ComputeAdam<T>, &input_params, unique_sparse_grad.indices_size_);
|
|
|
|
|
|
|
|
|
|
if (use_nesterov_) {
|
|
|
|
|
input_params.m_ = input_params.m_t_;
|
|
|
|
@ -182,7 +190,20 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
input_params.var_ = var;
|
|
|
|
|
input_params.lr_ = lr;
|
|
|
|
|
input_params.epsilon_ = epsilon;
|
|
|
|
|
MultiThreadCompute(ComputeWeight, &input_params, total_dim_size);
|
|
|
|
|
MultiThreadCompute<T>(ComputeWeight<T>, &input_params, total_dim_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &workspace,
|
|
|
|
|
const std::vector<kernel::AddressPtr> & /*outputs*/) {
|
|
|
|
|
if (inputs.size() < kSparseApplyAdamInputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Error input size!";
|
|
|
|
|
}
|
|
|
|
|
if (indices_data_type_ == kNumberTypeInt32) {
|
|
|
|
|
LaunchKernel<int>(inputs, workspace);
|
|
|
|
|
} else {
|
|
|
|
|
LaunchKernel<int64_t>(inputs, workspace);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace kernel
|
|
|
|
|