Python trainer api (#193)
* Python trainer API and demo * Adding missing PaddleAPIPrivate.h * Adding api_train.sh * More comments * Bump up patch version to 0b3avx_docs
parent
46bd5f53e3
commit
cbe734b396
@ -0,0 +1,114 @@
|
|||||||
|
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import itertools
|
||||||
|
import random
|
||||||
|
|
||||||
|
from paddle.trainer.config_parser import parse_config
|
||||||
|
from py_paddle import swig_paddle as api
|
||||||
|
from py_paddle import DataProviderConverter
|
||||||
|
from paddle.trainer.PyDataProvider2 \
|
||||||
|
import integer_value, integer_value_sequence, sparse_binary_vector
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--train_data",
|
||||||
|
type=str, required=False, help="train data file")
|
||||||
|
parser.add_argument("--test_data", type=str, help="test data file")
|
||||||
|
parser.add_argument("--config",
|
||||||
|
type=str, required=True, help="config file name")
|
||||||
|
parser.add_argument("--dict_file", required=True, help="dictionary file")
|
||||||
|
parser.add_argument("--seq",
|
||||||
|
default=1, type=int,
|
||||||
|
help="whether use sequence training")
|
||||||
|
parser.add_argument("--use_gpu", default=0, type=int,
|
||||||
|
help="whether use GPU for training")
|
||||||
|
parser.add_argument("--trainer_count", default=1, type=int,
|
||||||
|
help="Number of threads for training")
|
||||||
|
parser.add_argument("--num_passes", default=5, type=int,
|
||||||
|
help="Number of training passes")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
UNK_IDX = 0
|
||||||
|
|
||||||
|
def load_data(file_name, word_dict):
|
||||||
|
with open(file_name, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
label, comment = line.strip().split('\t')
|
||||||
|
words = comment.split()
|
||||||
|
word_slot = [word_dict.get(w, UNK_IDX) for w in words]
|
||||||
|
yield word_slot, int(label)
|
||||||
|
|
||||||
|
def load_dict(dict_file):
|
||||||
|
word_dict = dict()
|
||||||
|
with open(dict_file, 'r') as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
w = line.strip().split()[0]
|
||||||
|
word_dict[w] = i
|
||||||
|
return word_dict
|
||||||
|
|
||||||
|
def main():
|
||||||
|
options = parse_arguments()
|
||||||
|
api.initPaddle("--use_gpu=%s" % options.use_gpu,
|
||||||
|
"--trainer_count=%s" % options.trainer_count)
|
||||||
|
|
||||||
|
word_dict = load_dict(options.dict_file)
|
||||||
|
train_dataset = list(load_data(options.train_data, word_dict))
|
||||||
|
if options.test_data:
|
||||||
|
test_dataset = list(load_data(options.test_data, word_dict))
|
||||||
|
else:
|
||||||
|
test_dataset = None
|
||||||
|
|
||||||
|
trainer_config = parse_config(options.config,
|
||||||
|
"dict_file=%s" % options.dict_file)
|
||||||
|
# No need to have data provider for trainer
|
||||||
|
trainer_config.ClearField('data_config')
|
||||||
|
trainer_config.ClearField('test_data_config')
|
||||||
|
|
||||||
|
# create a GradientMachine from the model configuratin
|
||||||
|
model = api.GradientMachine.createFromConfigProto(
|
||||||
|
trainer_config.model_config)
|
||||||
|
# create a trainer for the gradient machine
|
||||||
|
trainer = api.Trainer.create(trainer_config, model)
|
||||||
|
|
||||||
|
# create a data converter which converts data to PaddlePaddle
|
||||||
|
# internal format
|
||||||
|
input_types = [
|
||||||
|
integer_value_sequence(len(word_dict)) if options.seq
|
||||||
|
else sparse_binary_vector(len(word_dict)),
|
||||||
|
integer_value(2)]
|
||||||
|
converter = DataProviderConverter(input_types)
|
||||||
|
|
||||||
|
batch_size = trainer_config.opt_config.batch_size
|
||||||
|
trainer.startTrain()
|
||||||
|
for train_pass in xrange(options.num_passes):
|
||||||
|
trainer.startTrainPass()
|
||||||
|
random.shuffle(train_dataset)
|
||||||
|
for pos in xrange(0, len(train_dataset), batch_size):
|
||||||
|
batch = itertools.islice(train_dataset, pos, pos + batch_size)
|
||||||
|
size = min(batch_size, len(train_dataset) - pos)
|
||||||
|
trainer.trainOneDataBatch(size, converter(batch))
|
||||||
|
trainer.finishTrainPass()
|
||||||
|
if test_dataset:
|
||||||
|
trainer.startTestPeriod();
|
||||||
|
for pos in xrange(0, len(test_dataset), batch_size):
|
||||||
|
batch = itertools.islice(test_dataset, pos, pos + batch_size)
|
||||||
|
size = min(batch_size, len(test_dataset) - pos)
|
||||||
|
trainer.testOneDataBatch(size, converter(batch))
|
||||||
|
trainer.finishTestPeriod()
|
||||||
|
trainer.finishTrain()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,29 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Note: if using trainer_config.emb.py, trainer_config.cnn.py
|
||||||
|
# or trainer_config.lstm.py, you need to change --seq to --seq=1
|
||||||
|
# because they are sequence models.
|
||||||
|
python api_train.py \
|
||||||
|
--config=trainer_config.lr.py \
|
||||||
|
--trainer_count=2 \
|
||||||
|
--num_passes=15 \
|
||||||
|
--use_gpu=0 \
|
||||||
|
--seq=0 \
|
||||||
|
--train_data=data/train.txt \
|
||||||
|
--test_data=data/test.txt \
|
||||||
|
--dict_file=data/dict.txt \
|
||||||
|
2>&1 | tee 'train.log'
|
@ -0,0 +1,68 @@
|
|||||||
|
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/gserver/gradientmachines/GradientMachine.h"
|
||||||
|
#include "paddle/trainer/TrainerConfigHelper.h"
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
struct GradientMachinePrivate {
|
||||||
|
std::shared_ptr<paddle::GradientMachine> machine;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T& cast(void* ptr) {
|
||||||
|
return *(T*)(ptr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OptimizationConfigPrivate {
|
||||||
|
std::shared_ptr<paddle::TrainerConfigHelper> trainer_config;
|
||||||
|
paddle::OptimizationConfig config;
|
||||||
|
|
||||||
|
const paddle::OptimizationConfig& getConfig() {
|
||||||
|
if (trainer_config != nullptr) {
|
||||||
|
return trainer_config->getOptConfig();
|
||||||
|
} else {
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TrainerConfigPrivate {
|
||||||
|
std::shared_ptr<paddle::TrainerConfigHelper> conf;
|
||||||
|
TrainerConfigPrivate() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ModelConfigPrivate {
|
||||||
|
std::shared_ptr<paddle::TrainerConfigHelper> conf;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArgumentsPrivate {
|
||||||
|
std::vector<paddle::Argument> outputs;
|
||||||
|
|
||||||
|
inline paddle::Argument& getArg(size_t idx) throw(RangeError) {
|
||||||
|
if (idx < outputs.size()) {
|
||||||
|
return outputs[idx];
|
||||||
|
} else {
|
||||||
|
RangeError e;
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::shared_ptr<T>& cast(void* rawPtr) const {
|
||||||
|
return *(std::shared_ptr<T>*)(rawPtr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,63 @@
|
|||||||
|
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle.trainer.config_parser import parse_config
|
||||||
|
from paddle.trainer.config_parser import logger
|
||||||
|
from py_paddle import swig_paddle
|
||||||
|
import util
|
||||||
|
|
||||||
|
def main():
|
||||||
|
trainer_config = parse_config(
|
||||||
|
"./testTrainConfig.py", "")
|
||||||
|
model = swig_paddle.GradientMachine.createFromConfigProto(
|
||||||
|
trainer_config.model_config)
|
||||||
|
trainer = swig_paddle.Trainer.create(trainer_config, model)
|
||||||
|
trainer.startTrain()
|
||||||
|
for train_pass in xrange(2):
|
||||||
|
trainer.startTrainPass()
|
||||||
|
num = 0
|
||||||
|
cost = 0
|
||||||
|
while True: # Train one batch
|
||||||
|
batch_size = 1000
|
||||||
|
data, atEnd = util.loadMNISTTrainData(batch_size)
|
||||||
|
if atEnd:
|
||||||
|
break
|
||||||
|
trainer.trainOneDataBatch(batch_size, data)
|
||||||
|
outs = trainer.getForwardOutput()
|
||||||
|
cost += sum(outs[0]['value'])
|
||||||
|
num += batch_size
|
||||||
|
trainer.finishTrainPass()
|
||||||
|
logger.info('train cost=%f' % (cost / num))
|
||||||
|
|
||||||
|
trainer.startTestPeriod()
|
||||||
|
num = 0
|
||||||
|
cost = 0
|
||||||
|
while True: # Test one batch
|
||||||
|
batch_size = 1000
|
||||||
|
data, atEnd = util.loadMNISTTrainData(batch_size)
|
||||||
|
if atEnd:
|
||||||
|
break
|
||||||
|
trainer.testOneDataBatch(batch_size, data)
|
||||||
|
outs = trainer.getForwardOutput()
|
||||||
|
cost += sum(outs[0]['value'])
|
||||||
|
num += batch_size
|
||||||
|
trainer.finishTestPeriod()
|
||||||
|
logger.info('test cost=%f' % (cost / num))
|
||||||
|
|
||||||
|
trainer.finishTrain()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
swig_paddle.initPaddle("--use_gpu=0", "--trainer_count=1")
|
||||||
|
main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue