Remove not used params in GradientMachine::start

avx_docs
Yu Yang 8 years ago
parent 2965df5113
commit 56f29658ba

@ -212,11 +212,7 @@ public:
* @note This function will only been implemented and used in a * @note This function will only been implemented and used in a
* multithreaded environment. * multithreaded environment.
*/ */
virtual void start(const TrainerConfig& config, virtual void start() {}
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
}
/** /**
* @brief check each work-thread whether is failed/error/finish, * @brief check each work-thread whether is failed/error/finish,

@ -441,7 +441,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
TrainerThread::~TrainerThread() { stop(); } TrainerThread::~TrainerThread() { stop(); }
void TrainerThread::start() { void TrainerThread::start() {
gradientMachine_->start(*(TrainerConfig*)nullptr, (DataProviderPtr) nullptr); gradientMachine_->start();
computeThread_.reset(new std::thread([this]() { computeThread(); })); computeThread_.reset(new std::thread([this]() { computeThread(); }));

@ -109,10 +109,9 @@ void MultiNetwork::onPassEnd() {
} }
} }
void MultiNetwork::start(const TrainerConfig& config, void MultiNetwork::start() {
DataProviderPtr dataProvider) {
for (auto& subNetwork : subNetworks_) { for (auto& subNetwork : subNetworks_) {
subNetwork->start(config, dataProvider); subNetwork->start();
} }
} }

@ -54,7 +54,7 @@ public:
return subNetworks_; return subNetworks_;
} }
virtual void start(const TrainerConfig& config, DataProviderPtr dataProvider); virtual void start();
virtual void finish(); virtual void finish();

@ -131,11 +131,7 @@ void ParallelNeuralNetwork::forwardBackward(const std::vector<Argument>& inArgs,
backward(callback); backward(callback);
} }
void ParallelNeuralNetwork::start(const TrainerConfig& config, void ParallelNeuralNetwork::start() {
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
for (auto& thread : threads_) { for (auto& thread : threads_) {
thread->start(); thread->start();
} }

@ -56,7 +56,7 @@ public:
PassType passType, PassType passType,
const UpdateCallback &callback = NULL); const UpdateCallback &callback = NULL);
virtual void start(const TrainerConfig &config, DataProviderPtr dataProvider); virtual void start();
void addComputeThread(int deviceId); void addComputeThread(int deviceId);

@ -114,7 +114,7 @@ void calcGradient(DataIn& in, DataOut& out, const std::string& configPath) {
parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]); parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]);
} }
} }
gradientMachine->start(trainer.getConfig(), nullptr); gradientMachine->start();
gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN); gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN);
for (size_t i = 0; i < in.outGrads.size(); i++) { for (size_t i = 0; i < in.outGrads.size(); i++) {
// If the all the layers in the config have no parameters, also // If the all the layers in the config have no parameters, also

@ -28,7 +28,7 @@ class TrainerForTest : public paddle::Trainer {
public: public:
void startTrain() { void startTrain() {
GradientMachine& gm = *this->trainerInternal_.getGradientMachine(); GradientMachine& gm = *this->trainerInternal_.getGradientMachine();
gm.start(this->getConfig(), dataProvider_); gm.start();
} }
void finishTrain() { void finishTrain() {

@ -257,7 +257,7 @@ void Tester::test() {
CHECK(testDataProvider_) << "TestData is not specified"; CHECK(testDataProvider_) << "TestData is not specified";
testDataProvider_->setSkipShuffle(); testDataProvider_->setSkipShuffle();
testDataProvider_->reset(); testDataProvider_->reset();
gradientMachine_->start(*config_, testDataProvider_); gradientMachine_->start();
// For evaluation // For evaluation
std::vector<std::string> modelList; std::vector<std::string> modelList;

@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) {
} }
real Trainer::checkGradient() { real Trainer::checkGradient() {
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); trainerInternal_.getGradientMachine()->start();
std::vector<ParameterPtr>& parameters = std::vector<ParameterPtr>& parameters =
trainerInternal_.getGradientMachine()->getNonStaticParameters(); trainerInternal_.getGradientMachine()->getNonStaticParameters();
DataBatch dataBatch; DataBatch dataBatch;
@ -390,7 +390,7 @@ void Trainer::startTrain() {
dataProvider_->reset(); dataProvider_->reset();
} }
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); trainerInternal_.getGradientMachine()->start();
} }
void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); } void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); }

@ -50,7 +50,7 @@ void calcGradient(bool useGpu, comData& Data) {
trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch); trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch);
CHECK(dataBatch.getSize()) << "No data from data provider"; CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams(); vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr); trainer.getGradientMachine()->start();
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
trainer.getGradientMachine()->forwardBackward( trainer.getGradientMachine()->forwardBackward(
inArgs, &Data.outArgs, PASS_TRAIN); inArgs, &Data.outArgs, PASS_TRAIN);

@ -72,7 +72,7 @@ void calcGradient(ComData& data, const string configFile) {
CHECK(dataBatch.getSize()) << "No data from data provider"; CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams(); vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr); trainer.getGradientMachine()->start();
trainer.getGradientMachine()->forwardBackward( trainer.getGradientMachine()->forwardBackward(
inArgs, &data.outArgs, PASS_TRAIN); inArgs, &data.outArgs, PASS_TRAIN);

Loading…
Cancel
Save