|
|
|
@ -31,9 +31,9 @@ namespace paddle {
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* The static pruning hook
|
|
|
|
|
* Static means user specific a sparsity_ratio before training start, and the
|
|
|
|
|
* Static means user specify a sparsity_ratio before training started, and the
|
|
|
|
|
* network will prune the parameters based on the sparsity_ratio. More deatils
|
|
|
|
|
* can see https://arxiv.org/pdf/1506.02626.pdf.
|
|
|
|
|
* can be found https://arxiv.org/pdf/1506.02626.pdf.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
class StaticPruningHook : public IParameterUpdaterHook {
|
|
|
|
@ -57,29 +57,31 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void generateMask(Parameter* para) {
|
|
|
|
|
VectorPtr vec = para->getBuf(PARAMETER_VALUE);
|
|
|
|
|
maskTemp_ = Vector::create(para->getSize(), false);
|
|
|
|
|
maskTemp_->zeroMem();
|
|
|
|
|
real* dataPtr = maskTemp_->getData();
|
|
|
|
|
|
|
|
|
|
VectorPtr maskTemp = Vector::create(para->getSize(), false);
|
|
|
|
|
maskTemp->zeroMem();
|
|
|
|
|
real* maskTempData = maskTemp->getData();
|
|
|
|
|
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
|
|
|
|
|
|
|
|
|
|
VectorPtr vecCpu = Vector::create(para->getSize(), false);
|
|
|
|
|
vecCpu->copyFrom(*vec);
|
|
|
|
|
VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
|
|
|
|
|
VectorPtr paraCpuCopy = Vector::create(para->getSize(), false);
|
|
|
|
|
|
|
|
|
|
paraCpuCopy->copyFrom(*paraVec);
|
|
|
|
|
std::vector<std::pair<real, size_t>> param;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < para->getSize(); i++)
|
|
|
|
|
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
|
|
|
|
|
param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
|
|
|
|
|
|
|
|
|
|
std::partial_sort(
|
|
|
|
|
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
|
|
|
|
|
for (size_t i = 0; i < nonZeroNum; i++) dataPtr[param[i].second] = 1.0;
|
|
|
|
|
for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
|
|
|
|
|
|
|
|
|
|
// Currently just use a mask vector for hack.
|
|
|
|
|
if (para->useGpu()) {
|
|
|
|
|
maskVec_ = Vector::create(para->getSize(), para->useGpu());
|
|
|
|
|
maskVec_->copyFrom(*maskTemp_);
|
|
|
|
|
maskVec_->copyFrom(*maskTemp);
|
|
|
|
|
} else {
|
|
|
|
|
maskVec_ = maskTemp_;
|
|
|
|
|
maskVec_ = maskTemp;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -91,15 +93,14 @@ public:
|
|
|
|
|
VLOG(3) << "Initialize Parameter " << para;
|
|
|
|
|
SetDevice device(para->getDeviceId());
|
|
|
|
|
|
|
|
|
|
auto& vec = para->getBuf(PARAMETER_VALUE);
|
|
|
|
|
vec->dotMul(*maskVec_);
|
|
|
|
|
auto& paraVec = para->getBuf(PARAMETER_VALUE);
|
|
|
|
|
paraVec->dotMul(*maskVec_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
SameThreadChecker updateThreadChecker_;
|
|
|
|
|
std::atomic<size_t> initCount_;
|
|
|
|
|
VectorPtr maskVec_;
|
|
|
|
|
VectorPtr maskTemp_;
|
|
|
|
|
real sparsityRatio_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|