|
|
|
@ -16,7 +16,6 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "NeuralNetwork.h"
|
|
|
|
|
#include "hl_gpu.h"
|
|
|
|
|
#include "paddle/gserver/layers/AgentLayer.h"
|
|
|
|
|
#include "paddle/utils/CustomStackTrace.h"
|
|
|
|
|
#include "paddle/utils/Logging.h"
|
|
|
|
|
#include "paddle/utils/Stat.h"
|
|
|
|
@ -28,6 +27,7 @@ limitations under the License. */
|
|
|
|
|
#ifndef PADDLE_MOBILE_INFERENCE
|
|
|
|
|
#include "MultiNetwork.h"
|
|
|
|
|
#include "RecurrentGradientMachine.h"
|
|
|
|
|
#include "paddle/gserver/layers/AgentLayer.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -192,9 +192,11 @@ void NeuralNetwork::init(const ModelConfig& config,
|
|
|
|
|
void NeuralNetwork::connect(LayerPtr agentLayer,
|
|
|
|
|
LayerPtr realLayer,
|
|
|
|
|
int height) {
|
|
|
|
|
#ifndef PADDLE_MOBILE_INFERENCE
|
|
|
|
|
AgentLayer* agent = dynamic_cast<AgentLayer*>(agentLayer.get());
|
|
|
|
|
CHECK_NOTNULL(agent);
|
|
|
|
|
agent->setRealLayer(realLayer, height);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void NeuralNetwork::connect(std::string agentLayerName,
|
|
|
|
|