Remove unnecessary import in api_train.py

avx_docs
Yu Yang 8 years ago
parent 763a30fdde
commit 9b41b08ef3

@ -12,7 +12,6 @@ import paddle.trainer.PyDataProvider2 as dp
import numpy as np
import random
from mnist_util import read_from_mnist
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
from paddle.trainer_config_helpers import *
@ -80,14 +79,13 @@ def main():
# enable_types = [value, gradient, momentum, etc]
# For each optimizer(SGD, Adam), GradientMachine should enable different
# buffers.
opt_config_proto = config_parser_utils.parse_optimizer_config(
optimizer_config)
opt_config_proto = parse_optimizer_config(optimizer_config)
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes()
# Create Simple Gradient Machine.
model_config = config_parser_utils.parse_network_config(network_config)
model_config = parse_network_config(network_config)
m = api.GradientMachine.createFromConfigProto(
model_config, api.CREATE_MODE_NORMAL, enable_types)

Loading…
Cancel
Save