|
|
|
|
@ -32,16 +32,15 @@ DEFINE_string(dirname, "", "Directory of the train model.");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
void Train() {
|
|
|
|
|
CHECK(!FLAGS_dirname.empty());
|
|
|
|
|
void Train(std::string model_dir) {
|
|
|
|
|
framework::InitDevices(false);
|
|
|
|
|
const auto cpu_place = platform::CPUPlace();
|
|
|
|
|
framework::Executor executor(cpu_place);
|
|
|
|
|
framework::Scope scope;
|
|
|
|
|
|
|
|
|
|
auto train_program = inference::Load(
|
|
|
|
|
&executor, &scope, FLAGS_dirname + "__model_combined__.main_program",
|
|
|
|
|
FLAGS_dirname + "__params_combined__");
|
|
|
|
|
&executor, &scope, model_dir + "__model_combined__.main_program",
|
|
|
|
|
model_dir + "__params_combined__");
|
|
|
|
|
|
|
|
|
|
std::string loss_name = "";
|
|
|
|
|
for (auto op_desc : train_program->Block(0).AllOps()) {
|
|
|
|
|
@ -87,6 +86,10 @@ void Train() {
|
|
|
|
|
EXPECT_LT(last_loss, first_loss);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(train, recognize_digits) { Train(); }
|
|
|
|
|
TEST(train, recognize_digits) {
|
|
|
|
|
CHECK(!FLAGS_dirname.empty());
|
|
|
|
|
Train(FLAGS_dirname + "recognize_digits_mlp.train.model/");
|
|
|
|
|
Train(FLAGS_dirname + "recognize_digits_conv.train.model/");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|