|
|
|
@ -64,6 +64,14 @@ GradientMachine* GradientMachine::createByModelConfig(
|
|
|
|
|
return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GradientMachine::onPassEnd() { m->machine->onPassEnd(); }
|
|
|
|
|
|
|
|
|
|
void GradientMachine::prefetch(const Arguments& inArgs) {
|
|
|
|
|
auto& in =
|
|
|
|
|
m->cast<std::vector<paddle::Argument>>(inArgs.getInternalArgumentsPtr());
|
|
|
|
|
m->machine->prefetch(in);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GradientMachine::forward(const Arguments& inArgs,
|
|
|
|
|
Arguments* outArgs,
|
|
|
|
|
PassType passType) {
|
|
|
|
|