|
|
|
@ -12,7 +12,6 @@ import paddle.trainer.PyDataProvider2 as dp
|
|
|
|
|
import numpy as np
|
|
|
|
|
import random
|
|
|
|
|
from mnist_util import read_from_mnist
|
|
|
|
|
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
|
|
|
|
|
from paddle.trainer_config_helpers import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -80,14 +79,13 @@ def main():
|
|
|
|
|
# enable_types = [value, gradient, momentum, etc]
|
|
|
|
|
# For each optimizer(SGD, Adam), GradientMachine should enable different
|
|
|
|
|
# buffers.
|
|
|
|
|
opt_config_proto = config_parser_utils.parse_optimizer_config(
|
|
|
|
|
optimizer_config)
|
|
|
|
|
opt_config_proto = parse_optimizer_config(optimizer_config)
|
|
|
|
|
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
|
|
|
|
|
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
|
|
|
|
|
enable_types = _temp_optimizer_.getParameterTypes()
|
|
|
|
|
|
|
|
|
|
# Create Simple Gradient Machine.
|
|
|
|
|
model_config = config_parser_utils.parse_network_config(network_config)
|
|
|
|
|
model_config = parse_network_config(network_config)
|
|
|
|
|
m = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
model_config, api.CREATE_MODE_NORMAL, enable_types)
|
|
|
|
|
|
|
|
|
|