From 5413af8d11eb8662598e5f89d5f4fa284e1c6fdf Mon Sep 17 00:00:00 2001 From: xzl Date: Fri, 2 Jun 2017 14:28:20 +0800 Subject: [PATCH 01/35] imporve pruning module --- paddle/parameter/ParameterUpdaterHook.cpp | 90 +++++++++++++++++-- proto/ParameterConfig.proto | 2 + python/paddle/trainer/config_parser.py | 15 +++- python/paddle/trainer_config_helpers/attrs.py | 46 +++++++++- python/paddle/v2/attr.py | 2 + 5 files changed, 144 insertions(+), 11 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index f826e8448c..76cc3ecad1 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -25,6 +25,9 @@ limitations under the License. */ #include "paddle/utils/Flags.h" #include "paddle/utils/Util.h" +using std::vector; +using std::pair; + namespace paddle { /** @@ -131,6 +134,73 @@ private: std::vector mask_; }; +class DynamicPruningHook : public IParameterUpdaterHook { +public: + explicit DynamicPruningHook(const ParameterUpdaterHookConfig& hookConfig) + : initCount_(0) { + sparsityRatio_ = hookConfig.sparsity_ratio(); + } + + static bool sortPairAscend(const pair& pair1, + const pair& pair2) { + return pair1.first > pair2.first; + } + + void update(Parameter* para) { + updateThreadChecker_.check(); + auto& vec = para->getBuf(PARAMETER_GRADIENT); + if (vec) { + vec->dotMul(*maskVec_); + } + } + + void generateMask(Parameter* para) { + VectorPtr vec = para->getBuf(PARAMETER_VALUE); + maskTemp_ = Vector::create(para->getSize(), false); + maskTemp_->zeroMem(); + real* dataPtr = maskTemp_->getData(); + + VectorPtr vecCpu = Vector::create(para->getSize(), false); + vecCpu->copyFrom(*vec); + vector> param; + + for (size_t i = 0; i < para->getSize(); i++) + param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); + std::sort(param.begin(), param.end(), sortPairAscend); + + for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++) + dataPtr[param[i].second] = 1.0; + } + + void init(Parameter* para) { + generateMask(para); + size_t initCount = this->initCount_.fetch_add(1); + CHECK_EQ(initCount, 0UL) << "Currently the DynamicPruningHook must invoke " + "in same ParamterUpdater"; + VLOG(3) << "Initialize Parameter " << para; + SetDevice device(para->getDeviceId()); + + // Currently just use a mask vector for hack. + // @TODO(yuyang18): Implemented the mask operation in vector. + if (para->useGpu()) { + maskVec_ = Vector::create(para->getSize(), para->useGpu()); + maskVec_->copyFrom(*maskTemp_); + } else { + maskVec_ = maskTemp_; + } + + auto& vec = para->getBuf(PARAMETER_VALUE); + vec->dotMul(*maskVec_); + } + +private: + SameThreadChecker updateThreadChecker_; + std::atomic initCount_; + VectorPtr maskVec_; + VectorPtr maskTemp_; + real sparsityRatio_; +}; + IParameterUpdaterHook::IParameterUpdaterHook() {} IParameterUpdaterHook::~IParameterUpdaterHook() {} @@ -156,8 +226,7 @@ private: static WeakKVCache, IParameterUpdaterHook, - StringIntPairHasher> - g_hookCache_; + StringIntPairHasher> g_hookCache_; /** * ParameterUpdaterHook actually factory method. @@ -165,11 +234,22 @@ static WeakKVCache, static IParameterUpdaterHook* createImpl( const ParameterUpdaterHookConfig& config) { auto& type = config.type(); - if (type == "pruning") { - if (config.has_purning_mask_filename()) { + if (type == "pruning_static") { + if (config.has_purning_mask_filename()) return new StaticPruningHook(config.purning_mask_filename()); - } + else + LOG(FATAL) << "There must be mask_filename parameter for " << type + << " Hook"; + + } else if (type == "pruning") { + if (config.has_sparsity_ratio()) + return new DynamicPruningHook(config); + else + LOG(FATAL) << "There must be sparsity_ratio parameter for " << type + << " Hook"; } + + LOG(FATAL) << "Unknown Hook type: " << type; return nullptr; } diff --git a/proto/ParameterConfig.proto b/proto/ParameterConfig.proto index cbcd0af598..61f4b037cf 100644 --- a/proto/ParameterConfig.proto +++ b/proto/ParameterConfig.proto @@ -26,7 +26,9 @@ enum ParameterInitStrategy { message ParameterUpdaterHookConfig { required string type = 1; + //hook type such as 'pruning', 'pruning_static' optional string purning_mask_filename = 2; + optional double sparsity_ratio = 3; } message ParameterConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 9fe8794691..d80590210f 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3171,12 +3171,19 @@ def Layer(name, type, **xargs): @config_func def ParameterHook(type, **kwargs): - if type == 'pruning': + if type == 'pruning_static': + hook = ParameterUpdaterHookConfig() + hook.type = type mask_filename = kwargs.get('mask_filename', None) assert mask_filename is not None + hook.pruning_mask_filename = mask_filename + return hook + elif type == 'pruning': hook = ParameterUpdaterHookConfig() hook.type = type - hook.purning_mask_filename = mask_filename + sparsity_ratio = kwargs.get('sparsity_ratio', None) + assert sparsity_ratio is not None + hook.sparsity_ratio = sparsity_ratio return hook else: return None @@ -3283,13 +3290,13 @@ def Parameter(name, if update_hooks is not None: if hasattr(update_hooks, '__call__'): - update_hooks = update_hooks(para.name) + update_hooks = update_hooks() if isinstance(update_hooks, list): for hook in update_hooks: para.update_hooks.extend([hook]) else: - para.update_hooks.extend(update_hooks) + para.update_hooks.extend([update_hooks]) g_parameter_map[name] = para diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index d1167a234c..011147a368 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -14,7 +14,8 @@ from paddle.trainer.config_parser import * __all__ = [ - 'ParamAttr', 'ExtraAttr', 'ParameterAttribute', 'ExtraLayerAttribute' + 'HookAttr', 'ParamAttr', 'ExtraAttr', 'ParameterAttribute', + 'ExtraLayerAttribute' ] @@ -55,6 +56,42 @@ def is_compatible_with(x, Type): return False +class HookAttribute(object): + """ + Hook Attribute object. The hook is an auxiliary operation that occurs + during network propagation. Such as pruning operation, It will cut off + redundant parameters in the network before training. More detail can see + here paddle/parameter/ParameterUpdaterHook.cpp + NOTE: IT IS A HIGH LEVEL USER INTERFACE. + + :param type: Hook type, eg: 'pruning', 'pruning_static' + :type type: string + + :param mask_file: Must be specified if hook type is 'pruning_static', + the network reads the mask from the file to determine which parameters should be cut off + :type mask_file: string + + :param sparsity_ratio: Must be specified if hook type is 'pruning', + the network will hold the sparsity_ratio maximum parameters, and cut off the rest. + :type sparsity_ratio: float number between 0 and 1 + + """ + + def __init__(self, type, mask_filename=None, sparsity_ratio=None): + self.type = type + self.mask_filename = mask_filename + self.sparsity_ratio = sparsity_ratio + assert is_compatible_with(self.sparsity_ratio, + float), 'sparisity_ratio must be float type' + assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' + + def __call__(self): + return ParameterHook( + self.type, + mask_filename=self.mask_filename, + sparsity_ratio=self.sparsity_ratio) + + class ParameterAttribute(object): """ Parameter Attributes object. To fine-tuning network training process, user @@ -109,7 +146,8 @@ class ParameterAttribute(object): learning_rate=None, momentum=None, gradient_clipping_threshold=None, - sparse_update=False): + sparse_update=False, + update_hooks=None): self.attr = {} if is_static: @@ -162,6 +200,9 @@ class ParameterAttribute(object): self.attr['gradient_clipping_threshold'] = \ gradient_clipping_threshold + if update_hooks: + self.attr['update_hooks'] = update_hooks + def set_default_parameter_name(self, name): """ Set default parameter name. If parameter not set, then will use default @@ -237,5 +278,6 @@ class ExtraLayerAttribute(object): return attr.attr +HookAttr = HookAttribute ParamAttr = ParameterAttribute ExtraAttr = ExtraLayerAttribute diff --git a/python/paddle/v2/attr.py b/python/paddle/v2/attr.py index 32f78614e7..5d23894d73 100644 --- a/python/paddle/v2/attr.py +++ b/python/paddle/v2/attr.py @@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs __all__ = [ "Param", "Extra", + "Hook", ] Param = paddle.trainer_config_helpers.attrs.ParameterAttribute Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute +Hook = paddle.trainer_config_helpers.attrs.HookAttribute for each in paddle.trainer_config_helpers.attrs.__all__: globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each) From 18435f2a738b2baec680eea6fc2648dd094e5c87 Mon Sep 17 00:00:00 2001 From: xzl Date: Fri, 2 Jun 2017 16:31:49 +0800 Subject: [PATCH 02/35] modify the pruning from reading mask to specify sparsity_ratio --- paddle/parameter/ParameterUpdaterHook.cpp | 130 ++---------------- proto/ParameterConfig.proto | 3 +- python/paddle/trainer/config_parser.py | 9 +- python/paddle/trainer_config_helpers/attrs.py | 14 +- 4 files changed, 17 insertions(+), 139 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 76cc3ecad1..e29494868b 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -19,130 +19,31 @@ limitations under the License. */ #include #include #include +#include #include "paddle/math/Vector.h" #include "paddle/parameter/Parameter.h" #include "paddle/utils/Flags.h" #include "paddle/utils/Util.h" -using std::vector; -using std::pair; - namespace paddle { /** * The static pruning hook - * - * Static means user load a mask map before training started. This map will - * define which link/weight between neural is disabled. + * Static means user specific a sparsity_ratio map before training started. The + * network will + * hold the sparsity_ratio maximum numbers of parameters, and cut off the rest. */ -class StaticPruningHook : public IParameterUpdaterHook { -public: - /** - * The Mask Map Header. - * The map file started with this header. - * - * In Version 0, reset file will be: - * contains header.size bit, each bit means such weight is enabled or not. - * if bit is 1, then such weight is enabled. - * at end, the file will round to byte, and the low bits of end byte will be - * filled by zero. - * - */ - struct StaticMaskHeader { - uint32_t version; - size_t size; - } __attribute__((__packed__)); - - explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) { - bool ok = this->loadMaskFile(mask_filename); - if (!ok) { - LOG(WARNING) << "Fail to load mask file " << mask_filename - << " in current directory, searching in init_model_path"; - std::string combineMaskFilename = - path::join(FLAGS_init_model_path, mask_filename); - CHECK(this->loadMaskFile(combineMaskFilename)) - << "Cannot load " << mask_filename << " in ./" << mask_filename - << " and " << combineMaskFilename; - } - VLOG(3) << mask_filename << " mask size = " << this->mask_.size(); - } - void update(Parameter* para) { - updateThreadChecker_.check(); - auto& vec = para->getBuf(PARAMETER_GRADIENT); - if (vec) { - vec->dotMul(*maskVec_); - } - } - - void init(Parameter* para) { - size_t initCount = this->initCount_.fetch_add(1); - CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " - "in same ParamterUpdater"; - VLOG(3) << "Initialize Parameter " << para; - SetDevice device(para->getDeviceId()); - - auto maskVec = Vector::create(this->mask_.size(), false); - { // Initialize maskVec with float mask vector - real* dataPtr = maskVec->getData(); - size_t i = 0; - for (bool m : mask_) { - dataPtr[i++] = m ? 1.0 : 0.0; - } - } - - // Currently just use a mask vector for hack. - // @TODO(yuyang18): Implemented the mask operation in vector. - if (para->useGpu()) { - maskVec_ = Vector::create(this->mask_.size(), para->useGpu()); - maskVec_->copyFrom(*maskVec); - } else { - maskVec_ = maskVec; - } - - auto& vec = para->getBuf(PARAMETER_VALUE); - vec->dotMul(*maskVec_); - } - -private: - bool loadMaskFile(const std::string& mask_filename) { - std::ifstream fin; - fin.open(mask_filename); - if (fin.is_open()) { - StaticMaskHeader header; - fin.read(reinterpret_cast(&header), sizeof(StaticMaskHeader)); - CHECK_EQ(header.version, 0UL); - mask_.resize(header.size); - uint8_t buf; - for (size_t i = 0; i < header.size; ++i, buf <<= 1) { - if (i % 8 == 0) { - fin.read(reinterpret_cast(&buf), sizeof(uint8_t)); - } - mask_[i] = buf & 0x80; - } - fin.close(); - return true; - } else { - return false; - } - } - - SameThreadChecker updateThreadChecker_; - std::atomic initCount_; - VectorPtr maskVec_; - std::vector mask_; -}; - -class DynamicPruningHook : public IParameterUpdaterHook { +class StaticPruningHook : public IParameterUpdaterHook { public: - explicit DynamicPruningHook(const ParameterUpdaterHookConfig& hookConfig) + explicit StaticPruningHook(const ParameterUpdaterHookConfig& hookConfig) : initCount_(0) { sparsityRatio_ = hookConfig.sparsity_ratio(); } - static bool sortPairAscend(const pair& pair1, - const pair& pair2) { + static bool sortPairAscend(const std::pair& pair1, + const std::pair& pair2) { return pair1.first > pair2.first; } @@ -162,7 +63,7 @@ public: VectorPtr vecCpu = Vector::create(para->getSize(), false); vecCpu->copyFrom(*vec); - vector> param; + std::vector> param; for (size_t i = 0; i < para->getSize(); i++) param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); @@ -175,7 +76,7 @@ public: void init(Parameter* para) { generateMask(para); size_t initCount = this->initCount_.fetch_add(1); - CHECK_EQ(initCount, 0UL) << "Currently the DynamicPruningHook must invoke " + CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " "in same ParamterUpdater"; VLOG(3) << "Initialize Parameter " << para; SetDevice device(para->getDeviceId()); @@ -234,16 +135,9 @@ static WeakKVCache, static IParameterUpdaterHook* createImpl( const ParameterUpdaterHookConfig& config) { auto& type = config.type(); - if (type == "pruning_static") { - if (config.has_purning_mask_filename()) - return new StaticPruningHook(config.purning_mask_filename()); - else - LOG(FATAL) << "There must be mask_filename parameter for " << type - << " Hook"; - - } else if (type == "pruning") { + if (type == "pruning") { if (config.has_sparsity_ratio()) - return new DynamicPruningHook(config); + return new StaticPruningHook(config); else LOG(FATAL) << "There must be sparsity_ratio parameter for " << type << " Hook"; diff --git a/proto/ParameterConfig.proto b/proto/ParameterConfig.proto index 61f4b037cf..53e3b94f03 100644 --- a/proto/ParameterConfig.proto +++ b/proto/ParameterConfig.proto @@ -26,8 +26,7 @@ enum ParameterInitStrategy { message ParameterUpdaterHookConfig { required string type = 1; - //hook type such as 'pruning', 'pruning_static' - optional string purning_mask_filename = 2; + //hook type such as 'pruning' optional double sparsity_ratio = 3; } diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 3775375c9b..bebb76d984 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3171,14 +3171,7 @@ def Layer(name, type, **xargs): @config_func def ParameterHook(type, **kwargs): - if type == 'pruning_static': - hook = ParameterUpdaterHookConfig() - hook.type = type - mask_filename = kwargs.get('mask_filename', None) - assert mask_filename is not None - hook.pruning_mask_filename = mask_filename - return hook - elif type == 'pruning': + if type == 'pruning': hook = ParameterUpdaterHookConfig() hook.type = type sparsity_ratio = kwargs.get('sparsity_ratio', None) diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index 011147a368..a0ad8c4452 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -64,32 +64,24 @@ class HookAttribute(object): here paddle/parameter/ParameterUpdaterHook.cpp NOTE: IT IS A HIGH LEVEL USER INTERFACE. - :param type: Hook type, eg: 'pruning', 'pruning_static' + :param type: Hook type, eg: 'pruning' :type type: string - :param mask_file: Must be specified if hook type is 'pruning_static', - the network reads the mask from the file to determine which parameters should be cut off - :type mask_file: string - :param sparsity_ratio: Must be specified if hook type is 'pruning', the network will hold the sparsity_ratio maximum parameters, and cut off the rest. :type sparsity_ratio: float number between 0 and 1 """ - def __init__(self, type, mask_filename=None, sparsity_ratio=None): + def __init__(self, type, sparsity_ratio=None): self.type = type - self.mask_filename = mask_filename self.sparsity_ratio = sparsity_ratio assert is_compatible_with(self.sparsity_ratio, float), 'sparisity_ratio must be float type' assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' def __call__(self): - return ParameterHook( - self.type, - mask_filename=self.mask_filename, - sparsity_ratio=self.sparsity_ratio) + return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) class ParameterAttribute(object): From 092828fbe30e40b72fc25d8ab9c56ac7ecb5afe4 Mon Sep 17 00:00:00 2001 From: xzl Date: Mon, 5 Jun 2017 17:42:33 +0800 Subject: [PATCH 03/35] modify the doc of the interface --- paddle/parameter/ParameterUpdaterHook.cpp | 6 +++--- proto/ParameterConfig.proto | 4 ++-- python/paddle/trainer_config_helpers/attrs.py | 11 ++++------- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index e29494868b..5e8c77ced0 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -30,9 +30,9 @@ namespace paddle { /** * The static pruning hook - * Static means user specific a sparsity_ratio map before training started. The - * network will - * hold the sparsity_ratio maximum numbers of parameters, and cut off the rest. + * Static means user specific a sparsity_ratio before training start, and the + * network will prune the parameters based on the sparsity_ratio. More deatils + * can see https://arxiv.org/pdf/1506.02626.pdf. */ class StaticPruningHook : public IParameterUpdaterHook { diff --git a/proto/ParameterConfig.proto b/proto/ParameterConfig.proto index 53e3b94f03..360342bac6 100644 --- a/proto/ParameterConfig.proto +++ b/proto/ParameterConfig.proto @@ -25,9 +25,9 @@ enum ParameterInitStrategy { } message ParameterUpdaterHookConfig { + // hook type such as 'pruning' required string type = 1; - //hook type such as 'pruning' - optional double sparsity_ratio = 3; + optional double sparsity_ratio = 2 [default = 0.8]; } message ParameterConfig { diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index a0ad8c4452..556701ca7a 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -59,17 +59,14 @@ def is_compatible_with(x, Type): class HookAttribute(object): """ Hook Attribute object. The hook is an auxiliary operation that occurs - during network propagation. Such as pruning operation, It will cut off - redundant parameters in the network before training. More detail can see - here paddle/parameter/ParameterUpdaterHook.cpp + during network propagation. NOTE: IT IS A HIGH LEVEL USER INTERFACE. - + :param type: Hook type, eg: 'pruning' :type type: string - :param sparsity_ratio: Must be specified if hook type is 'pruning', - the network will hold the sparsity_ratio maximum parameters, and cut off the rest. - :type sparsity_ratio: float number between 0 and 1 + :param sparsity_ratio: Must be specified if hook type is 'pruning' + :type sparsity_ratio: float or None """ From 597a58c3efe015be43e1e20a20a04921a9ae7c60 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 13 Jun 2017 23:52:16 +0800 Subject: [PATCH 04/35] Add DetectionMAPEvaluator. --- .../evaluators/DetectionMAPEvaluator.cpp | 312 ++++++++++++++++++ paddle/gserver/tests/test_Evaluator.cpp | 17 + proto/ModelConfig.proto | 9 + python/paddle/trainer/config_parser.py | 43 ++- .../trainer_config_helpers/evaluators.py | 105 ++++-- 5 files changed, 453 insertions(+), 33 deletions(-) create mode 100644 paddle/gserver/evaluators/DetectionMAPEvaluator.cpp diff --git a/paddle/gserver/evaluators/DetectionMAPEvaluator.cpp b/paddle/gserver/evaluators/DetectionMAPEvaluator.cpp new file mode 100644 index 0000000000..7d326c2db1 --- /dev/null +++ b/paddle/gserver/evaluators/DetectionMAPEvaluator.cpp @@ -0,0 +1,312 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Evaluator.h" +#include "paddle/gserver/layers/DetectionUtil.h" + +using std::map; +using std::vector; +using std::pair; +using std::make_pair; + +namespace paddle { + +/** + * @brief detection map Evaluator + * + * The config file api is detection_map_evaluator. + */ +class DetectionMAPEvaluator : public Evaluator { +public: + DetectionMAPEvaluator() + : evaluateDifficult_(false), cpuOutput_(nullptr), cpuLabel_(nullptr) {} + + virtual void start() { + Evaluator::start(); + allTruePos_.clear(); + allFalsePos_.clear(); + numPos_.clear(); + } + + virtual real evalImp(std::vector& arguments) { + overlapThreshold_ = config_.overlap_threshold(); + backgroundId_ = config_.background_id(); + evaluateDifficult_ = config_.evaluate_difficult(); + apType_ = config_.ap_type(); + + MatrixPtr detectTmpValue = arguments[0].value; + Matrix::resizeOrCreate(cpuOutput_, + detectTmpValue->getHeight(), + detectTmpValue->getWidth(), + false, + false); + + MatrixPtr labelTmpValue = arguments[1].value; + Matrix::resizeOrCreate(cpuLabel_, + labelTmpValue->getHeight(), + labelTmpValue->getWidth(), + false, + false); + + cpuOutput_->copyFrom(*detectTmpValue); + cpuLabel_->copyFrom(*labelTmpValue); + + Argument label = arguments[1]; + const int* labelIndex = label.sequenceStartPositions->getData(false); + size_t batchSize = label.getNumSequences(); + + vector>> allGTBBoxes; + vector>>> allDetectBBoxes; + + for (size_t n = 0; n < batchSize; ++n) { + map> bboxes; + for (int i = labelIndex[n]; i < labelIndex[n + 1]; ++i) { + vector bbox; + getBBoxFromLabelData(cpuLabel_->getData() + i * 6, 1, bbox); + int c = cpuLabel_->getData()[i * 6]; + bboxes[c].push_back(bbox[0]); + } + allGTBBoxes.push_back(bboxes); + } + + size_t imgId = 0; + for (size_t n = 0; n < cpuOutput_->getHeight();) { + map>> bboxes; + while (cpuOutput_->getData()[n * 7] == imgId && + n < cpuOutput_->getHeight()) { + vector label; + vector score; + vector bbox; + getBBoxFromDetectData( + cpuOutput_->getData() + n * 7, 1, label, score, bbox); + bboxes[label[0]].push_back(make_pair(score[0], bbox[0])); + ++n; + } + ++imgId; + if (imgId > batchSize) break; + allDetectBBoxes.push_back(bboxes); + } + + for (size_t n = 0; n < batchSize; ++n) { + for (map>::iterator it = + allGTBBoxes[n].begin(); + it != allGTBBoxes[n].end(); + ++it) { + size_t count = 0; + if (evaluateDifficult_) { + count = it->second.size(); + } else { + for (size_t i = 0; i < it->second.size(); ++i) + if (!(it->second[i].isDifficult)) ++count; + } + if (numPos_.find(it->first) == numPos_.end() && count != 0) { + numPos_[it->first] = count; + } else { + numPos_[it->first] += count; + } + } + } + + // calcTFPos + calcTFPos( + batchSize, allGTBBoxes, allDetectBBoxes, &allTruePos_, &allFalsePos_); + + return 0; + } + + virtual void printStats(std::ostream& os) const { + real mAP = calcMAP(); + os << "Detection mAP=" << mAP * 100; + } + + virtual void distributeEval(ParameterClient2* client) { + LOG(FATAL) << "Distribute detection evaluation not implemented."; + } + +protected: + void calcTFPos(const size_t batchSize, + const vector>>& allGTBBoxes, + const vector>>>& + allDetectBBoxes, + map>>* allTruePos, + map>>* allFalsePos) { + for (size_t n = 0; n < allDetectBBoxes.size(); ++n) { + if (allGTBBoxes[n].size() == 0) { + for (map>>::const_iterator + it = allDetectBBoxes[n].begin(); + it != allDetectBBoxes[n].end(); + ++it) { + size_t label = it->first; + for (size_t i = 0; i < it->second.size(); ++i) { + (*allTruePos)[label].push_back(make_pair(it->second[i].first, 0)); + (*allFalsePos)[label].push_back(make_pair(it->second[i].first, 1)); + } + } + } else { + for (map>>::const_iterator + it = allDetectBBoxes[n].begin(); + it != allDetectBBoxes[n].end(); + ++it) { + size_t label = it->first; + vector> predBBoxes = it->second; + if (allGTBBoxes[n].find(label) == allGTBBoxes[n].end()) { + for (size_t i = 0; i < predBBoxes.size(); ++i) { + (*allTruePos)[label].push_back(make_pair(predBBoxes[i].first, 0)); + (*allFalsePos)[label].push_back( + make_pair(predBBoxes[i].first, 1)); + } + } else { + vector gtBBoxes = + allGTBBoxes[n].find(label)->second; + vector visited(gtBBoxes.size(), false); + // Sort detections in descend order based on scores + std::sort(predBBoxes.begin(), + predBBoxes.end(), + sortScorePairDescend); + for (size_t i = 0; i < predBBoxes.size(); ++i) { + real maxOverlap = -1.0; + size_t maxIdx = 0; + for (size_t j = 0; j < gtBBoxes.size(); ++j) { + real overlap = + jaccardOverlap(predBBoxes[i].second, gtBBoxes[j]); + if (overlap > maxOverlap) { + maxOverlap = overlap; + maxIdx = j; + } + } + if (maxOverlap > overlapThreshold_) { + if (evaluateDifficult_ || + (!evaluateDifficult_ && !gtBBoxes[maxIdx].isDifficult)) { + if (!visited[maxIdx]) { + (*allTruePos)[label].push_back( + make_pair(predBBoxes[i].first, 1)); + (*allFalsePos)[label].push_back( + make_pair(predBBoxes[i].first, 0)); + visited[maxIdx] = true; + } else { + (*allTruePos)[label].push_back( + make_pair(predBBoxes[i].first, 0)); + (*allFalsePos)[label].push_back( + make_pair(predBBoxes[i].first, 1)); + } + } + } else { + (*allTruePos)[label].push_back( + make_pair(predBBoxes[i].first, 0)); + (*allFalsePos)[label].push_back( + make_pair(predBBoxes[i].first, 1)); + } + } + } + } + } + } + } + + real calcMAP() const { + real mAP = 0.0; + size_t count = 0; + for (map::const_iterator it = numPos_.begin(); + it != numPos_.end(); + ++it) { + size_t label = it->first; + size_t labelNumPos = it->second; + if (labelNumPos == 0 || allTruePos_.find(label) == allTruePos_.end()) + continue; + vector> labelTruePos = allTruePos_.find(label)->second; + vector> labelFalsePos = + allFalsePos_.find(label)->second; + // Compute average precision. + vector tpCumSum; + getAccumulation(labelTruePos, &tpCumSum); + vector fpCumSum; + getAccumulation(labelFalsePos, &fpCumSum); + std::vector precision, recall; + size_t num = tpCumSum.size(); + // Compute Precision. + for (size_t i = 0; i < num; ++i) { + CHECK_LE(tpCumSum[i], labelNumPos); + precision.push_back(static_cast(tpCumSum[i]) / + static_cast(tpCumSum[i] + fpCumSum[i])); + recall.push_back(static_cast(tpCumSum[i]) / labelNumPos); + } + // VOC2007 style + if (apType_ == "11point") { + vector maxPrecisions(11, 0.0); + int startIdx = num - 1; + for (int j = 10; j >= 0; --j) + for (int i = startIdx; i >= 0; --i) { + if (recall[i] < j / 10.) { + startIdx = i; + if (j > 0) maxPrecisions[j - 1] = maxPrecisions[j]; + break; + } else { + if (maxPrecisions[j] < precision[i]) + maxPrecisions[j] = precision[i]; + } + } + for (int j = 10; j >= 0; --j) mAP += maxPrecisions[j] / 11; + ++count; + } else if (apType_ == "Integral") { + // Nature integral + real averagePrecisions = 0.; + real prevRecall = 0.; + for (size_t i = 0; i < num; ++i) { + if (fabs(recall[i] - prevRecall) > 1e-6) + averagePrecisions += precision[i] * fabs(recall[i] - prevRecall); + prevRecall = recall[i]; + } + mAP += averagePrecisions; + ++count; + } else { + LOG(FATAL) << "Unkown ap version: " << apType_; + } + } + if (count != 0) mAP /= count; + return mAP; + } + + void getAccumulation(vector> inPairs, + vector* accuVec) const { + std::stable_sort( + inPairs.begin(), inPairs.end(), sortScorePairDescend); + accuVec->clear(); + size_t sum = 0; + for (size_t i = 0; i < inPairs.size(); ++i) { + sum += inPairs[i].second; + accuVec->push_back(sum); + } + } + + std::string getTypeImpl() const { return "detection_map"; } + + real getValueImpl() const { return calcMAP() * 100; } + +private: + real overlapThreshold_; + bool evaluateDifficult_; + size_t backgroundId_; + std::string apType_; + + MatrixPtr cpuOutput_; + MatrixPtr cpuLabel_; + + map numPos_; + map>> allTruePos_; + map>> allFalsePos_; +}; + +REGISTER_EVALUATOR(detection_map, DetectionMAPEvaluator); + +} // namespace paddle diff --git a/paddle/gserver/tests/test_Evaluator.cpp b/paddle/gserver/tests/test_Evaluator.cpp index 4f5fdbb37c..93996392d2 100644 --- a/paddle/gserver/tests/test_Evaluator.cpp +++ b/paddle/gserver/tests/test_Evaluator.cpp @@ -138,6 +138,23 @@ void testEvaluatorAll(TestConfig testConf, testEvaluator(testConf, testEvaluatorName, batchSize, false); } +TEST(Evaluator, detection_map) { + TestConfig config; + config.evaluatorConfig.set_type("detection_map"); + config.evaluatorConfig.set_overlap_threshold(0.5); + config.evaluatorConfig.set_background_id(0); + config.evaluatorConfig.set_ap_type("Integral"); + config.evaluatorConfig.set_evaluate_difficult(0); + + config.inputDefs.push_back({INPUT_DATA, "output", 7}); + config.inputDefs.push_back({INPUT_SEQUENCE_DATA, "label", 6}); + config.evaluatorConfig.set_evaluate_difficult(false); + testEvaluatorAll(config, "detection_map", 100); + + config.evaluatorConfig.set_evaluate_difficult(true); + testEvaluatorAll(config, "detection_map", 100); +} + TEST(Evaluator, classification_error) { TestConfig config; config.evaluatorConfig.set_type("classification_error"); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 29270829bb..ebe4f5cbb5 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -489,6 +489,15 @@ message EvaluatorConfig { // Used by ClassificationErrorEvaluator // top # classification error optional int32 top_k = 13 [default = 1]; + + // Used by DetectionMAPEvaluator + optional double overlap_threshold = 14 [default = 0.5]; + + optional int32 background_id = 15 [default = 0]; + + optional bool evaluate_difficult = 16 [default = false]; + + optional string ap_type = 17 [default = "11point"]; } message LinkConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 0792e2d40b..e78dc4f3b4 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1300,20 +1300,23 @@ def parse_maxout(maxout, input_layer_name, maxout_conf): # Define an evaluator @config_func -def Evaluator( - name, - type, - inputs, - chunk_scheme=None, - num_chunk_types=None, - classification_threshold=None, - positive_label=None, - dict_file=None, - result_file=None, - num_results=None, - top_k=None, - delimited=None, - excluded_chunk_types=None, ): +def Evaluator(name, + type, + inputs, + chunk_scheme=None, + num_chunk_types=None, + classification_threshold=None, + positive_label=None, + dict_file=None, + result_file=None, + num_results=None, + top_k=None, + delimited=None, + excluded_chunk_types=None, + overlap_threshold=None, + background_id=None, + evaluate_difficult=None, + ap_type=None): evaluator = g_config.model_config.evaluators.add() evaluator.type = type evaluator.name = MakeLayerNameInSubmodel(name) @@ -1347,6 +1350,18 @@ def Evaluator( if excluded_chunk_types: evaluator.excluded_chunk_types.extend(excluded_chunk_types) + if overlap_threshold is not None: + evaluator.overlap_threshold = overlap_threshold + + if background_id is not None: + evaluator.background_id = background_id + + if evaluate_difficult is not None: + evaluator.evaluate_difficult = evaluate_difficult + + if ap_type is not None: + evaluator.ap_type = ap_type + class LayerBase(object): def __init__( diff --git a/python/paddle/trainer_config_helpers/evaluators.py b/python/paddle/trainer_config_helpers/evaluators.py index a5234f3e47..1dcd804803 100644 --- a/python/paddle/trainer_config_helpers/evaluators.py +++ b/python/paddle/trainer_config_helpers/evaluators.py @@ -21,7 +21,8 @@ __all__ = [ "chunk_evaluator", "sum_evaluator", "column_sum_evaluator", "value_printer_evaluator", "gradient_printer_evaluator", "maxid_printer_evaluator", "maxframe_printer_evaluator", - "seqtext_printer_evaluator", "classification_error_printer_evaluator" + "seqtext_printer_evaluator", "classification_error_printer_evaluator", + "detection_map_evaluator" ] @@ -31,10 +32,11 @@ class EvaluatorAttribute(object): FOR_RANK = 1 << 2 FOR_PRINT = 1 << 3 FOR_UTILS = 1 << 4 + FOR_DETECTION = 1 << 5 KEYS = [ "for_classification", "for_regression", "for_rank", "for_print", - "for_utils" + "for_utils", "for_detection" ] @staticmethod @@ -57,22 +59,25 @@ def evaluator(*attrs): return impl -def evaluator_base( - input, - type, - label=None, - weight=None, - name=None, - chunk_scheme=None, - num_chunk_types=None, - classification_threshold=None, - positive_label=None, - dict_file=None, - result_file=None, - num_results=None, - delimited=None, - top_k=None, - excluded_chunk_types=None, ): +def evaluator_base(input, + type, + label=None, + weight=None, + name=None, + chunk_scheme=None, + num_chunk_types=None, + classification_threshold=None, + positive_label=None, + dict_file=None, + result_file=None, + num_results=None, + delimited=None, + top_k=None, + excluded_chunk_types=None, + overlap_threshold=None, + background_id=None, + evaluate_difficult=None, + ap_type=None): """ Evaluator will evaluate the network status while training/testing. @@ -107,6 +112,14 @@ def evaluator_base( :type weight: LayerOutput. :param top_k: number k in top-k error rate :type top_k: int + :param overlap_threshold: In detection tasks to filter detection results + :type overlap_threshold: float + :param background_id: Identifier of background class + :type background_id: int + :param evaluate_difficult: Whether to evaluate difficult objects + :type evaluate_difficult: bool + :param ap_type: How to calculate average persicion + :type ap_type: str """ # inputs type assertions. assert classification_threshold is None or isinstance( @@ -136,7 +149,61 @@ def evaluator_base( delimited=delimited, num_results=num_results, top_k=top_k, - excluded_chunk_types=excluded_chunk_types, ) + excluded_chunk_types=excluded_chunk_types, + overlap_threshold=overlap_threshold, + background_id=background_id, + evaluate_difficult=evaluate_difficult, + ap_type=ap_type) + + +@evaluator(EvaluatorAttribute.FOR_DETECTION) +@wrap_name_default() +def detection_map_evaluator(input, + label, + overlap_threshold=0.5, + background_id=0, + evaluate_difficult=False, + ap_type="11point", + name=None): + """ + Detection mAP Evaluator. It will print mean Average Precision for detection. + + The detection mAP Evaluator according to the detection_output's output count + the true positive and the false positive bbox and integral them to get the + mAP. + + The simple usage is: + + .. code-block:: python + + eval = detection_map_evaluator(input=det_output,label=lbl) + + :param input: Input layer. + :type input: LayerOutput + :param label: Label layer. + :type label: LayerOutput + :param overlap_threshold: The bbox overlap threshold of a true positive. + :type overlap_threshold: float + :param background_id: The background class index. + :type background_id: int + :param evaluate_difficult: Wether evaluate a difficult ground truth. + :type evaluate_difficult: bool + """ + if not isinstance(input, list): + input = [input] + + if label: + input.append(label) + + evaluator_base( + name=name, + type="detection_map", + input=input, + label=label, + overlap_threshold=overlap_threshold, + background_id=background_id, + evaluate_difficult=evaluate_difficult, + ap_type=ap_type) @evaluator(EvaluatorAttribute.FOR_CLASSIFICATION) From 997cef2e63ef4d7c99c58710289f7581d2af08c6 Mon Sep 17 00:00:00 2001 From: xzl Date: Wed, 14 Jun 2017 17:26:08 +0800 Subject: [PATCH 05/35] tiny modify --- paddle/parameter/ParameterUpdaterHook.cpp | 33 +++++++++---------- python/paddle/trainer/config_parser.py | 4 +-- python/paddle/trainer_config_helpers/attrs.py | 8 +++-- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 5e8c77ced0..a581cc047d 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -20,6 +20,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/math/Vector.h" #include "paddle/parameter/Parameter.h" @@ -60,6 +61,7 @@ public: maskTemp_ = Vector::create(para->getSize(), false); maskTemp_->zeroMem(); real* dataPtr = maskTemp_->getData(); + size_t sparsityNum = para->getSize() * (1 - sparsityRatio_); VectorPtr vecCpu = Vector::create(para->getSize(), false); vecCpu->copyFrom(*vec); @@ -67,10 +69,20 @@ public: for (size_t i = 0; i < para->getSize(); i++) param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); - std::sort(param.begin(), param.end(), sortPairAscend); - for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++) - dataPtr[param[i].second] = 1.0; + std::partial_sort(param.begin(), + param.begin() + sparsityNum, + param.end(), + sortPairAscend); + for (size_t i = 0; i < sparsityNum; i++) dataPtr[param[i].second] = 1.0; + + // Currently just use a mask vector for hack. + if (para->useGpu()) { + maskVec_ = Vector::create(para->getSize(), para->useGpu()); + maskVec_->copyFrom(*maskTemp_); + } else { + maskVec_ = maskTemp_; + } } void init(Parameter* para) { @@ -81,15 +93,6 @@ public: VLOG(3) << "Initialize Parameter " << para; SetDevice device(para->getDeviceId()); - // Currently just use a mask vector for hack. - // @TODO(yuyang18): Implemented the mask operation in vector. - if (para->useGpu()) { - maskVec_ = Vector::create(para->getSize(), para->useGpu()); - maskVec_->copyFrom(*maskTemp_); - } else { - maskVec_ = maskTemp_; - } - auto& vec = para->getBuf(PARAMETER_VALUE); vec->dotMul(*maskVec_); } @@ -136,11 +139,7 @@ static IParameterUpdaterHook* createImpl( const ParameterUpdaterHookConfig& config) { auto& type = config.type(); if (type == "pruning") { - if (config.has_sparsity_ratio()) - return new StaticPruningHook(config); - else - LOG(FATAL) << "There must be sparsity_ratio parameter for " << type - << " Hook"; + return new StaticPruningHook(config); } LOG(FATAL) << "Unknown Hook type: " << type; diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index e0147b1b37..3a29c91807 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3175,8 +3175,8 @@ def ParameterHook(type, **kwargs): hook = ParameterUpdaterHookConfig() hook.type = type sparsity_ratio = kwargs.get('sparsity_ratio', None) - assert sparsity_ratio is not None - hook.sparsity_ratio = sparsity_ratio + if sparsity_ratio is not None: + hook.sparsity_ratio = sparsity_ratio return hook else: return None diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index 556701ca7a..27b54ffdea 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -73,9 +73,11 @@ class HookAttribute(object): def __init__(self, type, sparsity_ratio=None): self.type = type self.sparsity_ratio = sparsity_ratio - assert is_compatible_with(self.sparsity_ratio, - float), 'sparisity_ratio must be float type' - assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' + if self.sparsity_ratio is not None: + assert is_compatible_with( + self.sparsity_ratio, + float), 'sparisity_ratio must be float type' + assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' def __call__(self): return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) From 98e4bb79ea3f569b35c69272e0ffebf6613c985a Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Wed, 14 Jun 2017 18:48:32 +0800 Subject: [PATCH 06/35] Create ParameterUpdaterHook.cpp --- paddle/parameter/ParameterUpdaterHook.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index a581cc047d..a4c0cb3099 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -130,7 +130,8 @@ private: static WeakKVCache, IParameterUpdaterHook, - StringIntPairHasher> g_hookCache_; + StringIntPairHasher> + g_hookCache_; /** * ParameterUpdaterHook actually factory method. From 4fbec8233b08dec0608b342563c62ecee946e460 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Wed, 14 Jun 2017 18:49:00 +0800 Subject: [PATCH 07/35] Update ParameterUpdaterHook.cpp --- paddle/parameter/ParameterUpdaterHook.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index a4c0cb3099..3e3dcd6575 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -130,7 +130,7 @@ private: static WeakKVCache, IParameterUpdaterHook, - StringIntPairHasher> + StringIntPairHasher> g_hookCache_; /** From 5405dc0a65e3bb4a9b807a46bb1296cddce44a7e Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Wed, 14 Jun 2017 19:15:51 +0800 Subject: [PATCH 08/35] Create ParameterUpdaterHook.cpp --- paddle/parameter/ParameterUpdaterHook.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 3e3dcd6575..44fac59200 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -14,13 +14,13 @@ limitations under the License. */ #include "ParameterUpdaterHook.h" +#include #include #include #include #include #include #include -#include #include "paddle/math/Vector.h" #include "paddle/parameter/Parameter.h" From fc9e3e4bda6a4ceaa1ae9e45eb3ef522382bf8e3 Mon Sep 17 00:00:00 2001 From: zlx Date: Fri, 16 Jun 2017 14:29:16 +0800 Subject: [PATCH 09/35] explain the sparsity ratio --- paddle/parameter/ParameterUpdaterHook.cpp | 6 +++--- proto/ParameterConfig.proto | 3 ++- python/paddle/trainer_config_helpers/attrs.py | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 44fac59200..1cc91b727a 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -61,7 +61,7 @@ public: maskTemp_ = Vector::create(para->getSize(), false); maskTemp_->zeroMem(); real* dataPtr = maskTemp_->getData(); - size_t sparsityNum = para->getSize() * (1 - sparsityRatio_); + size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); VectorPtr vecCpu = Vector::create(para->getSize(), false); vecCpu->copyFrom(*vec); @@ -71,10 +71,10 @@ public: param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); std::partial_sort(param.begin(), - param.begin() + sparsityNum, + param.begin() + nonZeroNum, param.end(), sortPairAscend); - for (size_t i = 0; i < sparsityNum; i++) dataPtr[param[i].second] = 1.0; + for (size_t i = 0; i < nonZeroNum; i++) dataPtr[param[i].second] = 1.0; // Currently just use a mask vector for hack. if (para->useGpu()) { diff --git a/proto/ParameterConfig.proto b/proto/ParameterConfig.proto index 360342bac6..580d663246 100644 --- a/proto/ParameterConfig.proto +++ b/proto/ParameterConfig.proto @@ -27,7 +27,8 @@ enum ParameterInitStrategy { message ParameterUpdaterHookConfig { // hook type such as 'pruning' required string type = 1; - optional double sparsity_ratio = 2 [default = 0.8]; + // this represents the ratio of zero element to be set by the Parameter + optional double sparsity_ratio = 2 [default = 0.6]; } message ParameterConfig { diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index 2e4e082efb..bf12ad644d 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -65,7 +65,8 @@ class HookAttribute(object): :param type: Hook type, eg: 'pruning' :type type: string - :param sparsity_ratio: Must be specified if hook type is 'pruning' + :param sparsity_ratio: Must be specified if hook type is 'pruning', + it represents the ratio of the zero elements to be set by the Parameter. :type sparsity_ratio: float or None """ From 885275ee77ddafa28cda0135fa752ca9d8afe1c8 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Fri, 16 Jun 2017 14:59:18 +0800 Subject: [PATCH 10/35] Update ParameterUpdaterHook.cpp --- paddle/parameter/ParameterUpdaterHook.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 1cc91b727a..738e86a622 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -70,10 +70,8 @@ public: for (size_t i = 0; i < para->getSize(); i++) param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); - std::partial_sort(param.begin(), - param.begin() + nonZeroNum, - param.end(), - sortPairAscend); + std::partial_sort( + param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); for (size_t i = 0; i < nonZeroNum; i++) dataPtr[param[i].second] = 1.0; // Currently just use a mask vector for hack. From 5f924d5d533831c29f1f5243eb1790467c9aac1a Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 19 Jun 2017 18:15:15 +0800 Subject: [PATCH 11/35] Follow comments. --- doc/api/v2/config/evaluators.rst | 9 +++ .../evaluators/DetectionMAPEvaluator.cpp | 66 +++++++++---------- .../trainer_config_helpers/evaluators.py | 6 +- 3 files changed, 43 insertions(+), 38 deletions(-) diff --git a/doc/api/v2/config/evaluators.rst b/doc/api/v2/config/evaluators.rst index 39db51fa4a..9ac972fb19 100644 --- a/doc/api/v2/config/evaluators.rst +++ b/doc/api/v2/config/evaluators.rst @@ -99,3 +99,12 @@ value_printer .. automodule:: paddle.v2.evaluator :members: value_printer :noindex: + +Detection +===== + +detection_map +------------- +.. automodule:: paddle.v2.evaluator + :members: detection_map + :noindex: diff --git a/paddle/gserver/evaluators/DetectionMAPEvaluator.cpp b/paddle/gserver/evaluators/DetectionMAPEvaluator.cpp index 7d326c2db1..9b825db574 100644 --- a/paddle/gserver/evaluators/DetectionMAPEvaluator.cpp +++ b/paddle/gserver/evaluators/DetectionMAPEvaluator.cpp @@ -80,21 +80,20 @@ public: allGTBBoxes.push_back(bboxes); } - size_t imgId = 0; - for (size_t n = 0; n < cpuOutput_->getHeight();) { + size_t n = 0; + const real* cpuOutputData = cpuOutput_->getData(); + for (size_t imgId = 0; imgId < batchSize; ++imgId) { map>> bboxes; - while (cpuOutput_->getData()[n * 7] == imgId && - n < cpuOutput_->getHeight()) { + size_t curImgId = static_cast((cpuOutputData + n * 7)[0]); + while (curImgId == imgId && n < cpuOutput_->getHeight()) { vector label; vector score; vector bbox; - getBBoxFromDetectData( - cpuOutput_->getData() + n * 7, 1, label, score, bbox); + getBBoxFromDetectData(cpuOutputData + n * 7, 1, label, score, bbox); bboxes[label[0]].push_back(make_pair(score[0], bbox[0])); ++n; + curImgId = static_cast((cpuOutputData + n * 7)[0]); } - ++imgId; - if (imgId > batchSize) break; allDetectBBoxes.push_back(bboxes); } @@ -119,15 +118,14 @@ public: } // calcTFPos - calcTFPos( - batchSize, allGTBBoxes, allDetectBBoxes, &allTruePos_, &allFalsePos_); + calcTFPos(batchSize, allGTBBoxes, allDetectBBoxes); return 0; } virtual void printStats(std::ostream& os) const { real mAP = calcMAP(); - os << "Detection mAP=" << mAP * 100; + os << "Detection mAP=" << mAP; } virtual void distributeEval(ParameterClient2* client) { @@ -138,9 +136,7 @@ protected: void calcTFPos(const size_t batchSize, const vector>>& allGTBBoxes, const vector>>>& - allDetectBBoxes, - map>>* allTruePos, - map>>* allFalsePos) { + allDetectBBoxes) { for (size_t n = 0; n < allDetectBBoxes.size(); ++n) { if (allGTBBoxes[n].size() == 0) { for (map>>::const_iterator @@ -149,8 +145,8 @@ protected: ++it) { size_t label = it->first; for (size_t i = 0; i < it->second.size(); ++i) { - (*allTruePos)[label].push_back(make_pair(it->second[i].first, 0)); - (*allFalsePos)[label].push_back(make_pair(it->second[i].first, 1)); + allTruePos_[label].push_back(make_pair(it->second[i].first, 0)); + allFalsePos_[label].push_back(make_pair(it->second[i].first, 1)); } } } else { @@ -162,9 +158,8 @@ protected: vector> predBBoxes = it->second; if (allGTBBoxes[n].find(label) == allGTBBoxes[n].end()) { for (size_t i = 0; i < predBBoxes.size(); ++i) { - (*allTruePos)[label].push_back(make_pair(predBBoxes[i].first, 0)); - (*allFalsePos)[label].push_back( - make_pair(predBBoxes[i].first, 1)); + allTruePos_[label].push_back(make_pair(predBBoxes[i].first, 0)); + allFalsePos_[label].push_back(make_pair(predBBoxes[i].first, 1)); } } else { vector gtBBoxes = @@ -189,22 +184,21 @@ protected: if (evaluateDifficult_ || (!evaluateDifficult_ && !gtBBoxes[maxIdx].isDifficult)) { if (!visited[maxIdx]) { - (*allTruePos)[label].push_back( + allTruePos_[label].push_back( make_pair(predBBoxes[i].first, 1)); - (*allFalsePos)[label].push_back( + allFalsePos_[label].push_back( make_pair(predBBoxes[i].first, 0)); visited[maxIdx] = true; } else { - (*allTruePos)[label].push_back( + allTruePos_[label].push_back( make_pair(predBBoxes[i].first, 0)); - (*allFalsePos)[label].push_back( + allFalsePos_[label].push_back( make_pair(predBBoxes[i].first, 1)); } } } else { - (*allTruePos)[label].push_back( - make_pair(predBBoxes[i].first, 0)); - (*allFalsePos)[label].push_back( + allTruePos_[label].push_back(make_pair(predBBoxes[i].first, 0)); + allFalsePos_[label].push_back( make_pair(predBBoxes[i].first, 1)); } } @@ -274,7 +268,7 @@ protected: } } if (count != 0) mAP /= count; - return mAP; + return mAP * 100; } void getAccumulation(vector> inPairs, @@ -291,20 +285,22 @@ protected: std::string getTypeImpl() const { return "detection_map"; } - real getValueImpl() const { return calcMAP() * 100; } + real getValueImpl() const { return calcMAP(); } private: - real overlapThreshold_; - bool evaluateDifficult_; - size_t backgroundId_; - std::string apType_; + real overlapThreshold_; // overlap threshold when determining whether matched + bool evaluateDifficult_; // whether evaluate difficult ground truth + size_t backgroundId_; // class index of background + std::string apType_; // how to calculate mAP (Integral or 11point) MatrixPtr cpuOutput_; MatrixPtr cpuLabel_; - map numPos_; - map>> allTruePos_; - map>> allFalsePos_; + map numPos_; // counts of true objects each classification + map>> + allTruePos_; // true positive prediction + map>> + allFalsePos_; // false positive prediction }; REGISTER_EVALUATOR(detection_map, DetectionMAPEvaluator); diff --git a/python/paddle/trainer_config_helpers/evaluators.py b/python/paddle/trainer_config_helpers/evaluators.py index 1dcd804803..44d52edfa7 100644 --- a/python/paddle/trainer_config_helpers/evaluators.py +++ b/python/paddle/trainer_config_helpers/evaluators.py @@ -166,9 +166,9 @@ def detection_map_evaluator(input, ap_type="11point", name=None): """ - Detection mAP Evaluator. It will print mean Average Precision for detection. + Detection mAP Evaluator. It will print mean Average Precision (mAP) for detection. - The detection mAP Evaluator according to the detection_output's output count + The detection mAP Evaluator based on the output of detection_output layer counts the true positive and the false positive bbox and integral them to get the mAP. @@ -186,7 +186,7 @@ def detection_map_evaluator(input, :type overlap_threshold: float :param background_id: The background class index. :type background_id: int - :param evaluate_difficult: Wether evaluate a difficult ground truth. + :param evaluate_difficult: Whether evaluate a difficult ground truth. :type evaluate_difficult: bool """ if not isinstance(input, list): From 1eab8cce32b61f201098be482359defbfffc941b Mon Sep 17 00:00:00 2001 From: zlx Date: Wed, 21 Jun 2017 14:31:29 +0800 Subject: [PATCH 12/35] modify the annotations of HookAttribute, Variable declaration --- paddle/parameter/ParameterUpdaterHook.cpp | 31 ++++++++++--------- python/paddle/trainer_config_helpers/attrs.py | 20 +++++++----- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 738e86a622..66e554a70d 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -31,9 +31,9 @@ namespace paddle { /** * The static pruning hook - * Static means user specific a sparsity_ratio before training start, and the + * Static means user specify a sparsity_ratio before training started, and the * network will prune the parameters based on the sparsity_ratio. More deatils - * can see https://arxiv.org/pdf/1506.02626.pdf. + * can be found https://arxiv.org/pdf/1506.02626.pdf. */ class StaticPruningHook : public IParameterUpdaterHook { @@ -57,29 +57,31 @@ public: } void generateMask(Parameter* para) { - VectorPtr vec = para->getBuf(PARAMETER_VALUE); - maskTemp_ = Vector::create(para->getSize(), false); - maskTemp_->zeroMem(); - real* dataPtr = maskTemp_->getData(); + + VectorPtr maskTemp = Vector::create(para->getSize(), false); + maskTemp->zeroMem(); + real* maskTempData = maskTemp->getData(); size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); - VectorPtr vecCpu = Vector::create(para->getSize(), false); - vecCpu->copyFrom(*vec); + VectorPtr paraVec = para->getBuf(PARAMETER_VALUE); + VectorPtr paraCpuCopy = Vector::create(para->getSize(), false); + + paraCpuCopy->copyFrom(*paraVec); std::vector> param; for (size_t i = 0; i < para->getSize(); i++) - param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); + param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i)); std::partial_sort( param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); - for (size_t i = 0; i < nonZeroNum; i++) dataPtr[param[i].second] = 1.0; + for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0; // Currently just use a mask vector for hack. if (para->useGpu()) { maskVec_ = Vector::create(para->getSize(), para->useGpu()); - maskVec_->copyFrom(*maskTemp_); + maskVec_->copyFrom(*maskTemp); } else { - maskVec_ = maskTemp_; + maskVec_ = maskTemp; } } @@ -91,15 +93,14 @@ public: VLOG(3) << "Initialize Parameter " << para; SetDevice device(para->getDeviceId()); - auto& vec = para->getBuf(PARAMETER_VALUE); - vec->dotMul(*maskVec_); + auto& paraVec = para->getBuf(PARAMETER_VALUE); + paraVec->dotMul(*maskVec_); } private: SameThreadChecker updateThreadChecker_; std::atomic initCount_; VectorPtr maskVec_; - VectorPtr maskTemp_; real sparsityRatio_; }; diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index bf12ad644d..66163bdc8d 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -58,15 +58,21 @@ def is_compatible_with(x, Type): class HookAttribute(object): """ - Hook Attribute object. The hook is an auxiliary operation that occurs - during network propagation. - NOTE: IT IS A HIGH LEVEL USER INTERFACE. - - :param type: Hook type, eg: 'pruning' + Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs + during training process of a layer with parameters, such as img_conv layer, fc layer. + + :param type: Hook type, currently supported types: + 'pruning' : user specify a sparsity_ratio before training started, and the + network will prune the parameters based on the sparsity_ratio. + eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6) + The specific usage can be paddle.layer.img_conv(input=img, filter_size=3, + num_channels=3, num_filters=64, + param_attr=ParameterAttribute(update_hooks=hk) ) + The pruning deatils can be found https://arxiv.org/pdf/1506.02626.pdf :type type: string :param sparsity_ratio: Must be specified if hook type is 'pruning', - it represents the ratio of the zero elements to be set by the Parameter. + it represents the ratio of the zero elements to be set by the Parameter. :type sparsity_ratio: float or None """ @@ -78,7 +84,7 @@ class HookAttribute(object): assert is_compatible_with( self.sparsity_ratio, float), 'sparisity_ratio must be float type' - assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' + assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity_ratio must be a float between [0, 1] ' def __call__(self): return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) From badcdfe1e539ffcad75f601e687a83fd1512cff1 Mon Sep 17 00:00:00 2001 From: wuyi05 Date: Wed, 21 Jun 2017 15:05:41 +0800 Subject: [PATCH 13/35] pserver etcd registration --- go/cmd/pserver/pserver.go | 20 ++++++- go/pserver/client_test.go | 8 ++- go/pserver/service.go | 112 ++++++++++++++++++++++++++++++++++++- go/pserver/service_test.go | 25 ++++++--- go/utils/helper.go | 45 +++++++++++++++ go/utils/helper_test.go | 10 ++++ 6 files changed, 206 insertions(+), 14 deletions(-) create mode 100644 go/utils/helper.go create mode 100644 go/utils/helper_test.go diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index f0be251c24..ddf5ad40fd 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -5,18 +5,34 @@ import ( "net/http" "net/rpc" "strconv" + "time" "github.com/namsral/flag" "github.com/PaddlePaddle/Paddle/go/pserver" + log "github.com/sirupsen/logrus" ) func main() { port := flag.Int("port", 0, "port of the pserver") + etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", + "comma separated endpoint string for pserver to connect to etcd") + etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") + logLevel := flag.String("log-level", "info", "log level, one of debug") flag.Parse() - s := pserver.NewService() - err := rpc.Register(s) + level, err := log.ParseLevel(*logLevel) + if err != nil { + panic(err) + } + log.SetLevel(level) + + timeout := time.Second * time.Duration((*etcdTimeout)) + s, err := pserver.NewService(*etcdEndpoint, timeout) + if err != nil { + panic(err) + } + err = rpc.Register(s) if err != nil { panic(err) } diff --git a/go/pserver/client_test.go b/go/pserver/client_test.go index d0371a26a1..6ecf1fa08a 100644 --- a/go/pserver/client_test.go +++ b/go/pserver/client_test.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/PaddlePaddle/Paddle/go/pserver" ) @@ -30,9 +31,12 @@ func init() { port[i] = p go func(l net.Listener) { - s := pserver.NewService() + s, err := pserver.NewService("", time.Second*5) + if err != nil { + panic(err) + } server := rpc.NewServer() - err := server.Register(s) + err = server.Register(s) if err != nil { panic(err) } diff --git a/go/pserver/service.go b/go/pserver/service.go index 78a2bfaf63..a5c76857ab 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -1,9 +1,18 @@ package pserver import ( + "context" "errors" "fmt" + "strconv" + "strings" "sync" + "time" + + "github.com/PaddlePaddle/Paddle/go/utils" + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/concurrency" + log "github.com/sirupsen/logrus" ) // ElementType is the type of elements of a Parameter. @@ -47,14 +56,113 @@ type Service struct { mu sync.Mutex opt *optimizer paramMap map[string]Parameter + + etcdEndpoints string + etcdClient *clientv3.Client + // etcdTimeout is also used as retry intervals. + etcdTimeout time.Duration + // desired number of pservers in the job. + // assume desired will not change during one training job. + desired int + // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. + externalIP string } // NewService creates a new service. -func NewService() *Service { +func NewService(endpoints string, timeout time.Duration) (*Service, error) { s := &Service{opt: newOptimizer(sgd, 0.005)} s.paramMap = make(map[string]Parameter) s.initialized = make(chan struct{}) - return s + s.etcdEndpoints = endpoints + s.etcdTimeout = timeout + + var err error + s.externalIP, err = utils.GetExternalIP() + if err != nil { + return nil, err + } + + if endpoints != "" { + // initialize connection to etcd, try + ep := strings.Split(s.etcdEndpoints, ",") + for { + cli, err := clientv3.New(clientv3.Config{ + Endpoints: ep, + DialTimeout: s.etcdTimeout, + }) + if err != nil { + log.Errorf("connect to etcd error: %v", err) + time.Sleep(s.etcdTimeout) + continue + } + s.etcdClient = cli + log.Debugf("inited client to %s", s.etcdEndpoints) + break + } + // wait and set s.desired init value + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + resp, err := s.etcdClient.Get(ctx, "/ps_desired") + cancel() + if err != nil { + log.Errorf("getting /ps_desired error: %v", err) + time.Sleep(s.etcdTimeout) + continue + } + for _, ev := range resp.Kvs { + log.Debugf("key: %s, value: %s", ev.Key, ev.Value) + if string(ev.Key) == "/ps_desired" { + s.desired, err = strconv.Atoi(string(ev.Value)) + if err != nil { + log.Errorf("value of /ps_desired invalid %v\n", err) + time.Sleep(s.etcdTimeout) + // NOTE: wait util ps_desired value change + continue + } + } + } + break + } + s.registerPserverEtcd() + } // if endpoints != "" + // Bypass etcd registration if no endpoints specified + return s, nil +} + +// registerPserverEtcd registers pserver node on etcd using transaction. +func (s *Service) registerPserverEtcd() (*clientv3.TxnResponse, error) { + return concurrency.NewSTMRepeatable(context.TODO(), s.etcdClient, func(c concurrency.STM) error { + for i := 0; i < s.desired; i++ { + psKey := "/ps/" + strconv.Itoa(i) + log.Debugf("checking %s", psKey) + ps := c.Get(psKey) + log.Debugf("got value (%s) for key: %s", ps, psKey) + + resp, err := s.etcdClient.Grant(context.TODO(), 5) + if err != nil { + log.Fatal(err) + } + + if ps == "" { + // find the first id and write info + c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID)) + log.Debugf("set pserver node %s with value %s", psKey, s.externalIP) + ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) + if kaerr != nil { + log.Errorf("keepalive etcd node error: %v", kaerr) + return kaerr + } + // FIXME: does this really needed? + go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { + ka := <-ch + log.Debugf("keepalive: %d\n", ka.TTL) + }(ch) + break + } + } + log.Debug("register finished") + return nil + }) } // InitParam initializes a parameter. diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index b746d13e1c..f317535592 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -10,12 +10,15 @@ import ( ) func TestFull(t *testing.T) { - s := pserver.NewService() + s, err := pserver.NewService("", time.Second*5) + if err != nil { + t.Error(err) + } var p pserver.Parameter p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) + err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) if err != nil { t.FailNow() } @@ -72,8 +75,11 @@ func TestFull(t *testing.T) { } func TestMultipleInit(t *testing.T) { - s := pserver.NewService() - err := s.FinishInitParams(0, nil) + s, err := pserver.NewService("", time.Second*5) + if err != nil { + t.Error(err) + } + err = s.FinishInitParams(0, nil) if err != nil { t.FailNow() } @@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) { } func TestUninitialized(t *testing.T) { - s := pserver.NewService() - err := s.SendGrad(pserver.Gradient{}, nil) + s, err := pserver.NewService("", time.Second*5) + err = s.SendGrad(pserver.Gradient{}, nil) if err.Error() != pserver.Uninitialized { t.FailNow() } } func TestBlockUntilInitialized(t *testing.T) { - s := pserver.NewService() + s, err := pserver.NewService("", time.Second*5) + if err != nil { + t.Error(err) + } ch := make(chan struct{}, 2) errCh := make(chan error, 2) var wg sync.WaitGroup @@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) { p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) + err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) if err != nil { t.FailNow() } diff --git a/go/utils/helper.go b/go/utils/helper.go new file mode 100644 index 0000000000..3220fd6c78 --- /dev/null +++ b/go/utils/helper.go @@ -0,0 +1,45 @@ +package utils + +import ( + "errors" + "net" +) + +// GetExternalIP returns the ip address of local network interface, not the +// loopback device. +func GetExternalIP() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", err + } + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // interface down + } + if iface.Flags&net.FlagLoopback != 0 { + continue // loopback interface + } + addrs, err := iface.Addrs() + if err != nil { + return "", err + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip == nil || ip.IsLoopback() { + continue + } + ip = ip.To4() + if ip == nil { + continue // not an ipv4 address + } + return ip.String(), nil + } + } + return "", errors.New("are you connected to the network?") +} diff --git a/go/utils/helper_test.go b/go/utils/helper_test.go new file mode 100644 index 0000000000..aa7c509768 --- /dev/null +++ b/go/utils/helper_test.go @@ -0,0 +1,10 @@ +package utils + +import "testing" + +func TestGetIP(t *testing.T) { + _, err := GetExternalIP() + if err != nil { + t.Errorf("GetExternalIP returns error : %v\n", err) + } +} From b7a52bd9767de41d65382929b1629e95e35a3fe5 Mon Sep 17 00:00:00 2001 From: wuyi05 Date: Wed, 21 Jun 2017 15:25:02 +0800 Subject: [PATCH 14/35] add started info log --- go/cmd/pserver/pserver.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index ddf5ad40fd..f42c90c6c6 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -43,7 +43,9 @@ func main() { panic(err) } + log.Infof("start pserver at port %d", *port) err = http.Serve(l, nil) + if err != nil { panic(err) } From aaf11fa6259dc0c4cc248a102141b68d94685ad7 Mon Sep 17 00:00:00 2001 From: zlx Date: Wed, 21 Jun 2017 15:44:07 +0800 Subject: [PATCH 15/35] modify the format --- paddle/parameter/ParameterUpdaterHook.cpp | 46 +++++++++++------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 66e554a70d..ba2cb37fa2 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -38,29 +38,28 @@ namespace paddle { class StaticPruningHook : public IParameterUpdaterHook { public: - explicit StaticPruningHook(const ParameterUpdaterHookConfig& hookConfig) + explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig) : initCount_(0) { sparsityRatio_ = hookConfig.sparsity_ratio(); } - static bool sortPairAscend(const std::pair& pair1, - const std::pair& pair2) { + static bool sortPairAscend(const std::pair &pair1, + const std::pair &pair2) { return pair1.first > pair2.first; } - void update(Parameter* para) { + void update(Parameter *para) { updateThreadChecker_.check(); - auto& vec = para->getBuf(PARAMETER_GRADIENT); + auto &vec = para->getBuf(PARAMETER_GRADIENT); if (vec) { vec->dotMul(*maskVec_); } } - void generateMask(Parameter* para) { - + void generateMask(Parameter *para) { VectorPtr maskTemp = Vector::create(para->getSize(), false); maskTemp->zeroMem(); - real* maskTempData = maskTemp->getData(); + real *maskTempData = maskTemp->getData(); size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); VectorPtr paraVec = para->getBuf(PARAMETER_VALUE); @@ -72,9 +71,10 @@ public: for (size_t i = 0; i < para->getSize(); i++) param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i)); - std::partial_sort( - param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); - for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0; + std::partial_sort(param.begin(), param.begin() + nonZeroNum, param.end(), + sortPairAscend); + for (size_t i = 0; i < nonZeroNum; i++) + maskTempData[param[i].second] = 1.0; // Currently just use a mask vector for hack. if (para->useGpu()) { @@ -85,7 +85,7 @@ public: } } - void init(Parameter* para) { + void init(Parameter *para) { generateMask(para); size_t initCount = this->initCount_.fetch_add(1); CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " @@ -93,7 +93,7 @@ public: VLOG(3) << "Initialize Parameter " << para; SetDevice device(para->getDeviceId()); - auto& paraVec = para->getBuf(PARAMETER_VALUE); + auto ¶Vec = para->getBuf(PARAMETER_VALUE); paraVec->dotMul(*maskVec_); } @@ -118,7 +118,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {} */ class StringIntPairHasher { public: - size_t operator()(const std::pair& k) const { + size_t operator()(const std::pair &k) const { return intHasher_(strHasher_(k.first) + k.second); } @@ -127,17 +127,15 @@ private: std::hash intHasher_; }; -static WeakKVCache, - IParameterUpdaterHook, - StringIntPairHasher> - g_hookCache_; +static WeakKVCache, IParameterUpdaterHook, + StringIntPairHasher> g_hookCache_; /** * ParameterUpdaterHook actually factory method. */ -static IParameterUpdaterHook* createImpl( - const ParameterUpdaterHookConfig& config) { - auto& type = config.type(); +static IParameterUpdaterHook * +createImpl(const ParameterUpdaterHookConfig &config) { + auto &type = config.type(); if (type == "pruning") { return new StaticPruningHook(config); } @@ -146,11 +144,11 @@ static IParameterUpdaterHook* createImpl( return nullptr; } -std::shared_ptr IParameterUpdaterHook::create( - const ParameterConfig& paramConfig, int idx) { +std::shared_ptr +IParameterUpdaterHook::create(const ParameterConfig ¶mConfig, int idx) { std::pair key = {paramConfig.name(), idx}; return g_hookCache_.get( key, [&] { return createImpl(paramConfig.update_hooks(idx)); }); } -} // namespace paddle +} // namespace paddle From a266292a57613d16806cccd939d68f436731927c Mon Sep 17 00:00:00 2001 From: zlx Date: Wed, 21 Jun 2017 18:37:37 +0800 Subject: [PATCH 16/35] modify format --- paddle/parameter/ParameterUpdaterHook.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index ba2cb37fa2..968803fc0f 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -71,10 +71,9 @@ public: for (size_t i = 0; i < para->getSize(); i++) param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i)); - std::partial_sort(param.begin(), param.begin() + nonZeroNum, param.end(), - sortPairAscend); - for (size_t i = 0; i < nonZeroNum; i++) - maskTempData[param[i].second] = 1.0; + std::partial_sort( + param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); + for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0; // Currently just use a mask vector for hack. if (para->useGpu()) { @@ -127,14 +126,16 @@ private: std::hash intHasher_; }; -static WeakKVCache, IParameterUpdaterHook, - StringIntPairHasher> g_hookCache_; +static WeakKVCache, + IParameterUpdaterHook, + StringIntPairHasher> + g_hookCache_; /** * ParameterUpdaterHook actually factory method. */ -static IParameterUpdaterHook * -createImpl(const ParameterUpdaterHookConfig &config) { +static IParameterUpdaterHook *createImpl( + const ParameterUpdaterHookConfig &config) { auto &type = config.type(); if (type == "pruning") { return new StaticPruningHook(config); @@ -144,11 +145,11 @@ createImpl(const ParameterUpdaterHookConfig &config) { return nullptr; } -std::shared_ptr -IParameterUpdaterHook::create(const ParameterConfig ¶mConfig, int idx) { +std::shared_ptr IParameterUpdaterHook::create( + const ParameterConfig ¶mConfig, int idx) { std::pair key = {paramConfig.name(), idx}; return g_hookCache_.get( key, [&] { return createImpl(paramConfig.update_hooks(idx)); }); } -} // namespace paddle +} // namespace paddle From 0b936e9399f2a5f01f6fde1d1b78b56306a8f9ac Mon Sep 17 00:00:00 2001 From: wuyi05 Date: Thu, 22 Jun 2017 15:00:39 +0800 Subject: [PATCH 17/35] update pserver etcd --- go/cmd/pserver/pserver.go | 3 +- go/pserver/service.go | 75 ++++++++++++--------- go/utils/{ => networkhelper}/helper.go | 2 +- go/utils/{ => networkhelper}/helper_test.go | 2 +- 4 files changed, 47 insertions(+), 35 deletions(-) rename go/utils/{ => networkhelper}/helper.go (97%) rename go/utils/{ => networkhelper}/helper_test.go (87%) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index f42c90c6c6..fe1fe5f6f0 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -18,7 +18,8 @@ func main() { etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", "comma separated endpoint string for pserver to connect to etcd") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") - logLevel := flag.String("log-level", "info", "log level, one of debug") + logLevel := flag.String("log-level", "info", + "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() level, err := log.ParseLevel(*logLevel) diff --git a/go/pserver/service.go b/go/pserver/service.go index a5c76857ab..7400b48832 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/PaddlePaddle/Paddle/go/utils" + "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3/concurrency" log "github.com/sirupsen/logrus" @@ -33,6 +33,9 @@ const ( Float64 ) +// PsDesired is etcd path for store desired pserver count +const PsDesired = "/ps_desired" + // Parameter is a piece of data to sync with the parameter server. type Parameter struct { Name string @@ -68,7 +71,8 @@ type Service struct { externalIP string } -// NewService creates a new service. +// NewService creates a new service, will bypass etcd registration if no +// endpoints specified. func NewService(endpoints string, timeout time.Duration) (*Service, error) { s := &Service{opt: newOptimizer(sgd, 0.005)} s.paramMap = make(map[string]Parameter) @@ -77,7 +81,7 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) { s.etcdTimeout = timeout var err error - s.externalIP, err = utils.GetExternalIP() + s.externalIP, err = networkhelper.GetExternalIP() if err != nil { return nil, err } @@ -102,67 +106,74 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) { // wait and set s.desired init value for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - resp, err := s.etcdClient.Get(ctx, "/ps_desired") + resp, err := s.etcdClient.Get(ctx, PsDesired) cancel() if err != nil { - log.Errorf("getting /ps_desired error: %v", err) + log.Errorf("getting %s error: %v", PsDesired, err) time.Sleep(s.etcdTimeout) continue } - for _, ev := range resp.Kvs { - log.Debugf("key: %s, value: %s", ev.Key, ev.Value) - if string(ev.Key) == "/ps_desired" { - s.desired, err = strconv.Atoi(string(ev.Value)) - if err != nil { - log.Errorf("value of /ps_desired invalid %v\n", err) - time.Sleep(s.etcdTimeout) - // NOTE: wait util ps_desired value change - continue - } + if len(resp.Kvs) != 0 { + s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) + if err != nil { + log.Errorf("value of %s invalid %v\n", PsDesired, err) + time.Sleep(s.etcdTimeout) + // NOTE: wait util ps_desired value change + continue } + break + } + } + // try register pserver node on etcd + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := s.registerPserverEtcd(ctx) + cancel() + if err != nil { + log.Warn(err) + time.Sleep(s.etcdTimeout) + continue } break } - s.registerPserverEtcd() } // if endpoints != "" // Bypass etcd registration if no endpoints specified return s, nil } // registerPserverEtcd registers pserver node on etcd using transaction. -func (s *Service) registerPserverEtcd() (*clientv3.TxnResponse, error) { - return concurrency.NewSTMRepeatable(context.TODO(), s.etcdClient, func(c concurrency.STM) error { +func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) { + return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { + registered := false for i := 0; i < s.desired; i++ { psKey := "/ps/" + strconv.Itoa(i) log.Debugf("checking %s", psKey) ps := c.Get(psKey) log.Debugf("got value (%s) for key: %s", ps, psKey) - resp, err := s.etcdClient.Grant(context.TODO(), 5) - if err != nil { - log.Fatal(err) - } - if ps == "" { + resp, err := s.etcdClient.Grant(context.TODO(), 5) + if err != nil { + log.Fatal(err) + } // find the first id and write info c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID)) log.Debugf("set pserver node %s with value %s", psKey, s.externalIP) - ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) + _, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) if kaerr != nil { log.Errorf("keepalive etcd node error: %v", kaerr) return kaerr } - // FIXME: does this really needed? - go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { - ka := <-ch - log.Debugf("keepalive: %d\n", ka.TTL) - }(ch) + log.Debug("register finished") + registered = true break } } - log.Debug("register finished") - return nil - }) + if registered == true { + return nil + } + return errors.New("not registerd, may due to already have enough pservers") + }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) } // InitParam initializes a parameter. diff --git a/go/utils/helper.go b/go/utils/networkhelper/helper.go similarity index 97% rename from go/utils/helper.go rename to go/utils/networkhelper/helper.go index 3220fd6c78..fbeaea8f5e 100644 --- a/go/utils/helper.go +++ b/go/utils/networkhelper/helper.go @@ -1,4 +1,4 @@ -package utils +package networkhelper import ( "errors" diff --git a/go/utils/helper_test.go b/go/utils/networkhelper/helper_test.go similarity index 87% rename from go/utils/helper_test.go rename to go/utils/networkhelper/helper_test.go index aa7c509768..4208f9e358 100644 --- a/go/utils/helper_test.go +++ b/go/utils/networkhelper/helper_test.go @@ -1,4 +1,4 @@ -package utils +package networkhelper import "testing" From e558ed0d5eade5d2bf6a1bb37beeb39486e9dd76 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Fri, 23 Jun 2017 01:48:04 +0000 Subject: [PATCH 18/35] fix etcd lease I made a comment in WuYi's PR that this is not necessary, so WuYi removed it. Turns out it's necessary after confirming with coreOS developer. --- go/pserver/service.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/go/pserver/service.go b/go/pserver/service.go index 7400b48832..7e2b841dd8 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -159,11 +159,18 @@ func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnRespons // find the first id and write info c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID)) log.Debugf("set pserver node %s with value %s", psKey, s.externalIP) - _, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) + ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) if kaerr != nil { log.Errorf("keepalive etcd node error: %v", kaerr) return kaerr } + + // Eat the keep alive message so etcd + // will not expire the lease. + go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { + ka := <-ch + log.Debugf("keepalive: %d\n", ka.TTL) + }(ch) log.Debug("register finished") registered = true break From c2fc896f5b2896fc6509e720e7dc08527495927f Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Thu, 22 Jun 2017 19:05:28 -0700 Subject: [PATCH 19/35] Simplify Travis CI configuration --- .travis.yml | 2 -- paddle/scripts/travis/build_and_test.sh | 12 ------------ paddle/scripts/travis/{docs.sh => build_doc.sh} | 13 ++++++++----- .../scripts/travis/{precommit.sh => check_style.sh} | 8 ++++---- paddle/scripts/travis/main.sh | 12 +++++------- 5 files changed, 17 insertions(+), 30 deletions(-) delete mode 100755 paddle/scripts/travis/build_and_test.sh rename paddle/scripts/travis/{docs.sh => build_doc.sh} (84%) rename paddle/scripts/travis/{precommit.sh => check_style.sh} (54%) diff --git a/.travis.yml b/.travis.yml index 87cef10b2b..915c23b7ab 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,6 @@ group: deprecated-2017Q2 language: cpp cache: directories: - - $HOME/third_party - $HOME/.ccache - $HOME/.cache/pip sudo: required @@ -18,7 +17,6 @@ addons: packages: - gcc-4.8 - g++-4.8 - - gfortran-4.8 - git - build-essential - python diff --git a/paddle/scripts/travis/build_and_test.sh b/paddle/scripts/travis/build_and_test.sh deleted file mode 100755 index f2cbc56165..0000000000 --- a/paddle/scripts/travis/build_and_test.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash -source ./common.sh - -NPROC=1 -export PYTHONPATH=/opt/python/2.7.12/lib/python2.7/site-packages -export PYTHONHOME=/opt/python/2.7.12 -export PATH=/opt/python/2.7.12/bin:${PATH} -cmake .. -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DON_TRAVIS=ON -DWITH_COVERAGE=ON -DCOVERALLS_UPLOAD=ON ${EXTRA_CMAKE_OPTS} -NRPOC=`nproc` -make -j $NPROC -make coveralls -sudo make install diff --git a/paddle/scripts/travis/docs.sh b/paddle/scripts/travis/build_doc.sh similarity index 84% rename from paddle/scripts/travis/docs.sh rename to paddle/scripts/travis/build_doc.sh index c784293695..88264d8c26 100755 --- a/paddle/scripts/travis/docs.sh +++ b/paddle/scripts/travis/build_doc.sh @@ -1,15 +1,18 @@ #!/bin/bash +set -e + +# Create the build directory for CMake. +mkdir -p $TRAVIS_BUILD_DIR/build +cd $TRAVIS_BUILD_DIR/build -# Add set -e, cd to directory. -source ./common.sh # Compile Documentation only. -cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_STYLE_CHECK=OFF ${EXTRA_CMAKE_OPTS} +cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_STYLE_CHECK=OFF mkdir output make -j `nproc` find .. -name '*whl' | xargs pip install # install all wheels. rm -rf * -cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DWITH_GPU=OFF -DWITH_DOC=ON ${EXTRA_CMAKE_OPTS} -make paddle_docs paddle_docs_cn +cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=ON +make -j `nproc` paddle_docs paddle_docs_cn # check websites for broken links linkchecker doc/en/html/index.html diff --git a/paddle/scripts/travis/precommit.sh b/paddle/scripts/travis/check_style.sh similarity index 54% rename from paddle/scripts/travis/precommit.sh rename to paddle/scripts/travis/check_style.sh index 7a59b1131d..4754bdd4c8 100755 --- a/paddle/scripts/travis/precommit.sh +++ b/paddle/scripts/travis/check_style.sh @@ -1,14 +1,14 @@ #!/bin/bash function abort(){ - echo "Your commit not fit PaddlePaddle code style" 1>&2 - echo "Please use pre-commit scripts to auto-format your code" 1>&2 + echo "Your change doesn't follow PaddlePaddle's code style." 1>&2 + echo "Please use pre-commit to reformat your code and git push again." 1>&2 exit 1 } trap 'abort' 0 set -e -source common.sh -cd .. + +cd $TRAVIS_BUILD_DIR export PATH=/usr/bin:$PATH pre-commit install clang-format --version diff --git a/paddle/scripts/travis/main.sh b/paddle/scripts/travis/main.sh index 13f2552d29..30afe60f60 100755 --- a/paddle/scripts/travis/main.sh +++ b/paddle/scripts/travis/main.sh @@ -1,13 +1,11 @@ #!/bin/bash cd `dirname $0` -if [ ${JOB} == "BUILD_AND_TEST" ]; then - ./build_and_test.sh -elif [ ${JOB} == "DOCS" ]; then - ./docs.sh +if [ ${JOB} == "DOCS" ]; then + ./build_doc.sh elif [ ${JOB} == "PRE_COMMIT" ]; then - ./precommit.sh + ./check_style.sh else - echo Unknown job ${JOB} - exit 1 + echo "Unknown Travis CI job: ${JOB}" + exit 0 # Don't fail due to unknown Travis CI job. fi From 7cf640b58ddeb2cc91d027ade8a6f326d42b5a8d Mon Sep 17 00:00:00 2001 From: Peng Li Date: Fri, 23 Jun 2017 10:26:46 +0800 Subject: [PATCH 20/35] add coeff parameter to classification_cost --- python/paddle/trainer_config_helpers/layers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index b8ce0373c0..84ed160773 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -3839,7 +3839,8 @@ def classification_cost(input, weight=None, name=None, evaluator=classification_error_evaluator, - layer_attr=None): + layer_attr=None, + coeff=1.): """ classification cost Layer. @@ -3855,6 +3856,8 @@ def classification_cost(input, :param evaluator: Evaluator method. :param layer_attr: layer's extra attribute. :type layer_attr: ExtraLayerAttribute + :param coeff: The coefficient affects the gradient in the backward. + :type coeff: float :return: LayerOutput object. :rtype: LayerOutput """ @@ -3868,6 +3871,7 @@ def classification_cost(input, name=name, type="multi-class-cross-entropy", inputs=ipts, + coeff=coeff, **ExtraLayerAttribute.to_kwargs(layer_attr)) def __add_evaluator__(e): From fba4649bcac265ce720fc8e71f0625f228ad2812 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 23 Jun 2017 10:31:21 +0800 Subject: [PATCH 21/35] Remove `BUILD_AND_TEST` section in travis.yaml --- .travis.yml | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/.travis.yml b/.travis.yml index 915c23b7ab..6b4fb4c4b6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,7 +10,6 @@ os: - linux env: - JOB=DOCS - - JOB=BUILD_AND_TEST - JOB=PRE_COMMIT addons: apt: @@ -33,17 +32,6 @@ addons: - libtool - ccache before_install: - - | - if [ ${JOB} == "BUILD_AND_TEST" ]; then - local change_list=`git diff --name-only $TRAVIS_COMMIT_RANGE` - if [ $? -eq 0 ]; then # if git diff return no zero, then rerun unit test. - if ! echo ${change_list} | grep -qvE '(\.md$)|(\.rst$)|(\.jpg$)|(\.png$)' - then - echo "Only markdown docs were updated, stopping build process." - exit - fi - fi - fi - if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python # protobuf version. From 260416559264c7a8d4dc63cd79619752a862cdf4 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Fri, 23 Jun 2017 10:37:19 +0800 Subject: [PATCH 22/35] "resolve clock skewed" --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 39af60966b..bf227737c5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,7 +25,7 @@ COPY ./paddle/scripts/docker/root/ /root/ RUN apt-get update && \ apt-get install -y \ git python-pip python-dev openssh-server bison \ - wget unzip tar xz-utils bzip2 gzip coreutils \ + wget unzip tar xz-utils bzip2 gzip coreutils ntp \ curl sed grep graphviz libjpeg-dev zlib1g-dev \ python-numpy python-matplotlib gcc g++ \ automake locales clang-format-3.8 swig doxygen cmake \ From fdde4eff0da95a170f2f727a8345057f20be09ef Mon Sep 17 00:00:00 2001 From: zlx Date: Fri, 23 Jun 2017 12:00:45 +0800 Subject: [PATCH 23/35] modify some topo --- paddle/parameter/ParameterUpdaterHook.cpp | 2 +- python/paddle/trainer_config_helpers/attrs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 968803fc0f..c8b47687f5 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -32,7 +32,7 @@ namespace paddle { /** * The static pruning hook * Static means user specify a sparsity_ratio before training started, and the - * network will prune the parameters based on the sparsity_ratio. More deatils + * network will prune the parameters based on the sparsity_ratio. More details * can be found https://arxiv.org/pdf/1506.02626.pdf. */ diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index 66163bdc8d..c02306f394 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -84,7 +84,7 @@ class HookAttribute(object): assert is_compatible_with( self.sparsity_ratio, float), 'sparisity_ratio must be float type' - assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity_ratio must be a float between [0, 1] ' + assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparsity_ratio must be a float between [0, 1] ' def __call__(self): return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) From 72c1a7fb5e2871ba3f6384ea28eaeed10aa5e76a Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Thu, 22 Jun 2017 21:06:07 -0700 Subject: [PATCH 24/35] Remove common.sh --- paddle/scripts/travis/common.sh | 6 ------ 1 file changed, 6 deletions(-) delete mode 100755 paddle/scripts/travis/common.sh diff --git a/paddle/scripts/travis/common.sh b/paddle/scripts/travis/common.sh deleted file mode 100755 index f05c7530a3..0000000000 --- a/paddle/scripts/travis/common.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -set -e -mkdir -p ../../../build -cd ../../../build -mkdir -p $HOME/third_party -EXTRA_CMAKE_OPTS="-DTHIRD_PARTY_PATH=${HOME}/third_party" From 0cbe120d8c06a1c064293918986264cb320bdb78 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Thu, 22 Jun 2017 21:16:07 -0700 Subject: [PATCH 25/35] Remove paddle/script/travis/main.sh --- .travis.yml | 10 ++++------ paddle/scripts/travis/main.sh | 11 ----------- 2 files changed, 4 insertions(+), 17 deletions(-) delete mode 100755 paddle/scripts/travis/main.sh diff --git a/.travis.yml b/.travis.yml index 6b4fb4c4b6..2c46da71e7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,8 +9,8 @@ dist: trusty os: - linux env: - - JOB=DOCS - - JOB=PRE_COMMIT + - JOB=build_doc + - JOB=check_style addons: apt: packages: @@ -32,7 +32,7 @@ addons: - libtool - ccache before_install: - - if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi + - if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python # protobuf version. - pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker @@ -41,9 +41,7 @@ before_install: - | function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } script: - - | - timeout 2580 paddle/scripts/travis/main.sh # 43min timeout - RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else false; fi; + - paddle/scripts/travis/$JOB.sh notifications: email: on_success: change diff --git a/paddle/scripts/travis/main.sh b/paddle/scripts/travis/main.sh deleted file mode 100755 index 30afe60f60..0000000000 --- a/paddle/scripts/travis/main.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -cd `dirname $0` - -if [ ${JOB} == "DOCS" ]; then - ./build_doc.sh -elif [ ${JOB} == "PRE_COMMIT" ]; then - ./check_style.sh -else - echo "Unknown Travis CI job: ${JOB}" - exit 0 # Don't fail due to unknown Travis CI job. -fi From 1d6b8595490d0d679a18329eccfa53d8bb285b96 Mon Sep 17 00:00:00 2001 From: zlx Date: Fri, 23 Jun 2017 13:11:31 +0800 Subject: [PATCH 26/35] modity topo --- python/paddle/trainer_config_helpers/attrs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index c02306f394..9b9f979bb6 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -68,7 +68,7 @@ class HookAttribute(object): The specific usage can be paddle.layer.img_conv(input=img, filter_size=3, num_channels=3, num_filters=64, param_attr=ParameterAttribute(update_hooks=hk) ) - The pruning deatils can be found https://arxiv.org/pdf/1506.02626.pdf + The pruning details can be found https://arxiv.org/pdf/1506.02626.pdf :type type: string :param sparsity_ratio: Must be specified if hook type is 'pruning', From c89fe83a775b0c8264f00de589d263fc6faec615 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 23 Jun 2017 16:32:05 +0800 Subject: [PATCH 27/35] Fix the problem that protobuf cannot be used as a DEPS argument in cc_library. --- cmake/external/protobuf.cmake | 61 ++++++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 7340394b1e..ce32b2531e 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -14,11 +14,41 @@ INCLUDE(ExternalProject) +# Print and set the protobuf library information, +# finish this cmake process and exit from this file. macro(PROMPT_PROTOBUF_LIB) + SET(protobuf_DEPS ${ARGN}) + MESSAGE(STATUS "Protobuf protoc executable: ${PROTOBUF_PROTOC_EXECUTABLE}") MESSAGE(STATUS "Protobuf library: ${PROTOBUF_LIBRARY}") MESSAGE(STATUS "Protobuf version: ${PROTOBUF_VERSION}") INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR}) + + # Assuming that all the protobuf libraries are of the same type. + IF(${PROTOBUF_LIBRARY} MATCHES "${STATIC_LIBRARY_SUFFIX}$") + SET(protobuf_LIBTYPE STATIC) + ELSEIF(${PROTOBUF_LIBRARY} MATCHES "${DYNAMIC_LIBRARY_SUFFIX}$") + SET(protobuf_LIBTYPE SHARED) + ELSE() + MESSAGE(FATAL_ERROR "Unknown library type: ${PROTOBUF_LIBRARY}") + ENDIF() + + ADD_LIBRARY(protobuf ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET protobuf PROPERTY IMPORTED_LOCATION ${PROTOBUF_LIBRARY}) + + ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY}) + + ADD_LIBRARY(protoc ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY}) + + FOREACH(dep ${protobuf_DEPS}) + ADD_DEPENDENCIES(protobuf ${dep}) + ADD_DEPENDENCIES(protobuf_lite ${dep}) + ADD_DEPENDENCIES(protoc ${dep}) + ENDFOREACH() + + LIST(APPEND external_project_dependencies protobuf) RETURN() endmacro() macro(SET_PROTOBUF_VERSION) @@ -43,8 +73,9 @@ if (NOT "${PROTOBUF_ROOT}" STREQUAL "") endif() FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) - SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/${TARGET_NAME}) - SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}/install/${TARGET_NAME}) + STRING(REPLACE "extern_" "" TARGET_DIR_NAME "${TARGET_NAME}") + SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/${TARGET_DIR_NAME}) + SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}/install/${TARGET_DIR_NAME}) SET(${TARGET_NAME}_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) SET(PROTOBUF_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) @@ -109,6 +140,8 @@ IF(NOT CMAKE_CROSSCOMPILING) SET_PROTOBUF_VERSION() IF("${PROTOBUF_VERSION}" VERSION_LESS "3.1.0") SET(PROTOBUF_FOUND OFF) + ELSE() + PROMPT_PROTOBUF_LIB() ENDIF() ENDIF(PROTOBUF_FOUND) ELSE() @@ -120,18 +153,22 @@ ELSE() ENDIF() IF(NOT PROTOBUF_FOUND) - build_protobuf(protobuf FALSE) - LIST(APPEND external_project_dependencies protobuf) + build_protobuf(extern_protobuf FALSE) - SET(PROTOBUF_INCLUDE_DIR ${protobuf_INCLUDE_DIR} + SET(PROTOBUF_INCLUDE_DIR ${extern_protobuf_INCLUDE_DIR} CACHE PATH "protobuf include directory." FORCE) - IF(NOT CMAKE_CROSSCOMPILING) - SET(PROTOBUF_PROTOC_EXECUTABLE ${protobuf_PROTOC_EXECUTABLE} + SET(PROTOBUF_LITE_LIBRARY ${extern_protobuf_LITE_LIBRARY} + CACHE FILEPATH "protobuf lite library." FORCE) + SET(PROTOBUF_LIBRARY ${extern_protobuf_LIBRARY} + CACHE FILEPATH "protobuf library." FORCE) + SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY} + CACHE FILEPATH "protoc library." FORCE) + + IF(CMAKE_CROSSCOMPILING) + PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf) + ELSE() + SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE} CACHE FILEPATH "protobuf executable." FORCE) + PROMPT_PROTOBUF_LIB(extern_protobuf) ENDIF() - SET(PROTOBUF_LITE_LIBRARY ${protobuf_LITE_LIBRARY} CACHE FILEPATH "protobuf lite library." FORCE) - SET(PROTOBUF_LIBRARY ${protobuf_LIBRARY} CACHE FILEPATH "protobuf library." FORCE) - SET(PROTOBUF_PROTOC_LIBRARY ${protobuf_PROTOC_LIBRARY} CACHE FILEPATH "protoc library." FORCE) ENDIF(NOT PROTOBUF_FOUND) - -PROMPT_PROTOBUF_LIB() \ No newline at end of file From 16f8508d74bd7d40776ad442a927f89d17960d6b Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 23 Jun 2017 17:46:32 +0800 Subject: [PATCH 28/35] Use CMake system variables, such as CMAKE_STATIC_LIBRARY_PREFIX/SUFFIX, instead. --- cmake/external/openblas.cmake | 3 ++- cmake/external/protobuf.cmake | 12 ++++++------ cmake/system.cmake | 18 ------------------ 3 files changed, 8 insertions(+), 25 deletions(-) diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index 2341e3785b..5b9d9844ed 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -21,7 +21,8 @@ IF(NOT ${CBLAS_FOUND}) SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE) - SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/${LIBRARY_PREFIX}openblas${STATIC_LIBRARY_SUFFIX}" + SET(CBLAS_LIBRARIES + "${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE FILEPATH "openblas library." FORCE) SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1 libs) diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index ce32b2531e..d43badc1da 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -25,9 +25,9 @@ macro(PROMPT_PROTOBUF_LIB) INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR}) # Assuming that all the protobuf libraries are of the same type. - IF(${PROTOBUF_LIBRARY} MATCHES "${STATIC_LIBRARY_SUFFIX}$") + IF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_STATIC_LIBRARY_SUFFIX}$") SET(protobuf_LIBTYPE STATIC) - ELSEIF(${PROTOBUF_LIBRARY} MATCHES "${DYNAMIC_LIBRARY_SUFFIX}$") + ELSEIF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_SHARED_LIBRARY_SUFFIX}$") SET(protobuf_LIBTYPE SHARED) ELSE() MESSAGE(FATAL_ERROR "Unknown library type: ${PROTOBUF_LIBRARY}") @@ -80,16 +80,16 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) SET(${TARGET_NAME}_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) SET(PROTOBUF_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) SET(${TARGET_NAME}_LITE_LIBRARY - "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf-lite${STATIC_LIBRARY_SUFFIX}" + "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf-lite${CMAKE_STATIC_LIBRARY_SUFFIX}" PARENT_SCOPE) SET(${TARGET_NAME}_LIBRARY - "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf${STATIC_LIBRARY_SUFFIX}" + "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf${CMAKE_STATIC_LIBRARY_SUFFIX}" PARENT_SCOPE) SET(${TARGET_NAME}_PROTOC_LIBRARY - "${PROTOBUF_INSTALL_DIR}/lib/libprotoc${STATIC_LIBRARY_SUFFIX}" + "${PROTOBUF_INSTALL_DIR}/lib/libprotoc${CMAKE_STATIC_LIBRARY_SUFFIX}" PARENT_SCOPE) SET(${TARGET_NAME}_PROTOC_EXECUTABLE - "${PROTOBUF_INSTALL_DIR}/bin/protoc${EXECUTABLE_SUFFIX}" + "${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}" PARENT_SCOPE) SET(OPTIONAL_CACHE_ARGS "") diff --git a/cmake/system.cmake b/cmake/system.cmake index 904652413e..3b5cbfdd63 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -84,24 +84,6 @@ IF(DEFINED CMAKE_SYSTEM_NAME) ENDIF() ENDIF() -# prefix and suffix on different os -IF(WIN32) - SET(LIBRARY_PREFIX "") - SET(SHARED_LIBRARY_SUFFIX ".dll") - SET(STATIC_LIBRARY_SUFFIX ".lib") - SET(EXECUTABLE_SUFFIX ".exe") -ELSE(WIN32) - SET(LIBRARY_PREFIX "lib") - IF(APPLE) - SET(SHARED_LIBRARY_SUFFIX ".dylib") - ELSE(APPLE) - SET(SHARED_LIBRARY_SUFFIX ".so") - ENDIF(APPLE) - - SET(STATIC_LIBRARY_SUFFIX ".a") - SET(EXECUTABLE_SUFFIX "") -ENDIF(WIN32) - # external dependencies log output SET(EXTERNAL_PROJECT_LOG_ARGS LOG_DOWNLOAD 0 # Wrap download in script to log output From fd8937556f95db4086ce095efa1e83041c896334 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Tue, 20 Jun 2017 23:57:07 +0000 Subject: [PATCH 29/35] Master save and load state from etcd --- go/cmd/master/master.go | 55 ++++++++++-- go/master/client_internal_test.go | 21 ++++- go/master/client_test.go | 21 ++++- go/master/etcd_store.go | 133 ++++++++++++++++++++++++++++ go/master/service.go | 142 +++++++++++++++++++++++------- go/pserver/cclient/cclient.go | 6 +- 6 files changed, 330 insertions(+), 48 deletions(-) create mode 100644 go/master/etcd_store.go diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index 25cd1cafcd..49ad0300b8 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -5,41 +5,80 @@ import ( "net/http" "net/rpc" "strconv" + "strings" + "sync" "time" "github.com/namsral/flag" + log "github.com/sirupsen/logrus" "github.com/PaddlePaddle/Paddle/go/master" ) +type inMemStore struct { + mu sync.Mutex + buf []byte +} + +func (m *inMemStore) Save(b []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.buf = b + return nil +} + +func (m *inMemStore) Load() ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + return m.buf, nil +} + func main() { port := flag.Int("port", 8080, "port of the master server.") - faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") + ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.") + endpoints := flag.String("endpoints", "", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") flag.Parse() - if *faultTolerance { - panic("fault tolernance not implemented.") + if *endpoints == "" { + log.Warningln("-endpoints not set, fault tolerance not be enabled.") + } + + var store master.Store + if *endpoints != "" { + eps := strings.Split(*endpoints, ",") + var err error + store, err = master.NewEtcdStore(eps, master.DefaultLockPath, master.DefaultStatePath, *ttlSec) + if err != nil { + log.Fatal(err) + } + } else { + store = &inMemStore{} + } + s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) + if err != nil { + log.Fatal(err) } - s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) - err := rpc.Register(s) + err = rpc.Register(s) if err != nil { - panic(err) + log.Fatal(err) } rpc.HandleHTTP() l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) if err != nil { - panic(err) + log.Fatal(err) } err = http.Serve(l, nil) if err != nil { - panic(err) + log.Fatal(err) } } diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go index 00fcca0e2c..a5b76fe853 100644 --- a/go/master/client_internal_test.go +++ b/go/master/client_internal_test.go @@ -32,6 +32,19 @@ func (a TestAddresser) Address() string { return string(a) } +type myStore struct { + buf []byte +} + +func (m *myStore) Save(b []byte) error { + m.buf = b + return nil +} + +func (m *myStore) Load() ([]byte, error) { + return m.buf, nil +} + func TestGetFinishTask(t *testing.T) { const path = "/tmp/master_client_test_0" @@ -47,9 +60,13 @@ func TestGetFinishTask(t *testing.T) { } go func(l net.Listener) { - s := NewService(chunkPerTask, time.Second, 1) + s, err := NewService(&myStore{}, chunkPerTask, time.Second, 1) + if err != nil { + panic(err) + } + server := rpc.NewServer() - err := server.Register(s) + err = server.Register(s) if err != nil { panic(err) } diff --git a/go/master/client_test.go b/go/master/client_test.go index 2b3f873ecf..ae5f17c2d4 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -15,6 +15,19 @@ import ( "github.com/PaddlePaddle/recordio" ) +type myStore struct { + buf []byte +} + +func (m *myStore) Save(b []byte) error { + m.buf = b + return nil +} + +func (m *myStore) Load() ([]byte, error) { + return m.buf, nil +} + func TestNextRecord(t *testing.T) { const ( path = "/tmp/master_client_TestFull" @@ -33,9 +46,13 @@ func TestNextRecord(t *testing.T) { } go func(l net.Listener) { - s := master.NewService(10, time.Second, 1) + s, err := master.NewService(&myStore{}, 10, time.Second, 1) + if err != nil { + panic(err) + } + server := rpc.NewServer() - err := server.Register(s) + err = server.Register(s) if err != nil { panic(err) } diff --git a/go/master/etcd_store.go b/go/master/etcd_store.go new file mode 100644 index 0000000000..ce178370ff --- /dev/null +++ b/go/master/etcd_store.go @@ -0,0 +1,133 @@ +package master + +import ( + "context" + "sync" + + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/concurrency" + log "github.com/sirupsen/logrus" +) + +const ( + // DefaultLockPath is the default etcd master lock path. + DefaultLockPath = "/master/lock" + // DefaultStatePath is the default etcd key for master state. + DefaultStatePath = "/master/state" +) + +// EtcdStore is the Store implementation backed by etcd. +type EtcdStore struct { + lockPath string + statePath string + ttlSec int + client *clientv3.Client + + mu sync.Mutex + lock *concurrency.Mutex +} + +// NewEtcdStore creates a new EtcdStore. +func NewEtcdStore(endpoints []string, lockPath, statePath string, ttlSec int) (*EtcdStore, error) { + cli, err := clientv3.New(clientv3.Config{ + Endpoints: endpoints, + DialTimeout: dialTimeout, + }) + if err != nil { + return nil, err + } + + sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec)) + if err != nil { + return nil, err + } + + lock := concurrency.NewMutex(sess, lockPath) + // It's fine for the lock to get stuck, in this case we have + // multiple master servers running (only configured to have + // one master running, but split-brain problem may cuase + // multiple master servers running), and the cluster management + // software will kill one of them. + log.Infof("Trying to acquire lock at %s.", lockPath) + err = lock.Lock(context.TODO()) + if err != nil { + return nil, err + } + log.Infof("Successfully acquired lock at %s.", lockPath) + + e := &EtcdStore{} + e.client = cli + e.lock = lock + e.lockPath = lockPath + e.statePath = statePath + e.ttlSec = ttlSec + return e, nil +} + +// Save saves the state into the etcd. +func (e *EtcdStore) Save(state []byte) error { + e.mu.Lock() + defer e.mu.Unlock() + + ctx := context.TODO() + put := clientv3.OpPut(e.statePath, string(state)) + resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit() + if err != nil { + return err + } + + if !resp.Succeeded { + log.Errorln("No longer owns the lock, trying to lock and save again.") + sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(e.ttlSec)) + if err != nil { + return err + } + + e.lock = concurrency.NewMutex(sess, e.lockPath) + log.Infof("Try to acquire lock at %s.", e.lockPath) + err = e.lock.Lock(context.TODO()) + if err != nil { + return err + } + log.Infof("Successfully acquired lock at %s.", e.lockPath) + return e.Save(state) + } + + return nil +} + +// Load loads the state from etcd. +func (e *EtcdStore) Load() ([]byte, error) { + e.mu.Lock() + ctx := context.TODO() + get := clientv3.OpGet(e.statePath) + + resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit() + if err != nil { + return nil, err + } + + if !resp.Succeeded { + log.Errorln("No longer owns the lock, trying to lock and load again.") + sess, err := concurrency.NewSession(e.client) + if err != nil { + return nil, err + } + + e.lock = concurrency.NewMutex(sess, e.lockPath) + e.lock.Lock(context.TODO()) + e.mu.Unlock() + return e.Load() + } + + kvs := resp.Responses[0].GetResponseRange().Kvs + if len(kvs) == 0 { + // No state exists + e.mu.Unlock() + return nil, nil + } + + state := kvs[0].Value + e.mu.Unlock() + return state, nil +} diff --git a/go/master/service.go b/go/master/service.go index 55e1e2d1a4..d453777b05 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -1,6 +1,9 @@ package master import ( + "bytes" + "compress/gzip" + "encoding/gob" "errors" "os" "path/filepath" @@ -12,24 +15,54 @@ import ( "github.com/PaddlePaddle/recordio" ) +const ( + dialTimeout = 5 * time.Second +) + +// Store is the interface for save and load the master state. +type Store interface { + Save([]byte) error + Load() ([]byte, error) +} + +// Chunk is a chunk of data consisted of several data instances. +type Chunk struct { + Path string + Index recordio.Index // chunk index +} + +// Task is the basic unit of data instances assigned to trainers. +type Task struct { + ID int + Chunks []Chunk +} + +type taskEntry struct { + Epoch int + NumTimeout int + Task Task +} + +type taskQueues struct { + Todo []taskEntry + Pending map[int]taskEntry // map from task ID to task entry + Done []taskEntry + Failed []Task +} + // Service is the master server service. type Service struct { chunksPerTask int timeoutDur time.Duration timeoutMax int ready chan struct{} + store Store mu sync.Mutex initDone bool taskQueues taskQueues } -// Recover recovers service state from etcd. -func Recover() (*Service, error) { - // TODO(helin): recover from snapshot state from etcd. - return nil, nil -} - func partition(chunks []Chunk, chunksPerTask int) []taskEntry { id := 0 if chunksPerTask <= 0 { @@ -58,7 +91,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { } // NewService creates a new service. -func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { +func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) { s := &Service{} s.chunksPerTask = chunksPerTask s.timeoutDur = timeoutDur @@ -66,38 +99,81 @@ func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Se s.taskQueues = taskQueues{} s.taskQueues.Pending = make(map[int]taskEntry) s.ready = make(chan struct{}) - return s -} + s.store = store + recovered, err := s.recover() + if err != nil { + return nil, err + } -// Chunk is a chunk of data consisted of several data instances. -type Chunk struct { - Path string - Index recordio.Index // chunk index -} + if recovered { + // Recovered. Now the state is already initialized, + // and the master is ready. + s.initDone = true + close(s.ready) + } -// Task is the basic unit of data instances assigned to trainers. -type Task struct { - ID int - Chunks []Chunk + return s, nil } -type taskEntry struct { - Epoch int - NumTimeout int - Task Task -} +// recover recovers service state from etcd. +func (s *Service) recover() (bool, error) { + state, err := s.store.Load() + if err != nil { + return false, err + } -type taskQueues struct { - Todo []taskEntry - Pending map[int]taskEntry // map from task ID to task entry - Done []taskEntry - Failed []Task + if state == nil { + log.Infoln("No state exists, not recovered.") + return false, nil + } + + log.Infof("Loaded snapshot of size: %d bytes.", len(state)) + gr, err := gzip.NewReader(bytes.NewReader(state)) + if err != nil { + return false, err + } + + dec := gob.NewDecoder(gr) + var tqs taskQueues + err = dec.Decode(&tqs) + if err != nil { + return false, err + } + + err = gr.Close() + if err != nil { + // Only close failed, recover actually succeed, so + // just log error. + log.Errorln(err) + } + + s.taskQueues = tqs + return true, nil } -// *must* be called with s.mu being held. +// snapshot *must* be called with s.mu being held. func (s *Service) snapshot() error { - // TODO(helin): snapshot state on etcd. - return nil + // TOOD(helin): etcd request has a size limit, so the snapshot + // size is limited by the max request size. We should either + // divide the snapshot into smaller chunks and save under + // different keys, or configure the request size to be big + // enough: + // https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44 + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + enc := gob.NewEncoder(gw) + err := enc.Encode(s.taskQueues) + if err != nil { + return err + } + err = gw.Close() + if err != nil { + return err + } + + state := buf.Bytes() + log.Infof("Saving snapshot of size: %d bytes.", len(state)) + return s.store.Save(state) } func readChunks(globPaths []string) ([]Chunk, error) { @@ -207,12 +283,12 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { t.NumTimeout++ if t.NumTimeout > s.timeoutMax { - log.Warningf("Task %v timed out %d times, discard.\n", t.Task, t.NumTimeout) + log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout) s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) return } - log.Warningf("Task %v timed out %d times, retry.\n", t.Task, t.NumTimeout) + log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout) s.taskQueues.Todo = append(s.taskQueues.Todo, t) } } diff --git a/go/pserver/cclient/cclient.go b/go/pserver/cclient/cclient.go index 92a41b7f54..bbaf43d9f1 100644 --- a/go/pserver/cclient/cclient.go +++ b/go/pserver/cclient/cclient.go @@ -133,7 +133,7 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, if err != nil { if err.Error() == pserver.AlreadyInitialized { - log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name) + log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.", name) return C.PSERVER_OK } log.Errorln(err) @@ -200,7 +200,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, for i, p := range ps { pn[i] = p.Name } - log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) + log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) return C.PSERVER_ERROR } @@ -210,7 +210,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, for i, p := range ps { pn[i] = p.Name } - log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) + log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) return C.PSERVER_ERROR } } From 44226853029119e195530e78ff7d0ab883b72dff Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 21 Jun 2017 18:55:49 +0000 Subject: [PATCH 30/35] put InMemStore into master package --- go/cmd/master/master.go | 23 +---------------------- go/master/client_internal_test.go | 15 +-------------- go/master/client_test.go | 15 +-------------- go/master/inmem_store.go | 28 ++++++++++++++++++++++++++++ 4 files changed, 31 insertions(+), 50 deletions(-) create mode 100644 go/master/inmem_store.go diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index 49ad0300b8..48fe2e6f75 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -6,7 +6,6 @@ import ( "net/rpc" "strconv" "strings" - "sync" "time" "github.com/namsral/flag" @@ -15,26 +14,6 @@ import ( "github.com/PaddlePaddle/Paddle/go/master" ) -type inMemStore struct { - mu sync.Mutex - buf []byte -} - -func (m *inMemStore) Save(b []byte) error { - m.mu.Lock() - defer m.mu.Unlock() - - m.buf = b - return nil -} - -func (m *inMemStore) Load() ([]byte, error) { - m.mu.Lock() - defer m.mu.Unlock() - - return m.buf, nil -} - func main() { port := flag.Int("port", 8080, "port of the master server.") @@ -58,7 +37,7 @@ func main() { log.Fatal(err) } } else { - store = &inMemStore{} + store = &master.InMemStore{} } s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go index a5b76fe853..251225780a 100644 --- a/go/master/client_internal_test.go +++ b/go/master/client_internal_test.go @@ -32,19 +32,6 @@ func (a TestAddresser) Address() string { return string(a) } -type myStore struct { - buf []byte -} - -func (m *myStore) Save(b []byte) error { - m.buf = b - return nil -} - -func (m *myStore) Load() ([]byte, error) { - return m.buf, nil -} - func TestGetFinishTask(t *testing.T) { const path = "/tmp/master_client_test_0" @@ -60,7 +47,7 @@ func TestGetFinishTask(t *testing.T) { } go func(l net.Listener) { - s, err := NewService(&myStore{}, chunkPerTask, time.Second, 1) + s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) if err != nil { panic(err) } diff --git a/go/master/client_test.go b/go/master/client_test.go index ae5f17c2d4..85a86761c2 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -15,19 +15,6 @@ import ( "github.com/PaddlePaddle/recordio" ) -type myStore struct { - buf []byte -} - -func (m *myStore) Save(b []byte) error { - m.buf = b - return nil -} - -func (m *myStore) Load() ([]byte, error) { - return m.buf, nil -} - func TestNextRecord(t *testing.T) { const ( path = "/tmp/master_client_TestFull" @@ -46,7 +33,7 @@ func TestNextRecord(t *testing.T) { } go func(l net.Listener) { - s, err := master.NewService(&myStore{}, 10, time.Second, 1) + s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) if err != nil { panic(err) } diff --git a/go/master/inmem_store.go b/go/master/inmem_store.go new file mode 100644 index 0000000000..bcd549b20e --- /dev/null +++ b/go/master/inmem_store.go @@ -0,0 +1,28 @@ +package master + +import "sync" + +// InMemStore is an in memory implementation of Store interface. +// +// It does not tolerate the fault that casues the program to crash. +type InMemStore struct { + mu sync.Mutex + buf []byte +} + +// Save saves the state into the in-memory store. +func (m *InMemStore) Save(state []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.buf = state + return nil +} + +// Load loads the state from the in-memory store. +func (m *InMemStore) Load() ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + return m.buf, nil +} From a4ba403e792fc21b5e032ad6116f1fc00fb4ba8d Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 21 Jun 2017 19:00:25 +0000 Subject: [PATCH 31/35] add comment for gracefully stop etcd store --- go/master/etcd_store.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/go/master/etcd_store.go b/go/master/etcd_store.go index ce178370ff..d8e95056d5 100644 --- a/go/master/etcd_store.go +++ b/go/master/etcd_store.go @@ -29,6 +29,10 @@ type EtcdStore struct { // NewEtcdStore creates a new EtcdStore. func NewEtcdStore(endpoints []string, lockPath, statePath string, ttlSec int) (*EtcdStore, error) { + // TODO(helin): gracefully shutdown etcd store. Becuase etcd + // store holds a etcd lock, even though the lock will expire + // when the lease timeout, we need to implement graceful + // shutdown to release the lock. cli, err := clientv3.New(clientv3.Config{ Endpoints: endpoints, DialTimeout: dialTimeout, From bf79c9e5bba41dd9f1e122a779e27e3e8dca9ee3 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 21 Jun 2017 19:02:21 +0000 Subject: [PATCH 32/35] add log when master recovered from saved state. --- go/master/service.go | 1 + 1 file changed, 1 insertion(+) diff --git a/go/master/service.go b/go/master/service.go index d453777b05..58e68e7448 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -110,6 +110,7 @@ func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeou // and the master is ready. s.initDone = true close(s.ready) + log.Info("Master recovered from saved state.") } return s, nil From 42313a3c35637b8d706aa4dbdef65c671e7d6665 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Fri, 23 Jun 2017 22:11:45 +0000 Subject: [PATCH 33/35] rename EtcdStore to Etcd --- go/cmd/master/master.go | 2 +- go/master/etcd_store.go | 21 +++++++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index 48fe2e6f75..a62bc4310e 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -32,7 +32,7 @@ func main() { if *endpoints != "" { eps := strings.Split(*endpoints, ",") var err error - store, err = master.NewEtcdStore(eps, master.DefaultLockPath, master.DefaultStatePath, *ttlSec) + store, err = master.NewEtcd(eps, master.DefaultLockPath, master.DefaultStatePath, *ttlSec) if err != nil { log.Fatal(err) } diff --git a/go/master/etcd_store.go b/go/master/etcd_store.go index d8e95056d5..21b3e2cb0f 100644 --- a/go/master/etcd_store.go +++ b/go/master/etcd_store.go @@ -16,8 +16,9 @@ const ( DefaultStatePath = "/master/state" ) -// EtcdStore is the Store implementation backed by etcd. -type EtcdStore struct { +// Etcd is the etcd abstraction that master uses for fault tolerance +// and service registry. +type Etcd struct { lockPath string statePath string ttlSec int @@ -27,8 +28,8 @@ type EtcdStore struct { lock *concurrency.Mutex } -// NewEtcdStore creates a new EtcdStore. -func NewEtcdStore(endpoints []string, lockPath, statePath string, ttlSec int) (*EtcdStore, error) { +// NewEtcd creates a new Etcd. +func NewEtcd(endpoints []string, lockPath, statePath string, ttlSec int) (*Etcd, error) { // TODO(helin): gracefully shutdown etcd store. Becuase etcd // store holds a etcd lock, even though the lock will expire // when the lease timeout, we need to implement graceful @@ -59,7 +60,7 @@ func NewEtcdStore(endpoints []string, lockPath, statePath string, ttlSec int) (* } log.Infof("Successfully acquired lock at %s.", lockPath) - e := &EtcdStore{} + e := &Etcd{} e.client = cli e.lock = lock e.lockPath = lockPath @@ -69,7 +70,7 @@ func NewEtcdStore(endpoints []string, lockPath, statePath string, ttlSec int) (* } // Save saves the state into the etcd. -func (e *EtcdStore) Save(state []byte) error { +func (e *Etcd) Save(state []byte) error { e.mu.Lock() defer e.mu.Unlock() @@ -101,7 +102,7 @@ func (e *EtcdStore) Save(state []byte) error { } // Load loads the state from etcd. -func (e *EtcdStore) Load() ([]byte, error) { +func (e *Etcd) Load() ([]byte, error) { e.mu.Lock() ctx := context.TODO() get := clientv3.OpGet(e.statePath) @@ -119,8 +120,12 @@ func (e *EtcdStore) Load() ([]byte, error) { } e.lock = concurrency.NewMutex(sess, e.lockPath) - e.lock.Lock(context.TODO()) + err = e.lock.Lock(context.TODO()) e.mu.Unlock() + if err != nil { + return nil, err + } + return e.Load() } From 7dad02661f1cd7406eac871354c94cebf4d38345 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Sat, 24 Jun 2017 00:04:26 +0000 Subject: [PATCH 34/35] Master server registers itself to etcd. --- go/cmd/master/master.go | 14 +++- go/master/{etcd_store.go => etcd_client.go} | 90 +++++++++++---------- 2 files changed, 56 insertions(+), 48 deletions(-) rename go/master/{etcd_store.go => etcd_client.go} (56%) diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index a62bc4310e..54fa254863 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "net" "net/http" "net/rpc" @@ -12,13 +13,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/PaddlePaddle/Paddle/go/master" + "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" ) func main() { port := flag.Int("port", 8080, "port of the master server.") - ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.") - endpoints := flag.String("endpoints", "", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") + endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") @@ -31,8 +32,13 @@ func main() { var store master.Store if *endpoints != "" { eps := strings.Split(*endpoints, ",") - var err error - store, err = master.NewEtcd(eps, master.DefaultLockPath, master.DefaultStatePath, *ttlSec) + ip, err := networkhelper.GetExternalIP() + if err != nil { + log.Fatal(err) + } + + addr := fmt.Sprintf("%s:%d", ip, *port) + store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec) if err != nil { log.Fatal(err) } diff --git a/go/master/etcd_store.go b/go/master/etcd_client.go similarity index 56% rename from go/master/etcd_store.go rename to go/master/etcd_client.go index 21b3e2cb0f..b7293a7598 100644 --- a/go/master/etcd_store.go +++ b/go/master/etcd_client.go @@ -2,7 +2,7 @@ package master import ( "context" - "sync" + "time" "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3/concurrency" @@ -14,22 +14,22 @@ const ( DefaultLockPath = "/master/lock" // DefaultStatePath is the default etcd key for master state. DefaultStatePath = "/master/state" + // DefaultAddrPath is the default etcd key for master address. + DefaultAddrPath = "/master/addr" ) -// Etcd is the etcd abstraction that master uses for fault tolerance +// EtcdClient is the etcd client that master uses for fault tolerance // and service registry. -type Etcd struct { +type EtcdClient struct { lockPath string statePath string - ttlSec int client *clientv3.Client - - mu sync.Mutex - lock *concurrency.Mutex + lock *concurrency.Mutex } -// NewEtcd creates a new Etcd. -func NewEtcd(endpoints []string, lockPath, statePath string, ttlSec int) (*Etcd, error) { +// NewEtcdClient creates a new EtcdClient. +func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { + log.Debugf("Connecting to etcd at %v", endpoints) // TODO(helin): gracefully shutdown etcd store. Becuase etcd // store holds a etcd lock, even though the lock will expire // when the lease timeout, we need to implement graceful @@ -53,27 +53,35 @@ func NewEtcd(endpoints []string, lockPath, statePath string, ttlSec int) (*Etcd, // one master running, but split-brain problem may cuase // multiple master servers running), and the cluster management // software will kill one of them. - log.Infof("Trying to acquire lock at %s.", lockPath) + log.Debugf("Trying to acquire lock at %s.", lockPath) err = lock.Lock(context.TODO()) if err != nil { return nil, err } - log.Infof("Successfully acquired lock at %s.", lockPath) - - e := &Etcd{} - e.client = cli - e.lock = lock - e.lockPath = lockPath - e.statePath = statePath - e.ttlSec = ttlSec + log.Debugf("Successfully acquired lock at %s.", lockPath) + + put := clientv3.OpPut(addrPath, string(addr)) + resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() + if err != nil { + return nil, err + } + + if !resp.Succeeded { + log.Fatal("No longer owns the master lock. Exiting.") + } + + e := &EtcdClient{ + lockPath: lockPath, + statePath: statePath, + client: cli, + lock: lock, + } + return e, nil } // Save saves the state into the etcd. -func (e *Etcd) Save(state []byte) error { - e.mu.Lock() - defer e.mu.Unlock() - +func (e *EtcdClient) Save(state []byte) error { ctx := context.TODO() put := clientv3.OpPut(e.statePath, string(state)) resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit() @@ -82,17 +90,21 @@ func (e *Etcd) Save(state []byte) error { } if !resp.Succeeded { - log.Errorln("No longer owns the lock, trying to lock and save again.") - sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(e.ttlSec)) - if err != nil { - return err - } - - e.lock = concurrency.NewMutex(sess, e.lockPath) - log.Infof("Try to acquire lock at %s.", e.lockPath) - err = e.lock.Lock(context.TODO()) + log.Errorln("No longer owns the lock, trying to lock again") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + err := e.lock.Lock(ctx) + cancel() if err != nil { - return err + // We lost the master lock and can not acquire + // it back, it means some other master is + // already started. We don't want cluster + // managment system to kill the master server + // who is holding the lock and running + // correctly. So the most feasible solution is + // to kill current master server. The current + // state is not saved, but the trainer's RPC + // call will fail, so the trainer will retry. + log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err) } log.Infof("Successfully acquired lock at %s.", e.lockPath) return e.Save(state) @@ -102,8 +114,7 @@ func (e *Etcd) Save(state []byte) error { } // Load loads the state from etcd. -func (e *Etcd) Load() ([]byte, error) { - e.mu.Lock() +func (e *EtcdClient) Load() ([]byte, error) { ctx := context.TODO() get := clientv3.OpGet(e.statePath) @@ -114,14 +125,7 @@ func (e *Etcd) Load() ([]byte, error) { if !resp.Succeeded { log.Errorln("No longer owns the lock, trying to lock and load again.") - sess, err := concurrency.NewSession(e.client) - if err != nil { - return nil, err - } - - e.lock = concurrency.NewMutex(sess, e.lockPath) - err = e.lock.Lock(context.TODO()) - e.mu.Unlock() + err = e.lock.Lock(context.Background()) if err != nil { return nil, err } @@ -132,11 +136,9 @@ func (e *Etcd) Load() ([]byte, error) { kvs := resp.Responses[0].GetResponseRange().Kvs if len(kvs) == 0 { // No state exists - e.mu.Unlock() return nil, nil } state := kvs[0].Value - e.mu.Unlock() return state, nil } From a7865a37768b8d320378c50b517ddd1fdf6db934 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 24 Jun 2017 17:01:48 +0800 Subject: [PATCH 35/35] Fix macos compile Please use `override` not `virtual` in sub-classes. `override` can check if there is a method in `parent` while compiling. --- .../gserver/gradientmachines/NeuralNetwork.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index 514c0759e1..2e839f6405 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -309,35 +309,35 @@ public: void addEvaluator(std::unique_ptr&& evaluator) { evaluators_.emplace_back(std::move(evaluator)); } - virtual void start() { + void start() override { for (auto& evaluator : evaluators_) { evaluator->start(); } } - virtual void finish() { + void finish() override { for (auto& evaluator : evaluators_) { evaluator->finish(); } } - virtual void eval(const NeuralNetwork& nn) override { + void eval(const NeuralNetwork& nn) override { for (auto& evaluator : evaluators_) { evaluator->eval(nn); } } - virtual real evalImp(std::vector& arguments) { + real evalImp(std::vector& arguments) override { (void)arguments; return -1; } - virtual void printStats(std::ostream& os) const { + void printStats(std::ostream& os) const override { for (auto& evaluator : evaluators_) { evaluator->printStats(os); os << ' '; } } - virtual void distributeEval(ParameterClient2* client) { + void distributeEval(ParameterClient2* client) override { for (auto& evaluator : evaluators_) { evaluator->distributeEval(client); } @@ -352,7 +352,7 @@ public: * @brief getNames will return all inside evaluators' names. * @param names [out]: return names. */ - void getNames(std::vector* names) { + void getNames(std::vector* names) override { for (auto& eval : evaluators_) { eval->getNames(names); } @@ -361,7 +361,7 @@ public: /** * @brief getValue could get all inside evaluators' value. */ - real getValue(const std::string& name, Error* err) const { + real getValue(const std::string& name, Error* err) const override { return this->getMethodHelper( name, err, [&name, err](const std::unique_ptr& eval) { return eval->getValue(name, err); @@ -371,7 +371,7 @@ public: /** * @brief getType could get all inside evaluators' type. */ - std::string getType(const std::string& name, Error* err) const { + std::string getType(const std::string& name, Error* err) const override { return this->getMethodHelper( name, err, [&name, err](const std::unique_ptr& eval) { return eval->getType(name, err);