|
|
@ -50,7 +50,7 @@ void calcGradient(bool useGpu, comData& Data) {
|
|
|
|
trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch);
|
|
|
|
trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch);
|
|
|
|
CHECK(dataBatch.getSize()) << "No data from data provider";
|
|
|
|
CHECK(dataBatch.getSize()) << "No data from data provider";
|
|
|
|
vector<Argument>& inArgs = dataBatch.getStreams();
|
|
|
|
vector<Argument>& inArgs = dataBatch.getStreams();
|
|
|
|
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr);
|
|
|
|
trainer.getGradientMachine()->start();
|
|
|
|
for (int i = 0; i < 2; ++i) {
|
|
|
|
for (int i = 0; i < 2; ++i) {
|
|
|
|
trainer.getGradientMachine()->forwardBackward(
|
|
|
|
trainer.getGradientMachine()->forwardBackward(
|
|
|
|
inArgs, &Data.outArgs, PASS_TRAIN);
|
|
|
|
inArgs, &Data.outArgs, PASS_TRAIN);
|
|
|
|