|  |  |  | @ -14,11 +14,13 @@ limitations under the License. */ | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | #include "ParameterUpdaterHook.h" | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | #include <algorithm> | 
			
		
	
		
			
				
					|  |  |  |  | #include <atomic> | 
			
		
	
		
			
				
					|  |  |  |  | #include <fstream> | 
			
		
	
		
			
				
					|  |  |  |  | #include <mutex> | 
			
		
	
		
			
				
					|  |  |  |  | #include <thread> | 
			
		
	
		
			
				
					|  |  |  |  | #include <unordered_map> | 
			
		
	
		
			
				
					|  |  |  |  | #include <vector> | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | #include "paddle/math/Vector.h" | 
			
		
	
		
			
				
					|  |  |  |  | #include "paddle/parameter/Parameter.h" | 
			
		
	
	
		
			
				
					|  |  |  | @ -29,40 +31,21 @@ namespace paddle { | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | /**
 | 
			
		
	
		
			
				
					|  |  |  |  |  * The static pruning hook | 
			
		
	
		
			
				
					|  |  |  |  |  * | 
			
		
	
		
			
				
					|  |  |  |  |  * Static means user load a mask map before training started. This map will | 
			
		
	
		
			
				
					|  |  |  |  |  * define which link/weight between neural is disabled. | 
			
		
	
		
			
				
					|  |  |  |  |  * Static means user specify a sparsity_ratio before training started, and the | 
			
		
	
		
			
				
					|  |  |  |  |  * network will prune the parameters based on the sparsity_ratio. More details | 
			
		
	
		
			
				
					|  |  |  |  |  * can be found https://arxiv.org/pdf/1506.02626.pdf.
 | 
			
		
	
		
			
				
					|  |  |  |  |  */ | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | class StaticPruningHook : public IParameterUpdaterHook { | 
			
		
	
		
			
				
					|  |  |  |  | public: | 
			
		
	
		
			
				
					|  |  |  |  |   /**
 | 
			
		
	
		
			
				
					|  |  |  |  |    * The Mask Map Header. | 
			
		
	
		
			
				
					|  |  |  |  |    * The map file started with this header. | 
			
		
	
		
			
				
					|  |  |  |  |    * | 
			
		
	
		
			
				
					|  |  |  |  |    * In Version 0, reset file will be: | 
			
		
	
		
			
				
					|  |  |  |  |    *  contains header.size bit, each bit means such weight is enabled or not. | 
			
		
	
		
			
				
					|  |  |  |  |    *    if bit is 1, then such weight is enabled. | 
			
		
	
		
			
				
					|  |  |  |  |    *  at end, the file will round to byte, and the low bits of end byte will be | 
			
		
	
		
			
				
					|  |  |  |  |    *  filled by zero. | 
			
		
	
		
			
				
					|  |  |  |  |    * | 
			
		
	
		
			
				
					|  |  |  |  |    */ | 
			
		
	
		
			
				
					|  |  |  |  |   struct StaticMaskHeader { | 
			
		
	
		
			
				
					|  |  |  |  |     uint32_t version; | 
			
		
	
		
			
				
					|  |  |  |  |     size_t size; | 
			
		
	
		
			
				
					|  |  |  |  |   } __attribute__((__packed__)); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) { | 
			
		
	
		
			
				
					|  |  |  |  |     bool ok = this->loadMaskFile(mask_filename); | 
			
		
	
		
			
				
					|  |  |  |  |     if (!ok) { | 
			
		
	
		
			
				
					|  |  |  |  |       LOG(WARNING) << "Fail to load mask file " << mask_filename | 
			
		
	
		
			
				
					|  |  |  |  |                    << " in current directory, searching in init_model_path"; | 
			
		
	
		
			
				
					|  |  |  |  |       std::string combineMaskFilename = | 
			
		
	
		
			
				
					|  |  |  |  |           path::join(FLAGS_init_model_path, mask_filename); | 
			
		
	
		
			
				
					|  |  |  |  |       CHECK(this->loadMaskFile(combineMaskFilename)) | 
			
		
	
		
			
				
					|  |  |  |  |           << "Cannot load " << mask_filename << " in ./" << mask_filename | 
			
		
	
		
			
				
					|  |  |  |  |           << " and " << combineMaskFilename; | 
			
		
	
		
			
				
					|  |  |  |  |   explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig) | 
			
		
	
		
			
				
					|  |  |  |  |       : initCount_(0) { | 
			
		
	
		
			
				
					|  |  |  |  |     sparsityRatio_ = hookConfig.sparsity_ratio(); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  |     VLOG(3) << mask_filename << " mask size = " << this->mask_.size(); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   static bool sortPairAscend(const std::pair<real, size_t> &pair1, | 
			
		
	
		
			
				
					|  |  |  |  |                              const std::pair<real, size_t> &pair2) { | 
			
		
	
		
			
				
					|  |  |  |  |     return pair1.first > pair2.first; | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   void update(Parameter *para) { | 
			
		
	
	
		
			
				
					|  |  |  | @ -73,62 +56,51 @@ public: | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   void init(Parameter* para) { | 
			
		
	
		
			
				
					|  |  |  |  |     size_t initCount = this->initCount_.fetch_add(1); | 
			
		
	
		
			
				
					|  |  |  |  |     CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " | 
			
		
	
		
			
				
					|  |  |  |  |                                 "in same ParamterUpdater"; | 
			
		
	
		
			
				
					|  |  |  |  |     VLOG(3) << "Initialize Parameter " << para; | 
			
		
	
		
			
				
					|  |  |  |  |     SetDevice device(para->getDeviceId()); | 
			
		
	
		
			
				
					|  |  |  |  |   void generateMask(Parameter *para) { | 
			
		
	
		
			
				
					|  |  |  |  |     VectorPtr maskTemp = Vector::create(para->getSize(), false); | 
			
		
	
		
			
				
					|  |  |  |  |     maskTemp->zeroMem(); | 
			
		
	
		
			
				
					|  |  |  |  |     real *maskTempData = maskTemp->getData(); | 
			
		
	
		
			
				
					|  |  |  |  |     size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     auto maskVec = Vector::create(this->mask_.size(), false); | 
			
		
	
		
			
				
					|  |  |  |  |     {  // Initialize maskVec with float mask vector
 | 
			
		
	
		
			
				
					|  |  |  |  |       real* dataPtr = maskVec->getData(); | 
			
		
	
		
			
				
					|  |  |  |  |       size_t i = 0; | 
			
		
	
		
			
				
					|  |  |  |  |       for (bool m : mask_) { | 
			
		
	
		
			
				
					|  |  |  |  |         dataPtr[i++] = m ? 1.0 : 0.0; | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     VectorPtr paraVec = para->getBuf(PARAMETER_VALUE); | 
			
		
	
		
			
				
					|  |  |  |  |     VectorPtr paraCpuCopy = Vector::create(para->getSize(), false); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     paraCpuCopy->copyFrom(*paraVec); | 
			
		
	
		
			
				
					|  |  |  |  |     std::vector<std::pair<real, size_t>> param; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     for (size_t i = 0; i < para->getSize(); i++) | 
			
		
	
		
			
				
					|  |  |  |  |       param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i)); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     std::partial_sort( | 
			
		
	
		
			
				
					|  |  |  |  |         param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); | 
			
		
	
		
			
				
					|  |  |  |  |     for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     // Currently just use a mask vector for hack.
 | 
			
		
	
		
			
				
					|  |  |  |  |     // @TODO(yuyang18): Implemented the mask operation in vector.
 | 
			
		
	
		
			
				
					|  |  |  |  |     if (para->useGpu()) { | 
			
		
	
		
			
				
					|  |  |  |  |       maskVec_ = Vector::create(this->mask_.size(), para->useGpu()); | 
			
		
	
		
			
				
					|  |  |  |  |       maskVec_->copyFrom(*maskVec); | 
			
		
	
		
			
				
					|  |  |  |  |       maskVec_ = Vector::create(para->getSize(), para->useGpu()); | 
			
		
	
		
			
				
					|  |  |  |  |       maskVec_->copyFrom(*maskTemp); | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       maskVec_ = maskVec; | 
			
		
	
		
			
				
					|  |  |  |  |       maskVec_ = maskTemp; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     auto& vec = para->getBuf(PARAMETER_VALUE); | 
			
		
	
		
			
				
					|  |  |  |  |     vec->dotMul(*maskVec_); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | private: | 
			
		
	
		
			
				
					|  |  |  |  |   bool loadMaskFile(const std::string& mask_filename) { | 
			
		
	
		
			
				
					|  |  |  |  |     std::ifstream fin; | 
			
		
	
		
			
				
					|  |  |  |  |     fin.open(mask_filename); | 
			
		
	
		
			
				
					|  |  |  |  |     if (fin.is_open()) { | 
			
		
	
		
			
				
					|  |  |  |  |       StaticMaskHeader header; | 
			
		
	
		
			
				
					|  |  |  |  |       fin.read(reinterpret_cast<char*>(&header), sizeof(StaticMaskHeader)); | 
			
		
	
		
			
				
					|  |  |  |  |       CHECK_EQ(header.version, 0UL); | 
			
		
	
		
			
				
					|  |  |  |  |       mask_.resize(header.size); | 
			
		
	
		
			
				
					|  |  |  |  |       uint8_t buf; | 
			
		
	
		
			
				
					|  |  |  |  |       for (size_t i = 0; i < header.size; ++i, buf <<= 1) { | 
			
		
	
		
			
				
					|  |  |  |  |         if (i % 8 == 0) { | 
			
		
	
		
			
				
					|  |  |  |  |           fin.read(reinterpret_cast<char*>(&buf), sizeof(uint8_t)); | 
			
		
	
		
			
				
					|  |  |  |  |         } | 
			
		
	
		
			
				
					|  |  |  |  |         mask_[i] = buf & 0x80; | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |       fin.close(); | 
			
		
	
		
			
				
					|  |  |  |  |       return true; | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       return false; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   void init(Parameter *para) { | 
			
		
	
		
			
				
					|  |  |  |  |     generateMask(para); | 
			
		
	
		
			
				
					|  |  |  |  |     size_t initCount = this->initCount_.fetch_add(1); | 
			
		
	
		
			
				
					|  |  |  |  |     CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " | 
			
		
	
		
			
				
					|  |  |  |  |                                 "in same ParamterUpdater"; | 
			
		
	
		
			
				
					|  |  |  |  |     VLOG(3) << "Initialize Parameter " << para; | 
			
		
	
		
			
				
					|  |  |  |  |     SetDevice device(para->getDeviceId()); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     auto ¶Vec = para->getBuf(PARAMETER_VALUE); | 
			
		
	
		
			
				
					|  |  |  |  |     paraVec->dotMul(*maskVec_); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | private: | 
			
		
	
		
			
				
					|  |  |  |  |   SameThreadChecker updateThreadChecker_; | 
			
		
	
		
			
				
					|  |  |  |  |   std::atomic<size_t> initCount_; | 
			
		
	
		
			
				
					|  |  |  |  |   VectorPtr maskVec_; | 
			
		
	
		
			
				
					|  |  |  |  |   std::vector<bool> mask_; | 
			
		
	
		
			
				
					|  |  |  |  |   real sparsityRatio_; | 
			
		
	
		
			
				
					|  |  |  |  | }; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | IParameterUpdaterHook::IParameterUpdaterHook() {} | 
			
		
	
	
		
			
				
					|  |  |  | @ -166,10 +138,10 @@ static IParameterUpdaterHook* createImpl( | 
			
		
	
		
			
				
					|  |  |  |  |     const ParameterUpdaterHookConfig &config) { | 
			
		
	
		
			
				
					|  |  |  |  |   auto &type = config.type(); | 
			
		
	
		
			
				
					|  |  |  |  |   if (type == "pruning") { | 
			
		
	
		
			
				
					|  |  |  |  |     if (config.has_purning_mask_filename()) { | 
			
		
	
		
			
				
					|  |  |  |  |       return new StaticPruningHook(config.purning_mask_filename()); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     return new StaticPruningHook(config); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   LOG(FATAL) << "Unknown Hook type:  " << type; | 
			
		
	
		
			
				
					|  |  |  |  |   return nullptr; | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | 
 |