|
|
|
@ -25,6 +25,9 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/utils/Flags.h"
|
|
|
|
|
#include "paddle/utils/Util.h"
|
|
|
|
|
|
|
|
|
|
using std::vector;
|
|
|
|
|
using std::pair;
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -131,6 +134,73 @@ private:
|
|
|
|
|
std::vector<bool> mask_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class DynamicPruningHook : public IParameterUpdaterHook {
|
|
|
|
|
public:
|
|
|
|
|
explicit DynamicPruningHook(const ParameterUpdaterHookConfig& hookConfig)
|
|
|
|
|
: initCount_(0) {
|
|
|
|
|
sparsityRatio_ = hookConfig.sparsity_ratio();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool sortPairAscend(const pair<real, size_t>& pair1,
|
|
|
|
|
const pair<real, size_t>& pair2) {
|
|
|
|
|
return pair1.first > pair2.first;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void update(Parameter* para) {
|
|
|
|
|
updateThreadChecker_.check();
|
|
|
|
|
auto& vec = para->getBuf(PARAMETER_GRADIENT);
|
|
|
|
|
if (vec) {
|
|
|
|
|
vec->dotMul(*maskVec_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void generateMask(Parameter* para) {
|
|
|
|
|
VectorPtr vec = para->getBuf(PARAMETER_VALUE);
|
|
|
|
|
maskTemp_ = Vector::create(para->getSize(), false);
|
|
|
|
|
maskTemp_->zeroMem();
|
|
|
|
|
real* dataPtr = maskTemp_->getData();
|
|
|
|
|
|
|
|
|
|
VectorPtr vecCpu = Vector::create(para->getSize(), false);
|
|
|
|
|
vecCpu->copyFrom(*vec);
|
|
|
|
|
vector<pair<real, size_t>> param;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < para->getSize(); i++)
|
|
|
|
|
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
|
|
|
|
|
std::sort(param.begin(), param.end(), sortPairAscend);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++)
|
|
|
|
|
dataPtr[param[i].second] = 1.0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void init(Parameter* para) {
|
|
|
|
|
generateMask(para);
|
|
|
|
|
size_t initCount = this->initCount_.fetch_add(1);
|
|
|
|
|
CHECK_EQ(initCount, 0UL) << "Currently the DynamicPruningHook must invoke "
|
|
|
|
|
"in same ParamterUpdater";
|
|
|
|
|
VLOG(3) << "Initialize Parameter " << para;
|
|
|
|
|
SetDevice device(para->getDeviceId());
|
|
|
|
|
|
|
|
|
|
// Currently just use a mask vector for hack.
|
|
|
|
|
// @TODO(yuyang18): Implemented the mask operation in vector.
|
|
|
|
|
if (para->useGpu()) {
|
|
|
|
|
maskVec_ = Vector::create(para->getSize(), para->useGpu());
|
|
|
|
|
maskVec_->copyFrom(*maskTemp_);
|
|
|
|
|
} else {
|
|
|
|
|
maskVec_ = maskTemp_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto& vec = para->getBuf(PARAMETER_VALUE);
|
|
|
|
|
vec->dotMul(*maskVec_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
SameThreadChecker updateThreadChecker_;
|
|
|
|
|
std::atomic<size_t> initCount_;
|
|
|
|
|
VectorPtr maskVec_;
|
|
|
|
|
VectorPtr maskTemp_;
|
|
|
|
|
real sparsityRatio_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
IParameterUpdaterHook::IParameterUpdaterHook() {}
|
|
|
|
|
|
|
|
|
|
IParameterUpdaterHook::~IParameterUpdaterHook() {}
|
|
|
|
@ -156,8 +226,7 @@ private:
|
|
|
|
|
|
|
|
|
|
static WeakKVCache<std::pair<std::string, int>,
|
|
|
|
|
IParameterUpdaterHook,
|
|
|
|
|
StringIntPairHasher>
|
|
|
|
|
g_hookCache_;
|
|
|
|
|
StringIntPairHasher> g_hookCache_;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* ParameterUpdaterHook actually factory method.
|
|
|
|
@ -165,11 +234,22 @@ static WeakKVCache<std::pair<std::string, int>,
|
|
|
|
|
static IParameterUpdaterHook* createImpl(
|
|
|
|
|
const ParameterUpdaterHookConfig& config) {
|
|
|
|
|
auto& type = config.type();
|
|
|
|
|
if (type == "pruning") {
|
|
|
|
|
if (config.has_purning_mask_filename()) {
|
|
|
|
|
if (type == "pruning_static") {
|
|
|
|
|
if (config.has_purning_mask_filename())
|
|
|
|
|
return new StaticPruningHook(config.purning_mask_filename());
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
LOG(FATAL) << "There must be mask_filename parameter for " << type
|
|
|
|
|
<< " Hook";
|
|
|
|
|
|
|
|
|
|
} else if (type == "pruning") {
|
|
|
|
|
if (config.has_sparsity_ratio())
|
|
|
|
|
return new DynamicPruningHook(config);
|
|
|
|
|
else
|
|
|
|
|
LOG(FATAL) << "There must be sparsity_ratio parameter for " << type
|
|
|
|
|
<< " Hook";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LOG(FATAL) << "Unknown Hook type: " << type;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|