|
|
|
@ -28,7 +28,7 @@ inline float get_period_sparcity(const std::vector<float>& sparsity,
|
|
|
|
|
|
|
|
|
|
size_t idx = static_cast<int>(cur_step * sparsity.size() / rampup_steps);
|
|
|
|
|
if (idx >= sparsity.size()) {
|
|
|
|
|
return 0.999;
|
|
|
|
|
idx = sparsity.size() - 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LT(idx, sparsity.size());
|
|
|
|
@ -102,8 +102,9 @@ class DGCOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float ratio =
|
|
|
|
|
1 - get_period_sparcity(sparsity, static_cast<float>(*current_step),
|
|
|
|
|
rampup_step);
|
|
|
|
|
1 - get_period_sparcity(
|
|
|
|
|
sparsity, static_cast<float>(*current_step - rampup_begin_step),
|
|
|
|
|
rampup_step);
|
|
|
|
|
PADDLE_ENFORCE_GE(ratio, 0.0);
|
|
|
|
|
PADDLE_ENFORCE_LT(ratio, 1.0);
|
|
|
|
|
int k = static_cast<int>(g->numel() * ratio);
|
|
|
|
|