change return type Argument

avx_docs
liaogang 8 years ago
parent bbfcee20fd
commit 950b4a3126

@ -148,7 +148,8 @@ Arguments* GradientMachine::getLayerOutput(const std::string& layerName) const
throw(UnsupportError) { throw(UnsupportError) {
auto nn = m->machine; auto nn = m->machine;
if (nn) { if (nn) {
return Arguments::createByPaddleArgument(&nn->getLayerOutput(layerName)); auto arg = nn->getLayerOutput(layerName);
return Arguments::createByPaddleArgument(&arg);
} else { } else {
throw UnsupportError(); throw UnsupportError();
} }

@ -134,7 +134,8 @@ void Trainer::finishTestPeriod() { m->finishTestPeriod(); }
Arguments* Trainer::getLayerOutput(const std::string& layerName) const { Arguments* Trainer::getLayerOutput(const std::string& layerName) const {
auto nn = this->m->getGradientMachine(); auto nn = this->m->getGradientMachine();
CHECK(nn) << "trainerInternal_.getGradientMachine() is not NeuralNetwork"; CHECK(nn) << "trainerInternal_.getGradientMachine() is not NeuralNetwork";
return Arguments::createByPaddleArgument(&nn->getLayerOutput(layerName)); auto arg = nn->getLayerOutput(layerName);
return Arguments::createByPaddleArgument(&arg);
} }
void Trainer::forwardOneBatch(size_t batchSize) { void Trainer::forwardOneBatch(size_t batchSize) {

@ -134,7 +134,7 @@ public:
backward(callback); backward(callback);
} }
virtual const Argument& getLayerOutput(const std::string& layerName) { virtual Argument getLayerOutput(const std::string& layerName) {
return *((Argument*)nullptr); return *((Argument*)nullptr);
} }

@ -282,8 +282,7 @@ void MultiGradientMachine::forwardBackward(const std::vector<Argument>& inArgs,
backwardImp(callback); backwardImp(callback);
} }
const Argument& MultiGradientMachine::getLayerOutput( Argument MultiGradientMachine::getLayerOutput(const std::string& layerName) {
const std::string& layerName) {
std::vector<Argument> args; std::vector<Argument> args;
args.reserve(threads_.size()); args.reserve(threads_.size());

@ -189,7 +189,7 @@ public:
PassType passType, PassType passType,
const UpdateCallback& callback); const UpdateCallback& callback);
virtual const Argument& getLayerOutput(const std::string& layerName); virtual Argument getLayerOutput(const std::string& layerName);
virtual void onPassEnd(); virtual void onPassEnd();

@ -293,7 +293,7 @@ void NeuralNetwork::backward(const UpdateCallback& callback) {
} }
} }
const Argument& NeuralNetwork::getLayerOutput(const std::string& layerName) { Argument NeuralNetwork::getLayerOutput(const std::string& layerName) {
return getLayer(layerName)->getOutput(); return getLayer(layerName)->getOutput();
} }

@ -87,7 +87,7 @@ public:
virtual void backward(const UpdateCallback& callback = nullptr); virtual void backward(const UpdateCallback& callback = nullptr);
virtual const Argument& getLayerOutput(const std::string& layerName); virtual Argument getLayerOutput(const std::string& layerName);
const LayerPtr& getLayer(const std::string& layerName) const { const LayerPtr& getLayer(const std::string& layerName) const {
auto it = layerMap_.find(layerName); auto it = layerMap_.find(layerName);

@ -112,7 +112,7 @@ bool CosSimVecMatLayer::init(const LayerMap& layerMap,
void CosSimVecMatLayer::forward(PassType passType) { void CosSimVecMatLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
CHECK_EQ(forward_.size(), 1) << "Only one forward function needed"; CHECK_EQ(forward_.size(), 1UL) << "Only one forward function needed";
MatrixPtr inV0 = getInputValue(0); MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1); MatrixPtr inV1 = getInputValue(1);
@ -145,7 +145,7 @@ void CosSimVecMatLayer::forward(PassType passType) {
} }
void CosSimVecMatLayer::backward(const UpdateCallback& callback) { void CosSimVecMatLayer::backward(const UpdateCallback& callback) {
CHECK_EQ(backward_.size(), 1) << "Only one forward function needed"; CHECK_EQ(backward_.size(), 1UL) << "Only one forward function needed";
MatrixPtr inV0 = getInputValue(0); MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1); MatrixPtr inV1 = getInputValue(1);

Loading…
Cancel
Save