|
|
|
@ -38,29 +38,28 @@ namespace paddle {
|
|
|
|
|
|
|
|
|
|
class StaticPruningHook : public IParameterUpdaterHook {
|
|
|
|
|
public:
|
|
|
|
|
explicit StaticPruningHook(const ParameterUpdaterHookConfig& hookConfig)
|
|
|
|
|
explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig)
|
|
|
|
|
: initCount_(0) {
|
|
|
|
|
sparsityRatio_ = hookConfig.sparsity_ratio();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool sortPairAscend(const std::pair<real, size_t>& pair1,
|
|
|
|
|
const std::pair<real, size_t>& pair2) {
|
|
|
|
|
static bool sortPairAscend(const std::pair<real, size_t> &pair1,
|
|
|
|
|
const std::pair<real, size_t> &pair2) {
|
|
|
|
|
return pair1.first > pair2.first;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void update(Parameter* para) {
|
|
|
|
|
void update(Parameter *para) {
|
|
|
|
|
updateThreadChecker_.check();
|
|
|
|
|
auto& vec = para->getBuf(PARAMETER_GRADIENT);
|
|
|
|
|
auto &vec = para->getBuf(PARAMETER_GRADIENT);
|
|
|
|
|
if (vec) {
|
|
|
|
|
vec->dotMul(*maskVec_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void generateMask(Parameter* para) {
|
|
|
|
|
|
|
|
|
|
void generateMask(Parameter *para) {
|
|
|
|
|
VectorPtr maskTemp = Vector::create(para->getSize(), false);
|
|
|
|
|
maskTemp->zeroMem();
|
|
|
|
|
real* maskTempData = maskTemp->getData();
|
|
|
|
|
real *maskTempData = maskTemp->getData();
|
|
|
|
|
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
|
|
|
|
|
|
|
|
|
|
VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
|
|
|
|
@ -72,9 +71,10 @@ public:
|
|
|
|
|
for (size_t i = 0; i < para->getSize(); i++)
|
|
|
|
|
param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
|
|
|
|
|
|
|
|
|
|
std::partial_sort(
|
|
|
|
|
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
|
|
|
|
|
for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
|
|
|
|
|
std::partial_sort(param.begin(), param.begin() + nonZeroNum, param.end(),
|
|
|
|
|
sortPairAscend);
|
|
|
|
|
for (size_t i = 0; i < nonZeroNum; i++)
|
|
|
|
|
maskTempData[param[i].second] = 1.0;
|
|
|
|
|
|
|
|
|
|
// Currently just use a mask vector for hack.
|
|
|
|
|
if (para->useGpu()) {
|
|
|
|
@ -85,7 +85,7 @@ public:
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void init(Parameter* para) {
|
|
|
|
|
void init(Parameter *para) {
|
|
|
|
|
generateMask(para);
|
|
|
|
|
size_t initCount = this->initCount_.fetch_add(1);
|
|
|
|
|
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
|
|
|
|
@ -93,7 +93,7 @@ public:
|
|
|
|
|
VLOG(3) << "Initialize Parameter " << para;
|
|
|
|
|
SetDevice device(para->getDeviceId());
|
|
|
|
|
|
|
|
|
|
auto& paraVec = para->getBuf(PARAMETER_VALUE);
|
|
|
|
|
auto ¶Vec = para->getBuf(PARAMETER_VALUE);
|
|
|
|
|
paraVec->dotMul(*maskVec_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -118,7 +118,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {}
|
|
|
|
|
*/
|
|
|
|
|
class StringIntPairHasher {
|
|
|
|
|
public:
|
|
|
|
|
size_t operator()(const std::pair<std::string, int>& k) const {
|
|
|
|
|
size_t operator()(const std::pair<std::string, int> &k) const {
|
|
|
|
|
return intHasher_(strHasher_(k.first) + k.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -127,17 +127,15 @@ private:
|
|
|
|
|
std::hash<int> intHasher_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static WeakKVCache<std::pair<std::string, int>,
|
|
|
|
|
IParameterUpdaterHook,
|
|
|
|
|
StringIntPairHasher>
|
|
|
|
|
g_hookCache_;
|
|
|
|
|
static WeakKVCache<std::pair<std::string, int>, IParameterUpdaterHook,
|
|
|
|
|
StringIntPairHasher> g_hookCache_;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* ParameterUpdaterHook actually factory method.
|
|
|
|
|
*/
|
|
|
|
|
static IParameterUpdaterHook* createImpl(
|
|
|
|
|
const ParameterUpdaterHookConfig& config) {
|
|
|
|
|
auto& type = config.type();
|
|
|
|
|
static IParameterUpdaterHook *
|
|
|
|
|
createImpl(const ParameterUpdaterHookConfig &config) {
|
|
|
|
|
auto &type = config.type();
|
|
|
|
|
if (type == "pruning") {
|
|
|
|
|
return new StaticPruningHook(config);
|
|
|
|
|
}
|
|
|
|
@ -146,11 +144,11 @@ static IParameterUpdaterHook* createImpl(
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<IParameterUpdaterHook> IParameterUpdaterHook::create(
|
|
|
|
|
const ParameterConfig& paramConfig, int idx) {
|
|
|
|
|
std::shared_ptr<IParameterUpdaterHook>
|
|
|
|
|
IParameterUpdaterHook::create(const ParameterConfig ¶mConfig, int idx) {
|
|
|
|
|
std::pair<std::string, int> key = {paramConfig.name(), idx};
|
|
|
|
|
return g_hookCache_.get(
|
|
|
|
|
key, [&] { return createImpl(paramConfig.update_hooks(idx)); });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
} // namespace paddle
|
|
|
|
|