From 55684af208071bd788381946ac76c9da2b5b7329 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 25 Jun 2017 13:13:10 +0800 Subject: [PATCH 1/2] fix MultiGradientMachine train and infer --- .../gradientmachines/MultiGradientMachine.cpp | 12 ++++++------ .../gserver/gradientmachines/MultiGradientMachine.h | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 3159026e6b..9abda18d54 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -171,6 +171,12 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config, } } +MultiGradientMachine::~MultiGradientMachine() { + for (auto& thread : threads_) { + thread->stop(); + } +} + std::vector*> MultiGradientMachine::getSlaveParameters() { std::vector*> vec; @@ -326,12 +332,6 @@ void MultiGradientMachine::onPassEnd() { } } -void MultiGradientMachine::finish() { - for (auto& thread : threads_) { - thread->stop(); - } -} - Evaluator* MultiGradientMachine::makeEvaluator() const { return threads_[0]->getGradientMachine()->makeEvaluator(); } diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.h b/paddle/gserver/gradientmachines/MultiGradientMachine.h index 70203bbb97..c005c0ed67 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.h +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.h @@ -176,6 +176,8 @@ public: explicit MultiGradientMachine(const ModelConfig& config, bool useGpu); + virtual ~MultiGradientMachine(); + virtual void prefetch(const std::vector& inArgs); virtual void forward(const std::vector& inArgs, @@ -193,8 +195,6 @@ public: virtual void onPassEnd(); - virtual void finish(); - virtual Evaluator* makeEvaluator() const; virtual void eval(Evaluator* evaluator) const; From 9f05a0f80225bf4f630817c413b82b23d7579091 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 26 Jun 2017 14:22:18 +0800 Subject: [PATCH 2/2] use GradientMachine::start and finish --- .../gradientmachines/MultiGradientMachine.cpp | 12 ++++++++++-- .../gserver/gradientmachines/MultiGradientMachine.h | 4 +++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 9abda18d54..8ef5e9d0c1 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -166,12 +166,16 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config, outArgStream_ = HPPL_STREAM_1; + start(); +} + +void MultiGradientMachine::start() { for (auto& thread : threads_) { thread->start(); } } -MultiGradientMachine::~MultiGradientMachine() { +void MultiGradientMachine::finish() { for (auto& thread : threads_) { thread->stop(); } @@ -445,7 +449,7 @@ TrainerThread::TrainerThread(const ModelConfig& config, gradStream_ = HPPL_STREAM_2; valueStream_ = HPPL_STREAM_3; - stopping_ = false; + stopping_ = true; updateCounter_ = 0; parameterUpdated_ = false; } @@ -453,6 +457,10 @@ TrainerThread::TrainerThread(const ModelConfig& config, TrainerThread::~TrainerThread() { stop(); } void TrainerThread::start() { + if (!stopping_) return; + + stopping_ = false; + gradientMachine_->start(); computeThread_.reset(new std::thread([this]() { computeThread(); })); diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.h b/paddle/gserver/gradientmachines/MultiGradientMachine.h index c005c0ed67..5e7622f929 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.h +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.h @@ -176,7 +176,9 @@ public: explicit MultiGradientMachine(const ModelConfig& config, bool useGpu); - virtual ~MultiGradientMachine(); + virtual void start(); + + virtual void finish(); virtual void prefetch(const std::vector& inArgs);