From cbe734b396a6f6844cafb94b0de8bd736360ef90 Mon Sep 17 00:00:00 2001 From: emailweixu Date: Thu, 27 Oct 2016 00:23:46 -0700 Subject: [PATCH 01/11] Python trainer api (#193) * Python trainer API and demo * Adding missing PaddleAPIPrivate.h * Adding api_train.sh * More comments * Bump up patch version to 0b3 --- CMakeLists.txt | 2 +- cmake/swig.cmake | 1 + demo/quick_start/api_train.py | 114 +++++++++++ demo/quick_start/api_train.sh | 29 +++ demo/quick_start/train.sh | 2 +- demo/quick_start/trainer_config.lr.py | 3 +- demo/sentiment/predict.py | 4 +- paddle/api/Arguments.cpp | 19 +- paddle/api/CMakeLists.txt | 3 + paddle/api/ConfigParser.cpp | 44 ++--- paddle/api/GradientMachine.cpp | 18 +- paddle/api/PaddleAPI.h | 40 ++-- paddle/api/PaddleAPIPrivate.h | 68 +++++++ paddle/api/ParameterOptimizer.cpp | 4 +- paddle/api/Trainer.cpp | 175 +++++++++-------- paddle/api/test/run_tests.sh | 2 +- paddle/api/test/testTrain.py | 3 +- paddle/api/test/testTrainer.py | 63 +++++++ paddle/py_paddle/dataprovider_converter.py | 2 +- paddle/py_paddle/util.py | 107 +++++++++-- paddle/trainer/Tester.cpp | 56 ++++-- paddle/trainer/Tester.h | 12 +- paddle/trainer/Trainer.cpp | 208 ++++++++++++--------- paddle/trainer/Trainer.h | 22 ++- paddle/trainer/TrainerInternal.cpp | 10 +- paddle/trainer/TrainerInternal.h | 4 +- proto/TrainerConfig.proto.m4 | 2 +- python/paddle/proto/__init__.py | 2 + 28 files changed, 707 insertions(+), 312 deletions(-) create mode 100644 demo/quick_start/api_train.py create mode 100755 demo/quick_start/api_train.sh create mode 100644 paddle/api/PaddleAPIPrivate.h create mode 100644 paddle/api/test/testTrainer.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 4613155f77..e96ce28248 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8) project(paddle CXX C) set(PADDLE_MAJOR_VERSION 0) set(PADDLE_MINOR_VERSION 8) -set(PADDLE_PATCH_VERSION 0b2) +set(PADDLE_PATCH_VERSION 0b3) set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION}) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake") diff --git a/cmake/swig.cmake b/cmake/swig.cmake index f5c1bcc79b..160d7ee56a 100644 --- a/cmake/swig.cmake +++ b/cmake/swig.cmake @@ -27,6 +27,7 @@ function(generate_python_api target_name) COMMAND swig -python -c++ -outcurrentdir -I../ api/Paddle.swig && mv ${PROJ_ROOT}/paddle/swig_paddle.py ${PROJ_ROOT}/paddle/py_paddle/swig_paddle.py DEPENDS ${PROJ_ROOT}/paddle/api/Paddle.swig + ${PROJ_ROOT}/paddle/api/PaddleAPI.h WORKING_DIRECTORY ${PROJ_ROOT}/paddle COMMENT "Generate Python API from swig") add_custom_target(${target_name} ALL DEPENDS diff --git a/demo/quick_start/api_train.py b/demo/quick_start/api_train.py new file mode 100644 index 0000000000..5ae19b8d26 --- /dev/null +++ b/demo/quick_start/api_train.py @@ -0,0 +1,114 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import itertools +import random + +from paddle.trainer.config_parser import parse_config +from py_paddle import swig_paddle as api +from py_paddle import DataProviderConverter +from paddle.trainer.PyDataProvider2 \ + import integer_value, integer_value_sequence, sparse_binary_vector + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_data", + type=str, required=False, help="train data file") + parser.add_argument("--test_data", type=str, help="test data file") + parser.add_argument("--config", + type=str, required=True, help="config file name") + parser.add_argument("--dict_file", required=True, help="dictionary file") + parser.add_argument("--seq", + default=1, type=int, + help="whether use sequence training") + parser.add_argument("--use_gpu", default=0, type=int, + help="whether use GPU for training") + parser.add_argument("--trainer_count", default=1, type=int, + help="Number of threads for training") + parser.add_argument("--num_passes", default=5, type=int, + help="Number of training passes") + return parser.parse_args() + +UNK_IDX = 0 + +def load_data(file_name, word_dict): + with open(file_name, 'r') as f: + for line in f: + label, comment = line.strip().split('\t') + words = comment.split() + word_slot = [word_dict.get(w, UNK_IDX) for w in words] + yield word_slot, int(label) + +def load_dict(dict_file): + word_dict = dict() + with open(dict_file, 'r') as f: + for i, line in enumerate(f): + w = line.strip().split()[0] + word_dict[w] = i + return word_dict + +def main(): + options = parse_arguments() + api.initPaddle("--use_gpu=%s" % options.use_gpu, + "--trainer_count=%s" % options.trainer_count) + + word_dict = load_dict(options.dict_file) + train_dataset = list(load_data(options.train_data, word_dict)) + if options.test_data: + test_dataset = list(load_data(options.test_data, word_dict)) + else: + test_dataset = None + + trainer_config = parse_config(options.config, + "dict_file=%s" % options.dict_file) + # No need to have data provider for trainer + trainer_config.ClearField('data_config') + trainer_config.ClearField('test_data_config') + + # create a GradientMachine from the model configuratin + model = api.GradientMachine.createFromConfigProto( + trainer_config.model_config) + # create a trainer for the gradient machine + trainer = api.Trainer.create(trainer_config, model) + + # create a data converter which converts data to PaddlePaddle + # internal format + input_types = [ + integer_value_sequence(len(word_dict)) if options.seq + else sparse_binary_vector(len(word_dict)), + integer_value(2)] + converter = DataProviderConverter(input_types) + + batch_size = trainer_config.opt_config.batch_size + trainer.startTrain() + for train_pass in xrange(options.num_passes): + trainer.startTrainPass() + random.shuffle(train_dataset) + for pos in xrange(0, len(train_dataset), batch_size): + batch = itertools.islice(train_dataset, pos, pos + batch_size) + size = min(batch_size, len(train_dataset) - pos) + trainer.trainOneDataBatch(size, converter(batch)) + trainer.finishTrainPass() + if test_dataset: + trainer.startTestPeriod(); + for pos in xrange(0, len(test_dataset), batch_size): + batch = itertools.islice(test_dataset, pos, pos + batch_size) + size = min(batch_size, len(test_dataset) - pos) + trainer.testOneDataBatch(size, converter(batch)) + trainer.finishTestPeriod() + trainer.finishTrain() + +if __name__ == '__main__': + main() diff --git a/demo/quick_start/api_train.sh b/demo/quick_start/api_train.sh new file mode 100755 index 0000000000..40e9d0a09a --- /dev/null +++ b/demo/quick_start/api_train.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +# Note: if using trainer_config.emb.py, trainer_config.cnn.py +# or trainer_config.lstm.py, you need to change --seq to --seq=1 +# because they are sequence models. +python api_train.py \ + --config=trainer_config.lr.py \ + --trainer_count=2 \ + --num_passes=15 \ + --use_gpu=0 \ + --seq=0 \ + --train_data=data/train.txt \ + --test_data=data/test.txt \ + --dict_file=data/dict.txt \ + 2>&1 | tee 'train.log' diff --git a/demo/quick_start/train.sh b/demo/quick_start/train.sh index ea4e32249a..49806292a4 100755 --- a/demo/quick_start/train.sh +++ b/demo/quick_start/train.sh @@ -24,7 +24,7 @@ paddle train \ --config=$cfg \ --save_dir=./output \ --trainer_count=4 \ - --log_period=20 \ + --log_period=100 \ --num_passes=15 \ --use_gpu=false \ --show_parameter_stats_period=100 \ diff --git a/demo/quick_start/trainer_config.lr.py b/demo/quick_start/trainer_config.lr.py index 119e3849a4..c6059947f3 100644 --- a/demo/quick_start/trainer_config.lr.py +++ b/demo/quick_start/trainer_config.lr.py @@ -16,7 +16,7 @@ from paddle.trainer_config_helpers import * -dict_file = "./data/dict.txt" +dict_file = get_config_arg('dict_file', str, "./data/dict.txt") word_dict = dict() with open(dict_file, 'r') as f: for i, line in enumerate(f): @@ -63,7 +63,6 @@ if not is_predict: label = data_layer(name="label", size=2) # Define cross-entropy classification loss and error. - classification_cost(input=output, label=label) cls = classification_cost(input=output, label=label) outputs(cls) else: diff --git a/demo/sentiment/predict.py b/demo/sentiment/predict.py index c61628d34d..7d0baeabbb 100755 --- a/demo/sentiment/predict.py +++ b/demo/sentiment/predict.py @@ -46,8 +46,8 @@ class SentimentPrediction(): conf = parse_config(train_conf, "is_predict=1") self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config) self.network.loadParameters(self.model_dir) - slots = [integer_value_sequence(self.dict_dim)] - self.converter = DataProviderConverter(slots) + input_types = [integer_value_sequence(self.dict_dim)] + self.converter = DataProviderConverter(input_types) def load_dict(self): """ diff --git a/paddle/api/Arguments.cpp b/paddle/api/Arguments.cpp index 8f73e76260..6f51d55120 100644 --- a/paddle/api/Arguments.cpp +++ b/paddle/api/Arguments.cpp @@ -14,27 +14,10 @@ limitations under the License. */ #include "PaddleAPI.h" +#include "PaddleAPIPrivate.h" #include "paddle/parameter/Argument.h" -struct ArgumentsPrivate { - std::vector outputs; - - inline paddle::Argument& getArg(size_t idx) throw(RangeError) { - if (idx < outputs.size()) { - return outputs[idx]; - } else { - RangeError e; - throw e; - } - } - - template - std::shared_ptr& cast(void* rawPtr) const { - return *(std::shared_ptr*)(rawPtr); - } -}; - size_t Arguments::getSlotNum() const { return m->outputs.size(); } Arguments* Arguments::createArguments(size_t slotNum) { diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index fe0da76351..9b2d122a09 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -40,6 +40,8 @@ configure_file( generate_python_api(python_swig_sources) +file(GLOB PY_PADDLE_PYTHON_FILES ${PROJ_ROOT}/paddle/py_paddle/*.py) + # TODO(yuyang18) : make wheel name calculated by cmake add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp COMMAND ${PYTHON_EXECUTABLE} setup.py bdist_wheel @@ -55,6 +57,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp paddle_trainer paddle_api paddle_cuda + ${PY_PADDLE_PYTHON_FILES} ) install(DIRECTORY ${PROJ_ROOT}/paddle/dist/ diff --git a/paddle/api/ConfigParser.cpp b/paddle/api/ConfigParser.cpp index c5ee784a0b..25d94f5a6a 100644 --- a/paddle/api/ConfigParser.cpp +++ b/paddle/api/ConfigParser.cpp @@ -14,17 +14,9 @@ limitations under the License. */ #include "PaddleAPI.h" +#include "PaddleAPIPrivate.h" #include "paddle/trainer/Trainer.h" -struct TrainerConfigPrivate { - std::shared_ptr conf; - TrainerConfigPrivate() : conf(std::make_shared()) {} -}; - -struct ModelConfigPrivate { - std::shared_ptr conf; -}; - struct ParameterConfigPrivate { paddle::ParameterPtr parameter; paddle::ParameterConfig config; @@ -39,19 +31,6 @@ struct ParameterConfigPrivate { } }; -struct OptimizationConfigPrivate { - std::shared_ptr trainer_config; - paddle::OptimizationConfig config; - - paddle::OptimizationConfig& getConfig() { - if (trainer_config != nullptr) { - return *trainer_config->mutable_opt_config(); - } else { - return config; - } - } -}; - TrainerConfig::TrainerConfig() : m(new TrainerConfigPrivate()) {} TrainerConfig::~TrainerConfig() { delete m; } @@ -59,10 +38,19 @@ TrainerConfig::~TrainerConfig() { delete m; } TrainerConfig* TrainerConfig::createFromTrainerConfigFile( const std::string& confPath) { LOG(INFO) << "load trainer config from " << confPath; - paddle::TrainerConfigHelper helper(confPath); - //! TODO(yuyang18): Make TrainerConfigPrivate to TrainerConfigHelper + auto conf = std::make_shared(confPath); auto retv = new TrainerConfig(); - *retv->m->conf = helper.getConfig(); + retv->m->conf = conf; + return retv; +} + +TrainerConfig* TrainerConfig::createFromProtoString( + const std::string& str) { + auto retv = new TrainerConfig(); + paddle::TrainerConfig trainerConfigProto; + auto conf = std::make_shared(trainerConfigProto); + CHECK(conf->getMutableConfig().ParseFromString(str)); + retv->m->conf = conf; return retv; } @@ -76,10 +64,6 @@ ModelConfig* TrainerConfig::getModelConfig() const { return retv; } -void* ModelConfig::getPaddleModelConfig() const { - return m->conf->mutable_model_config(); -} - ParameterConfig::ParameterConfig() : m(new ParameterConfigPrivate()) {} ParameterConfig::~ParameterConfig() { @@ -132,8 +116,6 @@ OptimizationConfig* TrainerConfig::getOptimizationConfig() const { return opt_config; } -void* OptimizationConfig::getRawPtr() { return &m->getConfig(); } - OptimizationConfig* OptimizationConfig::createFromProtoString( const std::string& str) { auto conf = new OptimizationConfig(); diff --git a/paddle/api/GradientMachine.cpp b/paddle/api/GradientMachine.cpp index 6f1d63575a..bef499c678 100644 --- a/paddle/api/GradientMachine.cpp +++ b/paddle/api/GradientMachine.cpp @@ -14,30 +14,22 @@ limitations under the License. */ #include "PaddleAPI.h" -#include "paddle/gserver/gradientmachines/GradientMachine.h" +#include "PaddleAPIPrivate.h" + #include "paddle/gserver/gradientmachines/NeuralNetwork.h" #include "Internal.h" std::vector GradientMachine::defaultParamTypes = { PARAMETER_VALUE, PARAMETER_GRADIENT, PARAMETER_MOMENTUM}; -struct GradientMachinePrivate { - std::shared_ptr machine; - - template - inline T& cast(void* ptr) { - return *(T*)(ptr); - } -}; - GradientMachine::GradientMachine() : m(new GradientMachinePrivate()) {} GradientMachine::~GradientMachine() { delete m; } GradientMachine* GradientMachine::createFromPaddleModelPtr( - void* confPtr, GradientMatchineCreateMode mode, + const void* confPtr, GradientMatchineCreateMode mode, const std::vector& types) { - auto& conf = *(paddle::ModelConfig*)(confPtr); + auto& conf = *(const paddle::ModelConfig*)(confPtr); std::vector realTypes; staticCastVector(&realTypes, types); auto machineRawPtr = paddle::GradientMachine::create(conf, mode, realTypes); @@ -66,7 +58,7 @@ GradientMachine* GradientMachine::createByConfigProtoStr( GradientMachine* GradientMachine::createByModelConfig( ModelConfig* conf, GradientMatchineCreateMode mode, const std::vector& types) { - auto confPtr = (paddle::ModelConfig*)conf->getPaddleModelConfig(); + auto confPtr = &conf->m->conf->getModelConfig(); return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types); } diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index b3140617af..cf790f2f8e 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -446,7 +446,6 @@ struct OptimizationConfigPrivate; class OptimizationConfig { DISABLE_COPY_AND_ASSIGN(OptimizationConfig); OptimizationConfig(); - void* getRawPtr(); public: static OptimizationConfig* createFromProtoString(const std::string& str); @@ -462,6 +461,7 @@ private: friend class TrainerConfig; friend class ParameterOptimizer; + friend class Trainer; }; struct ParameterPrivate; @@ -515,8 +515,6 @@ public: virtual ~ModelConfig(); private: - void* getPaddleModelConfig() const; - ModelConfigPrivate* m; friend class TrainerConfig; friend struct TrainerConfigPrivate; @@ -539,6 +537,7 @@ public: static TrainerConfig* createFromTrainerConfigFile( const std::string& configPath); + static TrainerConfig* createFromProtoString(const std::string& str); ModelConfig* getModelConfig() const; @@ -546,6 +545,7 @@ public: private: TrainerConfigPrivate* m; + friend class Trainer; }; /** @@ -700,11 +700,12 @@ private: GradientMachinePrivate* m; static GradientMachine* createFromPaddleModelPtr( - void* confPtr, GradientMatchineCreateMode mode, + const void* confPtr, GradientMatchineCreateMode mode, const std::vector& types); // Not to use c++ 11 init-list, so we use static var as function default arg. static std::vector defaultParamTypes; + friend class Trainer; }; struct TrainerPrivate; @@ -712,6 +713,7 @@ class Trainer { private: TrainerPrivate* m; Trainer(); + Trainer(TrainerConfig* optConfig, GradientMachine* gm); DISABLE_COPY_AND_ASSIGN(Trainer); public: @@ -720,38 +722,42 @@ public: /// Create A Trainer By TrainerConfig. using paddle command line. static Trainer* createByCommandLine() throw(IOError); - /// Start Train. + static Trainer* create(TrainerConfig* optConfig, GradientMachine* gm) + throw(IOError); + + /// Start training void startTrain(); + + /// Finish training void finishTrain(); - /// Start Pass. + /// Start a pass. void startTrainPass(); - void finishTrainPass(); - void setBatchSize(size_t batchSize); + /// Finish a pass + void finishTrainPass(); /** * Train one batch, * - * @param batchSize -1 wiil use command line or batch size set before, - * otherwise use this batchSize for train. - * * @return true if all batch finished. */ - bool trainOneBatch(size_t batchSize = -1UL); + bool trainOneBatch(size_t batchSize); - bool prepareBatchData(size_t batchSize = -1UL); + void trainOneDataBatch(size_t batchSize, const Arguments& args); - void finishTrainOneBatch(); + void startTestPeriod(); + void testOneDataBatch(size_t batchSize, const Arguments& args); + void finishTestPeriod(); - void forwardOneBatch() throw(UnsupportError); + void forwardOneBatch(size_t batchSize); - Arguments* getNetworkOutput(); + Arguments* getForwardOutput(); Matrix* getLayerOutput(const std::string& layerName); }; -/// The N-Best results generated from one input sequence. +/// the N-Best results generated from one input sequence. class ISequenceResults { public: virtual ~ISequenceResults(); diff --git a/paddle/api/PaddleAPIPrivate.h b/paddle/api/PaddleAPIPrivate.h new file mode 100644 index 0000000000..93cdca8c4b --- /dev/null +++ b/paddle/api/PaddleAPIPrivate.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/gserver/gradientmachines/GradientMachine.h" +#include "paddle/trainer/TrainerConfigHelper.h" + +#pragma once + +struct GradientMachinePrivate { + std::shared_ptr machine; + + template + inline T& cast(void* ptr) { + return *(T*)(ptr); + } +}; + +struct OptimizationConfigPrivate { + std::shared_ptr trainer_config; + paddle::OptimizationConfig config; + + const paddle::OptimizationConfig& getConfig() { + if (trainer_config != nullptr) { + return trainer_config->getOptConfig(); + } else { + return config; + } + } +}; + +struct TrainerConfigPrivate { + std::shared_ptr conf; + TrainerConfigPrivate() {} +}; + +struct ModelConfigPrivate { + std::shared_ptr conf; +}; + +struct ArgumentsPrivate { + std::vector outputs; + + inline paddle::Argument& getArg(size_t idx) throw(RangeError) { + if (idx < outputs.size()) { + return outputs[idx]; + } else { + RangeError e; + throw e; + } + } + + template + std::shared_ptr& cast(void* rawPtr) const { + return *(std::shared_ptr*)(rawPtr); + } +}; + diff --git a/paddle/api/ParameterOptimizer.cpp b/paddle/api/ParameterOptimizer.cpp index e087defc60..b13761ab09 100644 --- a/paddle/api/ParameterOptimizer.cpp +++ b/paddle/api/ParameterOptimizer.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "PaddleAPI.h" +#include "PaddleAPIPrivate.h" #include "paddle/parameter/ParameterOptimizer.h" #include "Internal.h" #include @@ -60,10 +61,9 @@ ParameterOptimizer::~ParameterOptimizer() { ParameterOptimizer* ParameterOptimizer::create(OptimizationConfig* config) { CHECK(config != nullptr); - auto opt_config_ptr = (paddle::OptimizationConfig*)config->getRawPtr(); auto retOptimizer = new ParameterOptimizer(); retOptimizer->m->optimizer.reset( - paddle::ParameterOptimizer::create(*opt_config_ptr, false)); + paddle::ParameterOptimizer::create(config->m->getConfig(), false)); return retOptimizer; } diff --git a/paddle/api/Trainer.cpp b/paddle/api/Trainer.cpp index 95b578c8db..b61f36f740 100644 --- a/paddle/api/Trainer.cpp +++ b/paddle/api/Trainer.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "PaddleAPI.h" +#include "PaddleAPIPrivate.h" #include #include @@ -30,31 +31,17 @@ P_DECLARE_string(config); P_DECLARE_string(init_model_path); P_DECLARE_int32(start_pass); -struct TrainPassContext { - int64_t batchId; - int32_t batchSize; - real avgTestCost; - int64_t numAvgTests; - int passInnerId; - paddle::DataBatch data; - std::vector forwardOutput; -}; - struct TrainerPrivate : public paddle::Trainer { - void startTrain(); - void finishTrain(); - - void startTrainPass(); - void finishTrainPass(); - - bool _trainOneBatch(); - - bool _prepareBatchData(); - void _forwardOneBatch() throw(UnsupportError); - + bool _trainOneBatch(size_t batchSize); + bool forwardOneBatch(size_t batchSize); + void forwardOneDataBatch(const std::vector& inArgs); + void setBatchSize(size_t batchSize); + std::vector& getForwardOutput(); + + void startTestPeriod(); + void finishTestPeriod(); + void testOneDataBatch(const paddle::DataBatch& dataBatch); TrainerPrivate() : paddle::Trainer() {} - - TrainPassContext trainPassContext; }; Trainer::Trainer() : m(new TrainerPrivate()) { @@ -75,61 +62,76 @@ Trainer* Trainer::createByCommandLine() throw(IOError) { } } -void Trainer::startTrain() { m->startTrain(); } +Trainer::Trainer(TrainerConfig* config, GradientMachine* gm) + : m(new TrainerPrivate()) { + m->init(config->m->conf, /* testing= */false, gm ? gm->m->machine : nullptr); +} -void TrainerPrivate::startTrain() { - srand(this->config_->getConfig().start_pass() + 1); - this->dataProvider_->reset(); - this->trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); +Trainer* Trainer::create(TrainerConfig* config, GradientMachine* gm) + throw(IOError) +{ + auto retv = new Trainer(config, gm); + if (retv->m->getConfig().IsInitialized()) { + return retv; + } else { + retv->m->getConfig().CheckInitialized(); + throw IOError(); + } } -void Trainer::finishTrain() { m->finishTrain(); } +void Trainer::startTrain() { m->startTrain(); } -void TrainerPrivate::finishTrain() { - this->trainerInternal_.getGradientMachine()->finish(); -} +void Trainer::finishTrain() { m->finishTrain(); } void Trainer::startTrainPass() { m->startTrainPass(); } -void TrainerPrivate::startTrainPass() { - this->stats_.reset(); - this->trainPassContext.batchId = 0; - this->trainPassContext.batchSize = this->config_->getOptConfig().batch_size(); - this->trainPassContext.avgTestCost = 0; - this->trainPassContext.numAvgTests = 0; - this->trainPassContext.passInnerId = 0; - this->trainerInternal_.getParameterUpdater()->startPass(); - this->evaluator_->start(); -} - void Trainer::finishTrainPass() { m->finishTrainPass(); } -void TrainerPrivate::finishTrainPass() { - this->trainerInternal_.getGradientMachine()->onPassEnd(); - this->trainerInternal_.getParameterUpdater()->finishPass(); - evaluator_->finish(); +void Trainer::trainOneDataBatch(size_t batchSize, const Arguments& inArgs) { + paddle::DataBatch dataBatch; + dataBatch.getStreams() = inArgs.m->outputs; + dataBatch.setSize(batchSize); + m->trainOneDataBatch(dataBatch); } -void Trainer::setBatchSize(size_t batchSize) { - this->m->trainPassContext.batchSize = batchSize; +bool Trainer::trainOneBatch(size_t batchSize) { + return m->_trainOneBatch(batchSize); } -bool Trainer::trainOneBatch(size_t batchSize) { - if (batchSize == -1UL) { - this->setBatchSize(batchSize); +bool TrainerPrivate::_trainOneBatch(size_t batchSize) { + paddle::DataBatch dataBatch; + CHECK(dataProvider_) << "data_provider is not specified"; + int num = dataProvider_->getNextBatch(batchSize, &dataBatch); + if (num == 0) { + return false; } - return m->_trainOneBatch(); + trainOneDataBatch(dataBatch); + return false; } -bool TrainerPrivate::_trainOneBatch() { - if (this->_prepareBatchData()) { - return true; +void TrainerPrivate::startTestPeriod() { + if (!tester_) { + createTester(); } - this->trainerInternal_.trainOneBatch(this->trainPassContext.batchId, - this->trainPassContext.data); - return false; + tester_->startTestPeriod(); +} + +void Trainer::startTestPeriod() { m->startTestPeriod(); } + +void TrainerPrivate::testOneDataBatch(const paddle::DataBatch& dataBatch) { + tester_->testOneDataBatch(dataBatch, &forwardOutput_); +} + +void Trainer::testOneDataBatch(size_t batchSize, const Arguments& args) { + paddle::DataBatch dataBatch; + dataBatch.getStreams() = args.m->outputs; + dataBatch.setSize(batchSize); + m->testOneDataBatch(dataBatch); } +void TrainerPrivate::finishTestPeriod() { tester_->finishTestPeriod(); } +void Trainer::finishTestPeriod() { m->finishTestPeriod(); } + Matrix* Trainer::getLayerOutput(const std::string& layerName) { auto nn = std::dynamic_pointer_cast( this->m->getGradientMachine()); @@ -138,46 +140,37 @@ Matrix* Trainer::getLayerOutput(const std::string& layerName) { return Matrix::createByPaddleMatrixPtr(&m); } -bool Trainer::prepareBatchData(size_t batchSize) { - if (batchSize != -1UL) { - this->setBatchSize(batchSize); +void Trainer::forwardOneBatch(size_t batchSize) { m->forwardOneBatch(batchSize); } + +bool TrainerPrivate::forwardOneBatch(size_t batchSize) { + CHECK(dataProvider_) << "data_provider is not specified"; + paddle::DataBatch dataBatch; + int num = dataProvider_->getNextBatch(batchSize, &dataBatch); + if (num == 0) { + return false; } - return this->m->_prepareBatchData(); -} -bool TrainerPrivate::_prepareBatchData() { - int num = dataProvider_->getNextBatch(this->trainPassContext.batchSize, - &this->trainPassContext.data); - return num == 0; + forwardOneDataBatch(dataBatch.getStreams()); + return true; } -void Trainer::finishTrainOneBatch() { ++m->trainPassContext.batchId; } +void TrainerPrivate::forwardOneDataBatch( + const std::vector& inArgs) { -void Trainer::forwardOneBatch() throw(UnsupportError) { m->_forwardOneBatch(); } - -void TrainerPrivate::_forwardOneBatch() throw(UnsupportError) { - auto& dataBatch = this->trainPassContext.data; - - int64_t actualBatchSize = dataBatch.getSize(); - if (actualBatchSize == 0) { - return; - } - - const std::vector& inArgs = dataBatch.getStreams(); - std::vector& outArgs = this->trainPassContext.forwardOutput; - outArgs.clear(); - paddle::PassType passType = - this->trainerInternal_.getParameterUpdater()->startBatch(actualBatchSize); + std::vector& outArgs = forwardOutput_; if (config_->getOptConfig().use_sparse_remote_updater()) { - this->trainerInternal_.getGradientMachine()->prefetch(inArgs); - this->trainerInternal_.getParameterUpdater()->getParametersRemote(); + trainerInternal_.getGradientMachine()->prefetch(inArgs); + trainerInternal_.getParameterUpdater()->getParametersRemote(); } - this->trainerInternal_.getGradientMachine()->forward( - inArgs, &outArgs, passType); + trainerInternal_.getGradientMachine()->forward( + inArgs, &outArgs, paddle::PASS_TEST); +} + +Arguments* Trainer::getForwardOutput() { + return Arguments::createByPaddleArgumentVector(&m->getForwardOutput()); } -Arguments* Trainer::getNetworkOutput() { - return Arguments::createByPaddleArgumentVector( - &m->trainPassContext.forwardOutput); +std::vector& TrainerPrivate::getForwardOutput() { + return forwardOutput_; } diff --git a/paddle/api/test/run_tests.sh b/paddle/api/test/run_tests.sh index 1fc6fd5a8c..a4814f98f8 100755 --- a/paddle/api/test/run_tests.sh +++ b/paddle/api/test/run_tests.sh @@ -30,7 +30,7 @@ source .test_env/bin/activate pip --timeout 600 install ../../dist/*.whl -test_list="testArguments.py testGradientMachine.py testMatrix.py testVector.py testTrain.py" +test_list="testArguments.py testGradientMachine.py testMatrix.py testVector.py testTrain.py testTrainer.py" export PYTHONPATH=$PWD/../../../python/ diff --git a/paddle/api/test/testTrain.py b/paddle/api/test/testTrain.py index 7f79c2701e..7759118a3d 100644 --- a/paddle/api/test/testTrain.py +++ b/paddle/api/test/testTrain.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from py_paddle import swig_paddle, DataProviderWrapperConverter +from py_paddle import swig_paddle import paddle.trainer.config_parser -from paddle.trainer.PyDataProviderWrapper import DenseSlot, IndexSlot import numpy import util diff --git a/paddle/api/test/testTrainer.py b/paddle/api/test/testTrainer.py new file mode 100644 index 0000000000..da69a60f84 --- /dev/null +++ b/paddle/api/test/testTrainer.py @@ -0,0 +1,63 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer.config_parser import parse_config +from paddle.trainer.config_parser import logger +from py_paddle import swig_paddle +import util + +def main(): + trainer_config = parse_config( + "./testTrainConfig.py", "") + model = swig_paddle.GradientMachine.createFromConfigProto( + trainer_config.model_config) + trainer = swig_paddle.Trainer.create(trainer_config, model) + trainer.startTrain() + for train_pass in xrange(2): + trainer.startTrainPass() + num = 0 + cost = 0 + while True: # Train one batch + batch_size = 1000 + data, atEnd = util.loadMNISTTrainData(batch_size) + if atEnd: + break + trainer.trainOneDataBatch(batch_size, data) + outs = trainer.getForwardOutput() + cost += sum(outs[0]['value']) + num += batch_size + trainer.finishTrainPass() + logger.info('train cost=%f' % (cost / num)) + + trainer.startTestPeriod() + num = 0 + cost = 0 + while True: # Test one batch + batch_size = 1000 + data, atEnd = util.loadMNISTTrainData(batch_size) + if atEnd: + break + trainer.testOneDataBatch(batch_size, data) + outs = trainer.getForwardOutput() + cost += sum(outs[0]['value']) + num += batch_size + trainer.finishTestPeriod() + logger.info('test cost=%f' % (cost / num)) + + trainer.finishTrain() + + +if __name__ == '__main__': + swig_paddle.initPaddle("--use_gpu=0", "--trainer_count=1") + main() diff --git a/paddle/py_paddle/dataprovider_converter.py b/paddle/py_paddle/dataprovider_converter.py index 6d8f5da3e2..dd2e146d11 100644 --- a/paddle/py_paddle/dataprovider_converter.py +++ b/paddle/py_paddle/dataprovider_converter.py @@ -63,7 +63,7 @@ class SparseBinaryScanner(IScanner): def scan(self, dat): self.extend_cols(dat) - self.__rows__.append(len(dat) + self.__rows__[-1]) + self.__rows__.append(len(self.__cols__)) self.__height__ += 1 def extend_cols(self, dat): diff --git a/paddle/py_paddle/util.py b/paddle/py_paddle/util.py index e6cf2710ef..53f67a861e 100644 --- a/paddle/py_paddle/util.py +++ b/paddle/py_paddle/util.py @@ -79,6 +79,20 @@ class __ParameterCallbackWrapper__(swig_paddle.UpdateCallback): else: return __ParameterCallbackWrapper__(callback).__disown__() +def __arguments_to_numpy__(i, arg): + assert isinstance(arg, swig_paddle.Arguments) + value = arg.getSlotValue(i) + if value is not None: + assert isinstance(value, swig_paddle.Matrix) + value = value.copyToNumpyMat() + ids = arg.getSlotIds(i) + if ids is not None: + assert isinstance(ids, swig_paddle.IVector) + ids = ids.copyToNumpyArray() + return { + "value": value, + "id": ids + } def __monkeypatch_gradient_machine__(): """ @@ -88,20 +102,6 @@ def __monkeypatch_gradient_machine__(): swig_paddle.GradientMachine.loadFromConfigFile = \ staticmethod(loadGradientMachine) - def __arguments_to_numpy__(i, arg): - assert isinstance(arg, swig_paddle.Arguments) - value = arg.getSlotValue(i) - if value is not None: - assert isinstance(value, swig_paddle.Matrix) - value = value.copyToNumpyMat() - ids = arg.getSlotIds(i) - if ids is not None: - assert isinstance(ids, swig_paddle.IVector) - ids = ids.copyToNumpyArray() - return { - "value": value, - "id": ids - } def __matrix_to_numpy__(m): if isinstance(m, swig_paddle.Matrix): @@ -126,7 +126,7 @@ def __monkeypatch_gradient_machine__(): :type paramTypes: list of int :return: paddle.GradientMachine """ - assert isinstance(protoObj, paddle.proto.ModelConfig_pb2.ModelConfig) + assert isinstance(protoObj, paddle.proto.ModelConfig) return swig_paddle.GradientMachine.createByConfigProtoStr( protoObj.SerializeToString(), createMode, paramTypes) @@ -460,13 +460,29 @@ def __monkey_patch_protobuf_objects__(): """ assert isinstance(protoObj, - paddle.proto.TrainerConfig_pb2.OptimizationConfig) + paddle.proto.OptimizationConfig) return swig_paddle.OptimizationConfig.createFromProtoString( protoObj.SerializeToString()) swig_paddle.OptimizationConfig.createFromProto = staticmethod( OptimizationConfig_createFromProto) + def TrainerConfig_createFromProto(protoObj): + """ + Create a new paddle.TrainerConfig from + proto.OptimizationConfig + + :param protoObj: proto.TrainerConfig + :return: paddle.TrainerConfig + """ + assert isinstance(protoObj, + paddle.proto.TrainerConfig) + return swig_paddle.TrainerConfig.createFromProtoString( + protoObj.SerializeToString()) + + swig_paddle.TrainerConfig.createFromProto = staticmethod( + TrainerConfig_createFromProto) + def __monkey_patch_parameter__(): def getBufs(self): @@ -483,9 +499,66 @@ def __monkey_patch_parameter__(): swig_paddle.Parameter.getBufs = getBufs +def __monkey_patch_trainer__(): + swig_paddle.Trainer.__create__ = staticmethod(swig_paddle.Trainer.create) + + def Trainer_create(config, model=None): + """ + Create a trainer for model with TrainerCOnfig trainer_config + trainer_config.model_config will be ignored when model is supplied. + Trainer.trainOneBatch() and Trainer.forwardOneBatch() can be used only + when trainer_config.data_config is set. + + A typical usage for Trainer is: + .. code-block:: python + trainer = Trainer.create(trainer_config, model) + for p in xrange(num_passes) + while True: + data = get_next_batch(batch_size) + if not data: + break + trainer.trainOneDataBatch(batch_size, data) + trainer.finishTrainPass() + trainer.finishTrain() + + The trainer will take care of logging, model saving, distributed + training, etc. + + :param config: trainer configuration + :type config: paddle.proto.TrainerConfig + :param model: the model to be trained + :type model: swig_paddle.GradientMachine + :return: a trainer + :rtype swig_paddle.Trainer + + """ + assert isinstance(config, paddle.proto.TrainerConfig) + if model is not None: + assert isinstance(model, swig_paddle.GradientMachine) + return swig_paddle.Trainer.__create__( + swig_paddle.TrainerConfig.createFromProto(config), model) + swig_paddle.Trainer.create = staticmethod(Trainer_create) + + swig_paddle.Trainer.__getForwardOutput__ = \ + swig_paddle.Trainer.getForwardOutput + + def getForwardOutput(self): + """ + Get the netword outputs from the previous trainOneBatch(), + trainOneDataBatch(), testOneDataPatch(), or forwardOneBatch() call. + + :return: list of dictionary with keys ['id', 'value'], each value is a + numpy.ndarray. + """ + outArgs = self.__getForwardOutput__() + return [__arguments_to_numpy__(i, outArgs) for i in xrange( + outArgs.getSlotNum())] + + swig_paddle.Trainer.getForwardOutput = getForwardOutput + def monkeypatches(): patches = [__monkeypatch_init_paddle__, __monkeypatch_gradient_machine__, __monkey_patch_protobuf_objects__, - __monkey_patch_parameter__] + __monkey_patch_parameter__, __monkey_patch_trainer__] for patch in patches: patch() diff --git a/paddle/trainer/Tester.cpp b/paddle/trainer/Tester.cpp index ccf06e1d84..b1bb75654a 100644 --- a/paddle/trainer/Tester.cpp +++ b/paddle/trainer/Tester.cpp @@ -71,24 +71,36 @@ Tester::Tester(const std::shared_ptr &config, parameterUpdater_)); } +void Tester::startTestPeriod() { + testEvaluator_->start(); + testContext_.cost = 0; + testContext_.numSamples = 0; + + parameterUpdater_->apply(); + if (intconfig_->prevBatchState) { + gradientMachine_->getState(*intconfig_->trainState); + gradientMachine_->setState(*intconfig_->testState); + } +} + +void Tester::testOneDataBatch( + const DataBatch& dataBatch, std::vector* outArgs) { + testContext_.cost += forwardOneBatch( + dataBatch, testEvaluator_.get(), outArgs); + testContext_.numSamples += dataBatch.getSize(); +} + void Tester::testOnePeriod() { DataBatch dataBatch; int64_t batchSize = config_->getOptConfig().batch_size(); - testEvaluator_->start(); - real cost = 0; - int64_t numSamples = 0; bool testAllData = intconfig_->testPeriod == 0 || intconfig_->testAllDataInOnePeriod; - int batches = testAllData ? std::numeric_limits::max() : intconfig_->testPeriod; - parameterUpdater_->apply(); - if (intconfig_->prevBatchState) { - gradientMachine_->getState(*intconfig_->trainState); - gradientMachine_->setState(*intconfig_->testState); - } + std::vector outArgs; + startTestPeriod(); for (int i = 0; i < batches; ++i) { int num = testDataProvider_->getNextBatch(batchSize, &dataBatch); if (num == 0) { @@ -102,13 +114,17 @@ void Tester::testOnePeriod() { num = testDataProvider_->getNextBatch(batchSize, &dataBatch); } } - cost += testOneBatch(dataBatch, testEvaluator_.get()); - numSamples += num; + testOneDataBatch(dataBatch, &outArgs); } +} + +void Tester::finishTestPeriod() { testEvaluator_->finish(); - CHECK_GT(numSamples, 0) << "There is no samples in your test batch. Possibly " - "wrong implementation of DataProvidor.reset()"; - LOG(INFO) << " Test samples=" << numSamples << " cost=" << cost / numSamples + CHECK_GT(testContext_.numSamples, 0) + << "There is no samples in your test batch. Possibly " + "wrong implementation of DataProvidor.reset()"; + LOG(INFO) << " Test samples=" << testContext_.numSamples + << " cost=" << testContext_.cost / testContext_.numSamples << " Eval: " << *testEvaluator_; parameterUpdater_->restore(); if (intconfig_->prevBatchState) { @@ -128,9 +144,11 @@ int64_t Tester::testOneBatchById(int64_t batchId) { return 0; } + std::vector outArgs; + stats_ += std::pair{ actualBatchSize, - testOneBatch(dataBatch, testEvaluator_.get())}; + forwardOneBatch(dataBatch, testEvaluator_.get(), &outArgs)}; if (((batchId + 1) % intconfig_->logPeriod) == 0) { LOG(INFO) << " Batch=" << batchId + 1 << " " << stats_.getStats(false); @@ -139,7 +157,10 @@ int64_t Tester::testOneBatchById(int64_t batchId) { return actualBatchSize; } -real Tester::testOneBatch(const DataBatch &dataBatch, Evaluator *evaluator) { +real Tester::forwardOneBatch(const DataBatch &dataBatch, + Evaluator *evaluator, + std::vector* pOutArgs) { + auto& outArgs = *pOutArgs; const std::vector& inArgs = dataBatch.getStreams(); if (intconfig_->loadsaveParametersInPserver) { REGISTER_TIMER("prefetch"); @@ -148,12 +169,11 @@ real Tester::testOneBatch(const DataBatch &dataBatch, Evaluator *evaluator) { true /*after apply*/); } - std::vector outArgs; gradientMachine_->forward(inArgs, &outArgs, PASS_TEST); // write features if set this flag and outArgs is not empty std::string featFile = intconfig_->featFile; - if (!featFile.empty() && !outArgs.empty()) { + if (!featFile.empty() && outArgs.empty()) { size_t numOutputs = outArgs.size(); std::vector featMatrices; featMatrices.resize(numOutputs); diff --git a/paddle/trainer/Tester.h b/paddle/trainer/Tester.h index 9663b8def9..671ffc5220 100644 --- a/paddle/trainer/Tester.h +++ b/paddle/trainer/Tester.h @@ -68,6 +68,10 @@ public: * is training at same time. */ void testOnePeriod(); + void startTestPeriod(); + void finishTestPeriod(); + void testOneDataBatch(const DataBatch& dataBatch, + std::vector* outArgs); /** * Test for given data batch. @@ -75,7 +79,9 @@ public: * @param evaluator Evaluator * @return cost */ - real testOneBatch(const DataBatch &dataBatch, Evaluator *evaluator); + real forwardOneBatch(const DataBatch& dataBatch, + Evaluator* evaluator, + std::vector* outArgs); /** @@ -99,6 +105,10 @@ protected: std::ofstream os_; std::vector cpuMat_; std::vector cpuVec_; + struct { + int64_t numSamples; + real cost; + } testContext_; private: /** diff --git a/paddle/trainer/Trainer.cpp b/paddle/trainer/Trainer.cpp index 275150e12d..04535849eb 100644 --- a/paddle/trainer/Trainer.cpp +++ b/paddle/trainer/Trainer.cpp @@ -196,7 +196,8 @@ void Trainer::init(const std::shared_ptr &config, if (!dataProvider_ && config_->hasDataConfig()) { dataProvider_.reset(DataProvider::create(*config_, *config_, gpuData)); } - if (dataProvider_) { + if (!testDataProvider_) { + // No evaluator_ if there is testDataProvider but no dataProvider. evaluator_.reset(trainerInternal_.getGradientMachine()->makeEvaluator()); currentEvaluator_.reset( trainerInternal_.getGradientMachine()->makeEvaluator()); @@ -215,10 +216,7 @@ void Trainer::init(const std::shared_ptr &config, DataProvider::create(config_->getTestDataConfig(), *config_, gpuData)); } if (testDataProvider_) { - tester_.reset(new Tester(config_, createTesterConfig(), - trainerInternal_.getGradientMachine(), - trainerInternal_.getParameterUpdater(), - testDataProvider_)); + createTester(); } if (!testing && @@ -258,34 +256,25 @@ void Trainer::init(const std::shared_ptr &config, } } - // set current evaluator and evalutor trainerInternal_.setCurrentEvaluator(currentEvaluator_.get()); trainerInternal_.setEvaluator(evaluator_.get()); } void Trainer::train(size_t numPasses) { - srand(config_->getConfig().start_pass() + 1); - dataProvider_->reset(); - - if (this->testDataProvider_) { - this->testDataProvider_->reset(); - } - - trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); - + startTrain(); for (size_t i = 0; i < numPasses; ++i) { if (IGradientMachineMode::trainWholeDataInOneBatch(mode_)) { trainOnePassBatch(config_->getConfig().start_pass() + i); } else { - trainOnePass(config_->getConfig().start_pass() + i); + trainOnePass(); } if (i < numPasses - 1) { dataProvider_->reset(); } } - trainerInternal_.getGradientMachine()->finish(); + finishTrain(); } @@ -387,13 +376,30 @@ real Trainer::checkGradient() { return maxDiff; } -void Trainer::trainOnePass(int passId) { - this->stats_->reset(); - int64_t batchId = 0; - int32_t batchSize = config_->getOptConfig().batch_size(); - real avgTestCost = 0; - int64_t numAvgTests = 0; - int passInnerId = 1; +void Trainer::startTrain() { + trainPassContext_.passId = config_->getConfig().start_pass(); + srand(config_->getConfig().start_pass() + 1); + if (dataProvider_) { + dataProvider_->reset(); + } + + if (this->testDataProvider_) { + this->testDataProvider_->reset(); + } + + trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); +} + +void Trainer::finishTrain() { + trainerInternal_.getGradientMachine()->finish(); +} + +void Trainer::startTrainPass() { + stats_->reset(); + trainPassContext_.batchId = 0; + trainPassContext_.avgTestCost = 0; + trainPassContext_.numAvgTests = 0; + trainPassContext_.passInnerId = 1; trainerInternal_.getParameterUpdater()->startPass(); evaluator_->start(); @@ -401,81 +407,83 @@ void Trainer::trainOnePass(int passId) { trainerInternal_.getGradientMachine()->resetState(); trainerInternal_.getGradientMachine()->getState(testState_); } - while (true) { - DataBatch dataBatch; - - int num = 0; - { - REGISTER_TIMER("getTrainBatch"); - num = dataProvider_->getNextBatch(batchSize, &dataBatch); - } - if (num == 0) break; +} - if (averageEvaluator_) { - int64_t mod = batchId % FLAGS_average_test_period; - if (mod >= FLAGS_average_test_period - FLAGS_log_period) { - if (mod == FLAGS_average_test_period - FLAGS_log_period) { - averageEvaluator_->start(); - } - trainerInternal_.getParameterUpdater()->apply(); - if (FLAGS_prev_batch_state) { - trainerInternal_.getGradientMachine()->getState(trainState_); - } - avgTestCost += - tester_->testOneBatch(dataBatch, averageEvaluator_.get()); - if (FLAGS_prev_batch_state) { - trainerInternal_.getGradientMachine()->setState(trainState_); - } - numAvgTests += num; - trainerInternal_.getParameterUpdater()->restore(); +void Trainer::trainOneDataBatch(DataBatch& dataBatch) { + int num = dataBatch.getSize(); + if (averageEvaluator_) { + int64_t mod = trainPassContext_.batchId % FLAGS_average_test_period; + if (mod >= FLAGS_average_test_period - FLAGS_log_period) { + if (mod == FLAGS_average_test_period - FLAGS_log_period) { + averageEvaluator_->start(); } + trainerInternal_.getParameterUpdater()->apply(); + if (FLAGS_prev_batch_state) { + trainerInternal_.getGradientMachine()->getState(trainState_); + } + trainPassContext_.avgTestCost += + tester_->forwardOneBatch( + dataBatch, averageEvaluator_.get(), &forwardOutput_); + if (FLAGS_prev_batch_state) { + trainerInternal_.getGradientMachine()->setState(trainState_); + } + trainPassContext_.numAvgTests += num; + trainerInternal_.getParameterUpdater()->restore(); } - { - REGISTER_TIMER("TrainBatch"); - trainerInternal_.trainOneBatch(batchId, dataBatch); - } + } + { + REGISTER_TIMER("TrainBatch"); + trainerInternal_.trainOneBatch( + trainPassContext_.batchId, dataBatch, &forwardOutput_); + } - if (averageEvaluator_ && - batchId % FLAGS_average_test_period == FLAGS_average_test_period - 1) { - averageEvaluator_->finish(); - LOG(INFO) << " Averaged parameter:" - << " cost=" << avgTestCost / numAvgTests - << " Eval: " << *averageEvaluator_; - numAvgTests = 0; - avgTestCost = 0; - } + if (averageEvaluator_ && + trainPassContext_.batchId % FLAGS_average_test_period + == FLAGS_average_test_period - 1) { + averageEvaluator_->finish(); + LOG(INFO) << " Averaged parameter:" + << " cost=" << trainPassContext_.avgTestCost + / trainPassContext_.numAvgTests + << " Eval: " << *averageEvaluator_; + trainPassContext_.numAvgTests = 0; + trainPassContext_.avgTestCost = 0; + } - ++batchId; + ++trainPassContext_.batchId; - if (batchId % FLAGS_log_period == 0) { - FOR_TIMING(globalStat.setThreadInfo(true)); - FOR_TIMING(globalStat.printAllStatus()); - FOR_TIMING(globalStat.reset()); - } + if (trainPassContext_.batchId % FLAGS_log_period == 0) { + FOR_TIMING(globalStat.setThreadInfo(true)); + FOR_TIMING(globalStat.printAllStatus()); + FOR_TIMING(globalStat.reset()); + } - if (testDataProvider_ && FLAGS_test_period > 0 && - batchId % FLAGS_test_period == 0) { - tester_->testOnePeriod(); - } + if (testDataProvider_ && FLAGS_test_period > 0 && + trainPassContext_.batchId % FLAGS_test_period == 0) { + tester_->testOnePeriod(); + } - if (FLAGS_saving_period_by_batches > 0 && - batchId > FLAGS_saving_period_by_batches * passInnerId && - 0 == FLAGS_trainer_id) { - trainerInternal_.getParameterUpdater()->catchUpWith(); - if (testDataProvider_) { - tester_->testOnePeriod(); - } - paramUtil_->saveParametersOnePass(passId, passInnerId); - ++passInnerId; + if (FLAGS_saving_period_by_batches > 0 && + trainPassContext_.batchId + > FLAGS_saving_period_by_batches * trainPassContext_.passInnerId && + 0 == FLAGS_trainer_id) { + trainerInternal_.getParameterUpdater()->catchUpWith(); + if (testDataProvider_) { + tester_->testOnePeriod(); } + paramUtil_->saveParametersOnePass( + trainPassContext_.passId, trainPassContext_.passInnerId); + ++trainPassContext_.passInnerId; } +} - if (batchId == 0) { +void Trainer::finishTrainPass() { + if (trainPassContext_.batchId == 0) { // This means no more data from DataProvider return; } - trainerInternal_.finishTrainPass(passId, batchId); + trainerInternal_.finishTrainPass( + trainPassContext_.passId, trainPassContext_.batchId); FOR_TIMING(globalStat.setThreadInfo(true)); FOR_TIMING(globalStat.printAllStatus()); @@ -485,9 +493,30 @@ void Trainer::trainOnePass(int passId) { tester_->testOnePeriod(); } - if (passId % FLAGS_saving_period == 0 && FLAGS_trainer_id == 0) { - paramUtil_->saveParametersOnePass(passId); + if (trainPassContext_.passId % FLAGS_saving_period == 0 + && FLAGS_trainer_id == 0) { + paramUtil_->saveParametersOnePass(trainPassContext_.passId); } + ++trainPassContext_.passId; +} + +void Trainer::trainOnePass() { + startTrainPass(); + size_t batchSize = config_->getOptConfig().batch_size(); + while (true) { + DataBatch dataBatch; + + int num = 0; + { + REGISTER_TIMER("getTrainBatch"); + num = dataProvider_->getNextBatch(batchSize, &dataBatch); + } + if (num == 0) break; + CHECK_EQ(num, dataBatch.getSize()); + trainOneDataBatch(dataBatch); + } + + finishTrainPass(); } void Trainer::trainOnePassBatch(int passId) { @@ -582,6 +611,13 @@ void Trainer::clearGradient() { int Trainer::getBatchSize() { return config_->getOptConfig().batch_size(); } +void Trainer::createTester() { + tester_.reset(new paddle::Tester(config_, createTesterConfig(), + trainerInternal_.getGradientMachine(), + trainerInternal_.getParameterUpdater(), + testDataProvider_)); +} + void Trainer::test() { tester_->test(); } diff --git a/paddle/trainer/Trainer.h b/paddle/trainer/Trainer.h index 9bfd6d107a..4f4811a139 100644 --- a/paddle/trainer/Trainer.h +++ b/paddle/trainer/Trainer.h @@ -94,6 +94,11 @@ public: */ real checkGradient(); + void startTrain(); + void finishTrain(); + void startTrainPass(); + void finishTrainPass(); + void trainOneDataBatch(DataBatch& dataBatch); /** * given a dataBatch and the current parameter value @@ -144,11 +149,11 @@ public: protected: /** - * Train one pass of data. passId starts from 0 + * Train one pass of data. * * SGD Method. */ - void trainOnePass(int passId); + void trainOnePass(); /** * Train one pass in one batch. @@ -161,6 +166,8 @@ protected: */ void clearGradient(); + void createTester(); + private: std::unique_ptr createTesterConfig(); @@ -173,6 +180,17 @@ protected: MachineState trainState_; MachineState testState_; + struct TrainPassContext { + int64_t batchId; + real avgTestCost; + int64_t numAvgTests; + int passId; + int passInnerId; + }; + std::vector forwardOutput_; + + TrainPassContext trainPassContext_; + std::unique_ptr evaluator_; std::unique_ptr currentEvaluator_; std::unique_ptr averageEvaluator_; diff --git a/paddle/trainer/TrainerInternal.cpp b/paddle/trainer/TrainerInternal.cpp index 6029a4b2c1..e23e42927c 100644 --- a/paddle/trainer/TrainerInternal.cpp +++ b/paddle/trainer/TrainerInternal.cpp @@ -55,6 +55,8 @@ void TrainerInternal::init(const std::shared_ptr &config, gradientMachine_ = gradientMachine; if (!gradientMachine) { + CHECK(config_->getConfig().has_model_config()) + << "Missing model_config in trainer_config"; gradientMachine_.reset(GradientMachine::create( config_->getConfig().model_config(), intconfig_->mode, parameterUpdater_->getParameterTypes())); @@ -62,7 +64,8 @@ void TrainerInternal::init(const std::shared_ptr &config, } void TrainerInternal::trainOneBatch(int64_t batchId, - const DataBatch& dataBatch) { + const DataBatch& dataBatch, + std::vector* outArgs) { // true means updating parameter whenever gradient is ready during backward() bool doPipelineUpdate = (intconfig_->mode != GradientMachine::kSgdSparseCpuTraining) && @@ -84,7 +87,6 @@ void TrainerInternal::trainOneBatch(int64_t batchId, } const std::vector& inArgs = dataBatch.getStreams(); - std::vector outArgs; PassType passType = parameterUpdater_->startBatch(actualBatchSize); @@ -114,7 +116,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId, timer.start(); #endif REGISTER_TIMER("forwardBackward"); - forwardBackwardBatch(inArgs, outArgs, passType, updateCallback, + forwardBackwardBatch(inArgs, *outArgs, passType, updateCallback, doPipelineUpdate); #ifndef PADDLE_DISABLE_TIMER timer.stop(); @@ -132,7 +134,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId, real cost = 0; { REGISTER_TIMER("sumCost"); - cost = Argument::sumCosts(outArgs); + cost = Argument::sumCosts(*outArgs); } if (batchId % intconfig_->log_period == 0) { diff --git a/paddle/trainer/TrainerInternal.h b/paddle/trainer/TrainerInternal.h index 17011c4d2e..3a53aa1d17 100644 --- a/paddle/trainer/TrainerInternal.h +++ b/paddle/trainer/TrainerInternal.h @@ -81,7 +81,9 @@ public: * @param batchId current batch id * @param dataBatch data for the batch */ - void trainOneBatch(int64_t batchId, const DataBatch& dataBatch); + void trainOneBatch(int64_t batchId, + const DataBatch& dataBatch, + std::vector* outArgs); /** * showParameterStats diff --git a/proto/TrainerConfig.proto.m4 b/proto/TrainerConfig.proto.m4 index a42ff88d54..3b0e24f90b 100644 --- a/proto/TrainerConfig.proto.m4 +++ b/proto/TrainerConfig.proto.m4 @@ -130,7 +130,7 @@ message OptimizationConfig { }; message TrainerConfig { - required ModelConfig model_config = 1; + optional ModelConfig model_config = 1; optional DataConfig data_config = 2; required OptimizationConfig opt_config = 3; optional DataConfig test_data_config = 4; diff --git a/python/paddle/proto/__init__.py b/python/paddle/proto/__init__.py index 7f9e87eee6..cd6a59ecbb 100644 --- a/python/paddle/proto/__init__.py +++ b/python/paddle/proto/__init__.py @@ -12,3 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from paddle.proto.TrainerConfig_pb2 import OptimizationConfig, TrainerConfig +from paddle.proto.ModelConfig_pb2 import ModelConfig From 45dca19a8d8744127331af0e0fd326e4afc740a2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 28 Oct 2016 11:07:31 +0800 Subject: [PATCH 02/11] Change contribute to paddle to fit new branching model (#275) * Change contribute to paddle to fit new branching model --- doc/build/contribute_to_paddle.md | 41 +++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/doc/build/contribute_to_paddle.md b/doc/build/contribute_to_paddle.md index bbdbb4d422..a9ab69c5f4 100644 --- a/doc/build/contribute_to_paddle.md +++ b/doc/build/contribute_to_paddle.md @@ -4,7 +4,7 @@ We sincerely appreciate your contributions. You can use fork and pull request workflow to merge your code. ## Code Requirements -- Your code mush be fully documented by +- Your code must be fully documented by [doxygen](http://www.stack.nl/~dimitri/doxygen/) style. - Make sure the compiler option WITH\_STYLE\_CHECK is on and the compiler passes the code style check. @@ -20,16 +20,30 @@ It's just that simple. ## Clone +Paddle is currently using [git-flow branching model](http://nvie.com/posts/a-successful-git-branching-model/). +The **develop** is the main branch, and other user's branches are feature branches. + Once you've created a fork, you can use your favorite git client to clone your repo or just head straight to the command line: ```shell # Clone your fork to your local machine -git clone https://github.com/USERNAME/Paddle.git +git clone --branch develop https://github.com/USERNAME/Paddle.git +``` +If your repository doesn't contain **develop** branch, just create it by your own. + +```shell +git clone https://github.com/USERNAME/Paddle.git Paddle +cd Paddle +git checkout -b develop # create develop branch. +git remote add upstream https://github.com/baidu/Paddle.git # add upstream to baidu/Paddle +git pull upstream develop # update to upstream ``` + Then you can start to develop by making a local developement branch + ```shell -git checkout -b MY_COOL_STUFF_BRANCH origin/master +git checkout -b MY_COOL_STUFF_BRANCH ``` ## Commit @@ -41,7 +55,7 @@ Commit your changes by following command lines: git status # add modified files git add xx -git commit -m "commit info" +env EDITOR=vim git commit # You can write your comments by vim/nano/emacs. ``` The first line of commit infomation is the title. The second and later lines are the details if any. @@ -63,7 +77,7 @@ git remote -v Update your fork with the latest upstream changes: ```shell -git pull --rebase upstream HEAD +git pull --rebase upstream develop ``` If there are no unique commits locally, git will simply perform a fast-forward. @@ -76,7 +90,7 @@ Now, your local master branch is up-to-date with everything modified upstream. ```shell # push to your repository in Github -git push origin HEAD +git push -u origin MY_COOL_STUFF_BRANCH # create remote branch MY_COOL_STUFF_BRANCH to origin. ``` ## Pull Request @@ -93,13 +107,24 @@ of conflict, you need to do the update manually. You need to do the following on your local repository: ```shell git checkout MY_COOL_STUFF_BRANCH -git pull --rebase upstream HEAD +git pull upstream develop # You may need to resolve the conflict according to the git prompt. # Make and test your code. -git push -f origin HEAD +git push origin MY_COOL_STUFF_BRANCH ``` Now your Pull Request is updated with the latest version. ## Revise your pull request When you revise your pull request according to reviewer's comments, please use 'git commit' instead of 'git commit --amend' to commit your changes so that the reviewers can see the difference between the new pull requrest and the old pull request. + +The possible commands are + +```shell +git checkout MY_COOL_STUFF_BRANCH +git pull upstream develop # update local to newest code base. +# May be some conflicts will occured. +# And develop your cool stuff +env EDITOR=vim git commit # add your revise log +git push origin MY_COOL_STUFF_BRANCH +``` From 0e1a22d0fcc80ea3931b7006c891b844b07f29a2 Mon Sep 17 00:00:00 2001 From: backyes Date: Fri, 28 Oct 2016 12:27:40 +0800 Subject: [PATCH 03/11] set test_period default value to 0 (#279) --- paddle/trainer/Trainer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/trainer/Trainer.cpp b/paddle/trainer/Trainer.cpp index 04535849eb..7fc48dd1fb 100644 --- a/paddle/trainer/Trainer.cpp +++ b/paddle/trainer/Trainer.cpp @@ -40,7 +40,7 @@ limitations under the License. */ #include "TrainerConfigHelper.h" P_DEFINE_string(config, "", "Trainer config file"); -P_DEFINE_int32(test_period, 1000, +P_DEFINE_int32(test_period, 0, "Run test every so many train batches." " 0 for testing after each pass." " If not 0, test log_period batches." From ca5a5ec480db806674043311b2c59fd318ad4f41 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Fri, 28 Oct 2016 12:51:35 +0800 Subject: [PATCH 04/11] Make Paddle --save_dir support a directory name (#277) * Also fix #243 --- paddle/trainer/ParamUtil.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/trainer/ParamUtil.cpp b/paddle/trainer/ParamUtil.cpp index dae8b44b6d..bb309a5497 100644 --- a/paddle/trainer/ParamUtil.cpp +++ b/paddle/trainer/ParamUtil.cpp @@ -89,6 +89,9 @@ void ParameterUtil::saveParameters(int passId, int passInnerId) { } std::string basePath = config_->getSaveDir(); + if (basePath.find('/') == std::string::npos) { + basePath = "./" + basePath; + } mkDirRecursively(basePath.c_str()); std::string saveDir = path::join(basePath, buf); From fc9ca53ae421ec2eac5648ec8de6a1e7a993bfd0 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Fri, 28 Oct 2016 12:54:13 +0800 Subject: [PATCH 05/11] fix interface bug of block_expand_layer and add unittest (#265) * fix interface bug of block_expand_layer and add unittest * auto compute num_channels * default value of num_channels is None * adjust input order of block_expand --- .../paddle/trainer_config_helpers/layers.py | 34 ++++++++++--------- .../tests/configs/check.md5 | 2 +- .../tests/configs/test_maxout.py | 21 +++++++++++- 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index f8c32dc91f..7df9108ae8 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -54,7 +54,7 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel", 'cross_entropy_with_selfnorm', 'cross_entropy', 'multi_binary_label_cross_entropy', 'rank_cost', 'lambda_cost', 'huber_cost', - # 'block_expand_layer', # TODO(yuyang18): this layer is not correct + 'block_expand_layer', 'maxout_layer', 'out_prod_layer', 'print_layer' ] @@ -3284,18 +3284,18 @@ convex_comb_layer = linear_comb_layer @wrap_name_default() @layer_support() def block_expand_layer(input, - channel=0, block_x=0, block_y=0, stride_x=0, stride_y=0, padding_x=0, padding_y=0, + num_channels=None, name=None, layer_attr=None): """ Expand feature map to minibatch matrix. - - matrix width is: block_y * block_x * channel + - matrix width is: block_y * block_x * num_channels - matirx height is: outputH * outputW .. math:: @@ -3307,7 +3307,7 @@ def block_expand_layer(input, The expand method is the same with ExpandConvLayer, but saved the transposed value. After expanding, output.sequenceStartPositions will store timeline. The number of time steps are outputH * outputW and the dimension of each - time step is block_y * block_x * channel. This layer can be used after + time step is block_y * block_x * num_channels. This layer can be used after convolution neural network, and before recurrent neural network. The simple usage is: @@ -3315,7 +3315,7 @@ def block_expand_layer(input, .. code-block:: python block_expand = block_expand_layer(input, - channel=128, + num_channels=128, stride_x=1, stride_y=1, block_x=1, @@ -3323,8 +3323,8 @@ def block_expand_layer(input, :param input: The input layer. :type input: LayerOutput - :param channel: The channel number of input layer. - :type channel: int + :param num_channels: The channel number of input layer. + :type num_channels: int|None :param block_x: The width of sub block. :type block_x: int :param block_y: The width of sub block. @@ -3344,16 +3344,18 @@ def block_expand_layer(input, :return: LayerOutput object. :rtype: LayerOutput """ + if num_channels is None: + assert input.num_filters is not None + num_channels = input.num_filters Layer(name=name, - input=Input(input.name, - block_expand=BlockExpand(channels=channel, - block_x=block_x, - block_y=block_y, - stride_x=stride_x, - stride_y=stride_y, - padding_x=padding_x, - padding_y=padding_y) - ), + inputs=Input(input.name, + block_expand=BlockExpand(channels=num_channels, + block_x=block_x, + block_y=block_y, + stride_x=stride_x, + stride_y=stride_y, + padding_x=padding_x, + padding_y=padding_y)), type=LayerType.BLOCK_EXPAND, **ExtraLayerAttribute.to_kwargs(layer_attr) ) diff --git a/python/paddle/trainer_config_helpers/tests/configs/check.md5 b/python/paddle/trainer_config_helpers/tests/configs/check.md5 index 88ce5c129e..d1b22b3490 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/check.md5 +++ b/python/paddle/trainer_config_helpers/tests/configs/check.md5 @@ -12,7 +12,7 @@ a5d9259ff1fd7ca23d0ef090052cb1f2 last_first_seq.protostr 8bb44e1e5072d0c261572307e7672bda test_grumemory_layer.protostr 1f3510672dce7a9ed25317fc58579ac7 test_hsigmoid.protostr d350bd91a0dc13e854b1364c3d9339c6 test_lstmemory_layer.protostr -6fa59551808ee7012bbd24f757e782d2 test_maxout.protostr +5433ed33d4e7414eaf658f2a55946186 test_maxout.protostr 251a948ba41c1071afcd3d9cf9c233f7 test_ntm_layers.protostr e6ff04e70aea27c7b06d808cc49c9497 test_print_layer.protostr 2a75dd33b640c49a8821c2da6e574577 test_rnn_group.protostr diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_maxout.py b/python/paddle/trainer_config_helpers/tests/configs/test_maxout.py index 079e2cf4c4..7c1fb04766 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/test_maxout.py +++ b/python/paddle/trainer_config_helpers/tests/configs/test_maxout.py @@ -25,6 +25,25 @@ pool = img_pool_layer(input=maxout, stride=2, pool_type=MaxPooling()) -fc = fc_layer(input=pool, size=384, bias_attr=False) +conv2 = img_conv_layer(input=pool, + filter_size = 3, + num_channels=32, + num_filters=128, + padding=1, + act=LinearActivation(), + bias_attr=True) + +maxout2 = maxout_layer(input=conv, + num_channels=128, + groups=4) + +block = block_expand_layer(input=maxout, + num_channels=32, + stride_x=1, + stride_y=1, + block_x=1, + block_y=6) + +fc = fc_layer(input=block, size=384, bias_attr=False) outputs(fc) From fa24cbdbe1b18649a9140152b971473ba99079f7 Mon Sep 17 00:00:00 2001 From: backyes Date: Fri, 28 Oct 2016 13:10:04 +0800 Subject: [PATCH 06/11] Support empty Param Block in ParameterSever (#244) * Because in cluster maybe use a lot machine to train a model, and some parameter size could be too small for ParameterServer. Then some of pservers could not have any ParamBlock. * Also, because ports_num or ports_num_for_sparse is too large, then give a warning in runtime. --- paddle/pserver/ParameterServer2.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/paddle/pserver/ParameterServer2.cpp b/paddle/pserver/ParameterServer2.cpp index 8f72c1988d..c8f37d0bf4 100644 --- a/paddle/pserver/ParameterServer2.cpp +++ b/paddle/pserver/ParameterServer2.cpp @@ -264,6 +264,15 @@ void ParameterServer2::setParameter(const SendParameterRequest& request, std::vector blockIds; blockIds.reserve(request.blocks_size()); int bufferIndex = 0; + + if (!request.blocks().size()) { + LOG(WARNING) + << "--ports_num or --ports_num_for_sparse might be too large, " + << "or total dense parameter size or sparse parameters size " + << "might be too small, this psever doesn't store any parameter."; + return; + } + for (const auto& block : request.blocks()) { /// block size for parameter(e.g. 128 for sparse row, 1K for dense) uint64_t blockSize = getParameterConfig(block).parameter_block_size(); From 212d3391029b49526eb617aa30c49119034c6724 Mon Sep 17 00:00:00 2001 From: zhouxiao-coder Date: Mon, 31 Oct 2016 16:36:17 +0800 Subject: [PATCH 07/11] Adding an introduction doc for Paddle to implement simplest linear regression. --- demo/introduction/README.md | 4 ++ demo/introduction/dataprovider.py | 24 +++++++ demo/introduction/evaluate_model.py | 36 ++++++++++ demo/introduction/train.sh | 21 ++++++ demo/introduction/trainer_config.py | 32 +++++++++ doc/index.md | 1 + doc/introduction/index.md | 101 ++++++++++++++++++++++++++ doc/introduction/parameters.png | 1 + doc_cn/index.rst | 2 +- doc_cn/introduction/index.md | 105 ++++++++++++++++++++++++++++ doc_cn/introduction/parameters.png | Bin 0 -> 44469 bytes 11 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 demo/introduction/README.md create mode 100644 demo/introduction/dataprovider.py create mode 100755 demo/introduction/evaluate_model.py create mode 100755 demo/introduction/train.sh create mode 100644 demo/introduction/trainer_config.py create mode 100644 doc/introduction/index.md create mode 120000 doc/introduction/parameters.png create mode 100644 doc_cn/introduction/index.md create mode 100644 doc_cn/introduction/parameters.png diff --git a/demo/introduction/README.md b/demo/introduction/README.md new file mode 100644 index 0000000000..bebf1d090d --- /dev/null +++ b/demo/introduction/README.md @@ -0,0 +1,4 @@ +This folder contains scripts used in PaddlePaddle introduction. +- use `bash train.sh` to train a simple linear regression model +- use `python evaluate_model.py` to read model parameters. You can see that `w` and `b` are very close to [2, 0.3]. + diff --git a/demo/introduction/dataprovider.py b/demo/introduction/dataprovider.py new file mode 100644 index 0000000000..be8c0bc891 --- /dev/null +++ b/demo/introduction/dataprovider.py @@ -0,0 +1,24 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer.PyDataProvider2 import * +import random + +# define data types of input: 2 real numbers +@provider(input_types=[dense_vector(1), dense_vector(1)],use_seq=False) +def process(settings, input_file): + for i in xrange(2000): + x = random.random() + yield [x], [2*x+0.3] + diff --git a/demo/introduction/evaluate_model.py b/demo/introduction/evaluate_model.py new file mode 100755 index 0000000000..8cfb843c42 --- /dev/null +++ b/demo/introduction/evaluate_model.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Print model parameters in last model + +Usage: + python evaluate_model.py +""" +import numpy as np +import os + +def load(file_name): + with open(file_name, 'rb') as f: + f.read(16) # skip header for float type. + return np.fromfile(f, dtype=np.float32) + +def main(): + print 'w=%.6f, b=%.6f from pass 29' % (load('output/pass-00029/w'), + load('output/pass-00029/b')) + +if __name__ == '__main__': + main() diff --git a/demo/introduction/train.sh b/demo/introduction/train.sh new file mode 100755 index 0000000000..06db8edd10 --- /dev/null +++ b/demo/introduction/train.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +paddle train \ + --config=trainer_config.py \ + --save_dir=./output \ + --num_passes=30 \ + 2>&1 |tee 'train.log' diff --git a/demo/introduction/trainer_config.py b/demo/introduction/trainer_config.py new file mode 100644 index 0000000000..3e3df55832 --- /dev/null +++ b/demo/introduction/trainer_config.py @@ -0,0 +1,32 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +# 1. read data. Suppose you saved above python code as dataprovider.py +data_file = 'empty.list' +with open(data_file, 'w') as f: f.writelines(' ') +define_py_data_sources2(train_list=data_file, test_list=None, + module='dataprovider', obj='process',args={}) + +# 2. learning algorithm +settings(batch_size=12, learning_rate=1e-3, learning_method=MomentumOptimizer()) + +# 3. Network configuration +x = data_layer(name='x', size=1) +y = data_layer(name='y', size=1) +y_predict = fc_layer(input=x, param_attr=ParamAttr(name='w'), size=1, act=LinearActivation(), bias_attr=ParamAttr(name='b')) +cost = regression_cost(input=y_predict, label=y) +outputs(cost) + diff --git a/doc/index.md b/doc/index.md index df03a33fac..a4dffb0405 100644 --- a/doc/index.md +++ b/doc/index.md @@ -3,6 +3,7 @@ PaddlePaddle Documentation User Guide ---------- +* [Introduction](introduction/index.md) * [Quick Start](demo/quick_start/index_en.md) * [Build and Installation](build/index.rst) * [Contribute Code](build/contribute_to_paddle.md) diff --git a/doc/introduction/index.md b/doc/introduction/index.md new file mode 100644 index 0000000000..004ca07844 --- /dev/null +++ b/doc/introduction/index.md @@ -0,0 +1,101 @@ +# Introduction + +PaddlePaddle is a deep learning platform open-sourced by Baidu. With PaddlePaddle, you can easily train a classic neural network within a couple lines of configuration, or you can build sophisticated models that provide state-of-the-art performance on difficult learning tasks like sentiment analysis, machine translation, image caption and so on. + +## 1. A Classic Problem + +Now, to give you a hint of what using PaddlePaddle looks like, let's start with a fundamental learning problem - **simple linear regression** : you have observed a set of two-dimensional data points of `X` and `Y`, where `X` is an explanatory variable and `Y` is corresponding dependent variable, and you want to recover the underlying correlation between `X` and `Y`. Linear regression can be used in many practical scenarios. For example, `X` can be a variable about house size, and `Y` a variable about house price. You can build a model that captures relationship between them by observing real estate markets. + +## 2. Prepare the Data + +Suppose the true relationship can be characterized as `Y = 2X + 0.3`, let's see how to recover this pattern only from observed data. Here is a piece of python code that feeds synthetic data to PaddlePaddle. The code is pretty self-explanatory, the only extra thing you need to add for PaddlePaddle is a definition of input data types. + +```python +# dataprovider.py +from paddle.trainer.PyDataProvider2 import * +import random + +# define data types of input: 2 real numbers +@provider(input_types=[dense_vector(1), dense_vector(1)],use_seq=False) +def process(settings, input_file): + for i in xrange(2000): + x = random.random() + yield [x], [2*x+0.3] +``` + +## 3. Train a NeuralNetwork in PaddlePaddle + +To recover this relationship between `X` and `Y`, we use a neural network with one layer of linear activation units and a square error cost layer. Don't worry if you are not familiar with these terminologies, it's just saying that we are starting from a random line `Y' = wX + b` , then we gradually adapt `w` and `b` to minimize the difference between `Y'` and `Y`. Here is what it looks like in PaddlePaddle: + +```python +# trainer_config.py +from paddle.trainer_config_helpers import * + +# 1. read data. Suppose you saved above python code as dataprovider.py +data_file = 'empty.list' +with open(data_file, 'w') as f: f.writelines(' ') +define_py_data_sources2(train_list=data_file, test_list=None, + module='dataprovider', obj='process',args={}) + +# 2. learning algorithm +settings(batch_size=12, learning_rate=1e-3, learning_method=MomentumOptimizer()) + +# 3. Network configuration +x = data_layer(name='x', size=1) +y = data_layer(name='y', size=1) +y_predict = fc_layer(input=x, param_attr=ParamAttr(name='w'), size=1, act=LinearActivation(), bias_attr=ParamAttr(name='b')) +cost = regression_cost(input=y_predict, label=y) +outputs(cost) +``` + +Some of the most fundamental usages of PaddlePaddle are demonstrated: + +- The first part shows how to feed data into PaddlePaddle. In general cases, PaddlePaddle reads raw data from a list of files, and then do some user-defined process to get real input. In this case, we only need to create a placeholder file since we are generating synthetic data on the fly. + +- The second part describes learning algorithm. It defines in what ways adjustments are made to model parameters. PaddlePaddle provides a rich set of optimizers, but a simple momentum based optimizer will suffice here, and it processes 12 data points each time. + +- Finally, the network configuration. It usually is as simple as "stacking" layers. Three kinds of layers are used in this configuration: + - **Data Layer**: a network always starts with one or more data layers. They provide input data to the rest of the network. In this problem, two data layers are used respectively for `X` and `Y`. + - **FC Layer**: FC layer is short for Fully Connected Layer, which connects all the input units to current layer and does the actual computation specified as activation function. Computation layers like this are the fundamental building blocks of a deeper model. + - **Cost Layer**: in training phase, cost layers are usually the last layers of the network. They measure the performance of current model, and provide guidence to adjust parameters. + +Now that everything is ready, you can train the network with a simple command line call: + ``` + paddle train --config=trainer_config.py --save_dir=./output --num_passes=30 + ``` + +This means that PaddlePaddle will train this network on the synthectic dataset for 30 passes, and save all the models under path `./output`. You will see from the messages printed out during training phase that the model cost is decreasing as time goes by, which indicates we are getting a closer guess. + + +## 4. Evaluate the Model + +Usually, a different dataset that left out during training phase should be used to evalute the models. However, we are lucky enough to know the real answer: `w=2, b=0.3`, thus a better option is to check out model parameters directly. + +In PaddlePaddle, training is just to get a collection of model parameters, which are `w` and `b` in this case. Each parameter is saved in an individual file in the popular `numpy` array format. Here is the code that reads parameters from last pass. + +```python +import numpy as np +import os + +def load(file_name): + with open(file_name, 'rb') as f: + f.read(16) # skip header for float type. + return np.fromfile(f, dtype=np.float32) + +print 'w=%.6f, b=%.6f' % (load('output/pass-00029/w'), load('output/pass-00029/b')) +# w=1.999743, b=0.300137 +``` + +
![](./parameters.png)
+ +Although starts from a random guess, you can see that value of `w` changes quickly towards 2 and `b` changes quickly towards 0.3. In the end, the predicted line is almost identical with real answer. + +There, you have recovered the underlying pattern between `X` and `Y` only from observed data. + + +## 5. Where to Go from Here + +- Build and Installation +- Quick Start +- Example and Demo + diff --git a/doc/introduction/parameters.png b/doc/introduction/parameters.png new file mode 120000 index 0000000000..f47e74c94f --- /dev/null +++ b/doc/introduction/parameters.png @@ -0,0 +1 @@ +../../doc_cn/introduction/parameters.png \ No newline at end of file diff --git a/doc_cn/index.rst b/doc_cn/index.rst index d2d50fbdb4..715da44fb4 100644 --- a/doc_cn/index.rst +++ b/doc_cn/index.rst @@ -3,7 +3,7 @@ PaddlePaddle文档 使用指南 -------- - +* `介绍 `_ * `快速入门 `_ * `编译与安装 `_ * `用户接口 `_ diff --git a/doc_cn/introduction/index.md b/doc_cn/introduction/index.md new file mode 100644 index 0000000000..164cb7d494 --- /dev/null +++ b/doc_cn/introduction/index.md @@ -0,0 +1,105 @@ +# 简介 + +PaddlePaddle 是起源于百度的开源深度学习平台。它是简单易用的:你可以通过简单的十数行配置搭建经典的神经网络模型;它也是高效强大的:PaddlePaddle可以支撑复杂集群环境下超大模型的训练,令你受益于深度学习的前沿成果。在百度内部,已经有大量产品线使用了基于PaddlePaddle的深度学习技术。 + +这份简短的介绍将像你展示如何利用PaddlePaddle解决一个经典的学习问题。 + +## 1. 一个经典的任务 + +让我们从一个基础问题开始:单变量的线性回归。问题假定观测到了一批二维空间上的点`(x, y) `,并且已知 `x` 和 `y` 之间存在着某种线性关系,我们的目标是通过观测数据还原这个线性关系。作为一个简单基础的模型,线性回归却有着广泛的应用场景。比如可以想象一个资产定价的简化场景,其中 `x` 对应于房屋的大小,`y` 对应于房屋价格。我们可以通过观察市场上房屋的情况获得二者之间的关系,从而为新房屋的定价提供参考。 + + +## 2. 准备数据 + +假设变量 `X` 和 `Y` 的真实关系为: `Y = 2X + 0.3`,这里展示如何使用观测数据还原这一线性关系。如下Python代码将随机产生2000个观测点,它们将被用作PaddlePaddle的输入。产生PaddlePaddle的输入数据和写一段普通的Python脚本几乎一样,你唯一需要增加的就是定义输入数据的类型。 + +```python +# -*- coding:utf-8 -*- +# dataprovider.py +from paddle.trainer.PyDataProvider2 import * +import random + +# 定义输入数据的类型: 2个浮点数 +@provider(input_types=[dense_vector(1), dense_vector(1)],use_seq=False) +def process(settings, input_file): + for i in xrange(2000): + x = random.random() + yield [x], [2*x+0.3] +``` + +## 3. 训练模型 + +为了还原 `Y = 2X + 0.3`,我们先从一条随机的直线 `Y' = wX + b` 开始,然后利用观测数据调整 `w` 和 `b` 使得 `Y'` 和 `Y` 的差距不断减小,最终趋于相同。这个过程就是模型的训练过程,而 `w` 和 `b` 就是模型的参数,即我们的训练目标。 + +在PaddlePaddle里,该模型的网络配置如下。 + +```python +# -*- coding:utf-8 -*- +# trainer_config.py +from paddle.trainer_config_helpers import * + +# 1. 定义数据来源,调用上面的process函数获得观测数据 +data_file = 'empty.list' +with open(data_file, 'w') as f: f.writelines(' ') +define_py_data_sources2(train_list=data_file, test_list=None, + module='dataprovider', obj='process',args={}) + +# 2. 学习算法。控制如何改变模型参数 w 和 b +settings(batch_size=12, learning_rate=1e-3, learning_method=MomentumOptimizer()) + +# 3. 神经网络配置 +x = data_layer(name='x', size=1) +y = data_layer(name='y', size=1) +# 线性计算单元: y_predict = wx + b +y_predict = fc_layer(input=x, param_attr=ParamAttr(name='w'), size=1, act=LinearActivation(), bias_attr=ParamAttr(name='b')) +# 损失计算,度量 y_predict 和真实 y 之间的差距 +cost = regression_cost(input=y_predict, label=y) +outputs(cost) +``` +这段简短的配置展示了PaddlePaddle的基本用法: + +- 首先,第一部分定义了数据输入。一般情况下,PaddlePaddle先从一个文件列表里获得数据文件地址,然后交给用户自定义的函数(例如上面的`process`函数)进行读入和预处理从而得到真实输入。本文中由于输入数据是随机生成的不需要读输入文件,所以放一个空列表(`empty.list`)即可。 + +- 第二部分主要是选择学习算法,它定义了模型参数如何改变。PaddlePaddle提供了很多优秀的学习算法,但这里使用一个简单的基于momentum的算法就足够了,它每次读取12个数据进行计算和模型更新。 + +- 最后一部分是神经网络的配置。由于PaddlePaddle已经实现了丰富的网络单元(Layer),所以很多时候你需要做的只是声明正确的网络单元并把它们拼接起来。这里使用了三种网络单元: + - **数据层**:数据层 `data_layer` 是神经网络的入口,它读入数据并将它们传输到下游的其它单元。这里数据层有两个,分别对应于变量 `X` 和 `Y`。 + - **全连接层**:全连接层 `fc_layer` 是基础的计算单元,这里利用它建模变量之间的线性关系。计算单元是神经网络的核心,PaddlePaddle支持大量的计算单元和任意深度的网络连接,从而可以挖掘复杂的数据关系。 + - **回归损失层**:回归损失层 `regression_cost`是众多损失函数层的一种,它们在训练过程作为网络的出口,用来计算模型的表现,并指导模型参数的改变。 + +这样定义了网络结构并保存为`trainer_config.py`之后,运行训练命令即可: + ``` + paddle train --config=trainer_config.py --save_dir=./output --num_passes=30 + ``` + +PaddlePaddle将在观测数据集上迭代训练30轮,并将每轮的模型结果存放在 `./output` 路径下。从输出日志可以看到,随着轮数增加损失函数的输出在不断的减小,这意味着模型在不断的改进,直到逼近真实解:` Y = 2X + 0.3 ` + +## 4. 模型检验 + +训练完成后,我们希望能够检验模型的好坏。一种常用的做法是用模型对另外一组数据进行预测,然后评价预测的效果。但在这个例子中,由于已经知道了真实答案,我们可以直接观察模型的参数是否符合预期来进行检验。 + +PaddlePaddle将每个模型参数作为一个numpy数组单独存为一个文件,所以可以利用如下方法读取模型的参数。 + +```python +import numpy as np +import os + +def load(file_name): + with open(file_name, 'rb') as f: + f.read(16) # skip header for float type. + return np.fromfile(f, dtype=np.float32) + +print 'w=%.6f, b=%.6f' % (load('output/pass-00029/w'), load('output/pass-00029/b')) +# w=1.999743, b=0.300137 +``` +
![](./parameters.png)
+ +从图中可以看到,虽然 `w` 和 `b` 都使用随机值初始化,但在起初的几轮训练中它们都在快速逼近真实值,并且后续仍在不断改进,使得最终得到的模型几乎与真实模型重合。 + +这样,我们就完成了对单变量线性回归问题的解决:将数据输入PaddlePaddle,训练模型,最后验证结果。 + +## 5. 推荐后续阅读 + +- 安装/编译:PaddlePaddle的安装与编译文档。 +- 快速入门 :使用商品评论分类任务,系统性的介绍如何一步步改进,最终得到产品级的深度模型。 +- 示例:各种实用案例,涵盖图像、文本、推荐等多个领域。 diff --git a/doc_cn/introduction/parameters.png b/doc_cn/introduction/parameters.png new file mode 100644 index 0000000000000000000000000000000000000000..2ec67480951e21f0400bce1c34b3108dcd65c18c GIT binary patch literal 44469 zcmeGEWmr{P_dgE9ra`)-5u_XGMp9|%25CWR)1A^F2D#}}IwUvU-67rG-TW6v&;32W z_s`46>*A8V_KG>@nsba#j7gZXq6|7JF)9oU4Eh^c$#*a?AU_xw*b-z!;7H%1sXp)@ zEcl(wYnb9e(rw@aioL8h7zPH@_~{2WUo_7OI6>P=P0LA3L0-Vb?jx(Qsoi@sR=1D# zz}YY`LT&=UuOH2vj49nd+Sq~x+=QwAIYR*W{pm0p73DvtI9Ur*X(=dEO4vD?QF60# zv9eQ%pi)v&3OSmZ3%rw*{(Cv_Ntnvg$;n=Tjm_27mDTkXtDU0-8wWo>KN~wI8z(0V za0Uz5-PXz2jl~vB{pTkC-bd05Y~pBT?__0XOZjwPRgWXF(@={Zu=}dbg3wj~8)$ry}Aqnrt zJ4x*N%z|mNjsu;kasnq7>yr9u4N>j-_j$+x$N``L*~13J*9wV;-h*ojwx8Z0f8p$b z(lyntL%>Sci-NU7(7fxr6{A7q=djd9@R5Gl>>z2%_sEh`@(M+n(#`)l9PNi9=85&^ zjDL<}8NyyDVz;>%V1+hlc*!|yY^9Q}aDkErKTQe=xtHohpUv)2T`(Hq%8YeMN1=9w{p`nT@6dg;|i$1#-@^ZBlY@m}ua-7c}GdX`L_L64!! z2EOFhb#<5+m#Ij3McW>$`4pEsT>HwNog>y>AuRcz!~dV#+(3Td_(w;2v6P_~d!Hz_$P; z>e_`GpA9jff8r<>u!Ck@Tk9DVbj@vUa4B3MZN@y#_5LuHSBh!y9Oz2sZQF+W5VMVo z@HHr}4>3gFAmtT2%#B%TpCx$OL9*6=EH7Adv6L7p9zTy)?8Am)2AKyVrCHx8TfyIM z|5~UUvs-NR=DS||s{iQ2QqdU|8NRvP&Jjo;(*tM}98W@wW| zx*@Fkb$lZwLn>v3-|z)AL3p=F7^Ka>OpZ(y_B4OM>OZGKW1fV+@-~maYFfM#FL|G# zrP+&0@df6P8=dD6dWw<6^O3j=PKf#)%t=`K$)>2g&PCP7*-zQoFTXKq4(?m6zmb&{ z%aV?o6Z^a~QDE|Lcb%_WU5p##W$(r!C0?LW#O`%vuU4qVu@-2aN)rNe8|t{`mZzBV zRxziLmLF&75+aIo4L71^ie*%wlqTX$l~t}GS7^{6pxg8)D}mzxQXYW5vKVZy7!uhv z*Z^Su7u=s_190EsHk2pM}*Yi>g4Yp~j@^0+6YrW}$gsD!8P}A8R)5C8!dM_gjmY6P~Y2iQMSCa5-z`5wG3C974jh=oh6o zXUkrtHn5o$hps(!ZFn9Fu9fqKlq>M*BBMES@q?tx1wLM4eDsoJL)W95A&L}~=+)k) zSi`<}H_X#uETYR4&`}VsVrXQq4htEq%WY4@s|b4o`d}r}lB0^pMeR$P_Xw2ON8-wP zn4vcaX|GaIVH-LxV&62U26opTUX+JIJO`3CoSG)mio{pDVq6tED)Dx1smkf`n`X-U z25UD{cXiIg4Y$#X+p$?-BpSK{?*oR=4m7lSfh-$D0(<)qGE2rVeL}kLQ=(r_(DZP_ zH2pEuK>>zf@4h=`oImr{_whE3WVpvxpK;@qdVeKVS`w{!#pkr}?oXOlc7vxUpZR4S z!t-$k&mcE$%mH_6y3E9!SO)hx5u17m$m{TPrUjXQ^707WpnuZ=dmZ zw|oSyWs-5E%^jYdK95ez&Ag%qDn8l9Aa+rD9pXRFz`};74_A90jD!0OZ9PqsFUiGV zl~7WM48>dMoIcd?{x?^f|8^RSr?`OyZX*6meYx2`iQp}>zmc(c9}ywGGWL}PLJ($hN`QQmxOp(k1^jn{HJL^&au2dggS<`R1N7Ch8*8ES(@sI1gk0>MBxU0T zHzABhiV{(e zChuF9IiK78=~{4}T+BmK{AssYR;%=7WUsY*aoSSlxZ&!DlXycP{>c)?v_pBvZ7vfp>swnb)6(uaJwGwVN&q-=zsnLj)Fkr?V+2L&9e2Y8is%1A0I1-(T3iL{)YY?{aXW{YqR)+(;?9qAUc|b$3sn}#m@!>4{P!t zA1<3xLNoY&0!cSIEBY6y4bKK=F1lGx=Jr^gNqZ=E4Um7?&3f_XJa&t0^*vFuLMVKm z@6T4dX72XTrM1V1J|Gv{cfjIcid?1Fry`_xqtRq)b#{?^_Uq@Ny%aj>re^h?eC#B# zlLS(FOOD3oP+ANcYYmJ^XQab>$C*ki)<@vc%d0V>ol}OsQ5r=$%U|SJVf$LkMbF1o zwJI&K7ragt=zO$YCIijPwGW#feX6zEPS*R5`(F*E%fyfm&h*mgCKBup;x{@;v8}JJ zQ}hmzkgIfo+U7WLKC}7PzxL_5TK7M*o2l$cDR|_dwHWz>k0>KiwjM9sPdBd9d~UXF zXzrIoD5eRvMc$*FX#yemRPR2J7#07nr0Jo~cPLfZ zM|*XW?RD>BjvQf=7WF(08UgcBn)k(TRz3Iv-@Z5&AM6Cyy=1$_V?*Ej^TaoXP9ste zK(wj@W7j4=2rWR}!5hw0D|b`+t|@31sS3$F=MkUJBEAoH6zJDi`7XBx`;^Vq3iObM zv$FQpTTijN_V4z5j_xx3<#GN&(-zyu2fO+K2({OOPFqWaefbexM~&XMbwXP$UKHME zDw@2LhP7J{Sur9fzQ-Me2Ri4g20?j`08Z)KI8>xEH6ajlpXn$1%QW~jK$*C6bthEe z_J<9ZbsM@mot8vG_h(}--XDc-Oe_f{AjXs7#lm4T&>J>Fcj4#*5UD~T!O;%UlTJO( z1tAz@NwrQ4EVdAM_nq+{#m}gQk6Mr@mNt^@q%RNV^?j)B^oSC9Z|@vnB9Pxq36cDS z*o+b^hdQiw);ikGR-uD82Row41oZ^ef2#S1%DjB}vIg@Or`+&m5qOWJXID{hpUd8$ z5y)lr%zPzO$xv-O9at-LQ}^!Uc3dT=^+axKg2_T{xC5*8td1{}3X1*u4Ia+%yr=2~ zruM%ZHrqUZTeFc3rL!g#}52U7&=xq1o!Y;&NKD=t0{z~_FI$wD~*hNevC zE@kuO;Px48l$|2a1l3`Rt59u>!yC2FeJIHPn*!)y@ue#jqsRmaN}wmOlKEd;-#miS-qB$^n*C;Lsw6Gq4p^q|2O<1}0#YqGAVxl#^Sn^w zQ??R|Q(7czWyrdEKYI!aF0H9`U5;{AM|5Vyc6nl)RRu$pSznx@vcC8E#Gqr38#5gO z6&XvAM?aVa2}vLN_;9NT9ocyUE1NgA3|{6*E)vFazT1>gv204yxEk+FbP`=;uD=*x!fDycZBzIvSW6dxF?7ddG!aLn~1g{j&c=6zg9sE zt`YIG;mNI8Mc%qO-?cF5p|wzilMwIa(lQj-kSn0tU2)bij}|cJpO$%dU$gm5I1l0& zPQqXKu&^t4Q>G0gcSBZn^CyIe`Qs|XiXV5G2>gPq(|{W`hV+KaRW3qQyS#o0alY1Q zcy;+K>$w^l37qMfgZUK-9b_q*8$H!BDRMKa&!2S%YL=dc%0)g^pczw7O?sNH{+U7` z6>KC!Kc$Qq7UbA+{VCkUM+S!-*6k(k>A%#`ULGF$e#jFtB?+B!i#hf1E?rk1OK%dm z*$#OfK_PH+ti=)!m0krngj=nCo)x=~y^yd)LgyHi`*8`}sZRKl^%+`meP^}z0U*|F z40gT#mpcau!{V#8APn?~YAqe>=_^jL-)R+}<`Q|OY>^75KH8=vCgSQzgj0;3dHma3TnQt`q-Qk5BfM`+QBa^UtuxiB9}7n z&_!mezX+sB5+Da8>F(HcMiMPKTH^UMJwDXy-s}Nj23`}(eKm@=`InQR%N}yB_@jvT zP0cTl+5@(!ZT(8iSiHk`&q(>$b$_UNjow_|?3SzodC+AgHrBT3-X(tXDcZYp*;c{t zWy_)@z%akNivy;{Q%?H!9E9g)(rH?v=p|&o;0_C$|Dv3gwKIC6qS42pZJcYcMTmKIXx~9qFxONjj>F)bl_|9`A?Ob z_4Pw*%n#S&4Wnzpn*zq&FsL;UNz$qPd!TX00H!w0ZQ<;P@Q| z!B|S`$CjIqVjMYyeQ!Vf=N!JC-_y zQV8JpBeh@)~(jp(T(Wk9Qf*d%FtmoZrK+yYa2hW$^c8l`8lo!&~wf#1N+o6z)>G8OW z!WOF&0Q$r7%uxzSTn>RkmJ>|W_gD;Dmh)NMF;ouQqd5mPa`CJ+N3CcryH#=&of8F` zf{(@e^#YG!xOC(yufr9?ngJ zjzA0c??RBo?vB7`V{?dxMU|Qa*#3d^@?hetE|iL={KUQauuAc04~nApvO*AgygN;y z%yRSGNVJ%I1YzHD zyeXj6KXT$69OIEv1eRg!B8cf zmK@&{P;8>Dg#^S*N}`Hi5Z>J`K1Od0BwK7`Y;QTU8qWri0ZYZ$RjyF%&rAeDx}x8E z0PB57Bn-kaFld5Azry_tXGeN-3ZedVeJrq1G|<{I<+l3U;p!Wjb}^)|JvGNR907{p!bcFe2M^m zj;XNeAiCS6wnc5uRx}Nn!q7#q1n^dt?W^?u%w9?=JsWrDG@HqQ?a!egbsYzQ1zATVhyW;?G>F z1pUJFD@q!(M|q$eQ?p3Nlr8qdTk&SkbG+^#`rZ{&4yj73Kr2fO<2(HVDuj zXKT32?Xj+$mUF)X&5~#Ur}}r@-<2Jqz zL866eJBZWr|GV0?aMEGs#?fSBN5RqP{8F!(uJ4{>d6Ek6;WEwm8~Qu9!Rp6G9q)P* zV)+rvK%+ZG!Y^CV*o*)L(pF(1M&t$=NOyvO9H?LUahBjoEZ0qoXL5x&bP`#P<)T{_ z?M@bv_&i)&C|B-Jm%-_cxdM9ZRiEKQ0Kngtv%e}nv6WX5nVD>35*Xt#8nQgFPcj#v zMM@TEymM8m5iC=f23GKMp6XWA_kX8*vR`QZm7w3adS-b0NqnP>q`n;gpt8qI0I|L{ zvrn7kOO5!Q@OSdMRB^4SUFyy}Qwu<5R@~#rXbD6zS?i89dS%n#an9~@=T`b7&ibNQ zyQ0(k!+5?w09TLzB;BaH8vifICDE}w}vwmpS=@Wa9fSK0@B5kYXJdt1mGhtB!Ar1 z++5mGO!R^``-()wiw}S!fv1ClLsBT`w&w=cT~acXnFy2={d44sc5b%E;RhGmgMXGT zmIi!gxujO~dZyGcVVcQN8p;x7R7v9vYW2KJGH$@$ZZ&|!#tXG^0nvhpRuQswbFFg4 zqFw$HOo?1V!fxbu%Vjm%&P@VvP_p&Dc-MRQ-VR8;dRhS1^*`MaXo8Ma->pTPYA5Yag&%j`#i- z$p~O|9>#xsR~m9YjFM=AtXSR6T~;=aeCcW-M?pbx+y6DCm)2b8yj%C+6>s3dQ0KHA z0Ei{EA|5v7W=$=+wzVDFNKak>z%N_1(fFhKimDfObQ*qY6z$}?)!rV|>%*!Y9ZSYw zlJfbfWJxarYt-3Y=ytE7SaiA6xNR7KME!2HQ_Zi@lEvCl#Jr8HW_$p}WTebQ=F)$m??U!Gl?M(DJhd1n%_W_aP0{6r9Kl~@0%pV$B3vKW5?6B&0I~$Qn znKo$><(VDLlyL<*kO?>#gG=J#vLjt8t&o#og@($7y48^^8H*e6F!KF2 zlt4__9DLe_DY_>D3HTJf@?}y6ye#*R(p|c@cXKgoSC3J4SN=rT6NKE$UquwiuCe{> zV{0P4zr7rS20nXb@s~b`E#fbW1+-e9ZVV_oxd#*Wb^BB7T@v;>BqbPiU7FHAF1dLO zo;kOd+-!?Lo|nllYml$2;oSlvK5Ff|oVN$!OZ*frNww;vJE_!33m&^HQ@)wn73NQh zASXg*ZksR#{rv8)!8V>hlTGKjEdUSd59EW2z6y+Sc5$ANxj&^ekanL|($0~m3~*2TK#^u*UwbW`2Y$I9I2vEp%`lY_b>v7_YVKu6$?$ z7AkpTLUUms|Azc_tYM0)I?Rx8{7%7;^L4suj>Fb&QNBeJ{wc>@O2SkC(adt|QpgF= zFxUc<`s2#F{iStrEuxah^~)#0hwf~;>kxPuu37&4R0#u(T-am1)y8dzgalrvM1zTA z4Y>A*U1i`up^u&)UTtryac!LWG6@eUBi<9;R9z&IN=84sshV05W`{6N;g{Z0Qh+jx z_nlO8ZRB*2To-n;9&T(jnhp59eCG()))3gmYr@1X;k5EE^7=*(x*;F-4*RJmqW0-f zUBE8>rnpfpAB_*?oV*{_+)U%R_*U4OUBp)FE{?e1WuG` z%=_U>{l(a2`nv!*n@i=mG>lI;D4Wn@FXg~sQVN}OwCFalrxDficB#NWbr>hol;eWp z3(PWl%yu!4jk^RBQ1{bP zPrFZb7CH~IvyZjvd2%o`BJg-#3z!8`zTRR`hw&14wfUca%P6kJD6`#g$_VcW84IB8 zD9-xyqHz+7`YpGRR<4w#E1{w7B<1r7r%HZjAeP%2tu;N~AKJ`kaT%x_f(s;6;)_a8 zV`R~cj3uh>{=V!r>~yUY0iE-Ut4lGDcz*2e(|0HFOLF#`#FwRy>=!*0Qj}z?@!y;H z3?b4zrv@D}0SIWNUe}JiuQ~urUZu&$`-vu=>u-L^g>=OQ5-^O{$U}w=Oeg{a_V>r~h`hz;Og|W@`Zmctpx~ z$4N*OUc>0AHuH7dfDubj?tVQq&LUcZ(r9s==;9#P$L!a!AFb{F*AQSvhL3nbz$o>A z(L+(>8*&%&c}FgIF}}et{gc#nHp!PevTvgJJ|FP!Ug58en^Vg~JHuehMxjT`@#_v0MY?J*tK1DBggg}zES)h`?#u(o>!1d}w$DXe`X>~CR`X*`UY$-Q}41L-EFX|)JH8v+kt-z_pmAY z*wP+`dwkS}S$}Geqp!0;&4}wy4sfWWT;1<~X5P{jN)Qf@JHDfPRQ%LCasj6&I9`tn zt~@&5pPh=@%AvgN5`FLoYy~_>=K2!=8>rJm_E)ITTMoCLANyrBhG>!MPRoZC)R#Sy z7Vrb`CZ44H2nlP$QWMg9_Q-8zp`X*E-bc<-4&*zj_)>s?qV)d?D0|33UrU$OSO9y< zCCi;ZQ9A}kXG=p3{X^}r!xLlrmPLKZ=0on94RFD#3Fz(X`?@J^#;GZR(o~yq}$vg zT0leepkb=^EBkLtqPA>4DV=Ayep2C!oPQ_L%3HBu6_q%{zTrG23W(Q9PceOqhCt_t zTO^iO;bz?w-vStwZjeq^j(RMwdo9&R?_ppUI^}-fw#*0V)0p-G5E_p|-w;sM_Fkg7 z5TRp31OJV9+xGPm{ZfDvo}#eFnSGk_br5cZ^15o%$~#3CJMja+l>g&x`@dIsy5{Br z-))(C|2o}mg>E`l!ogTYua~I5hH_Bia)UV(!u*dUA;v+~oXawpWe$`I`oqA6)A8IQ z`pT=}57D*58y!RKG(G@{6d=+?Br|DNwah6U*kAOx{T1)gD0SsVpJcxeG|$W!m^r>P zzrg2!)p~84t=G}SpPB$nUIM(khl#b+1nYj@Wrn`SJ& z@QG=iV6?ns-c1E;;6Dn07(DJ0nl*rx6pv_Le(Q?=>g&Ws?sX*oMCz72@&XajUw1{5 zy)tydj_mu%3g7x>%?%+ToRnu`)G9$RR-sgE%*y0>v9V0O^w?-a`BTuP4!#a#c<}<3 zgM)*ZhWSEgu#{dNo3NGHPs$69*%Rnrc&c9`d+poS%dd@{U82mL>r;y%9IhgthD2_k z#p}BydWJc^VF62vp=^P(Hl==9@YCDSUpXD(`x=M>wMT;YXKghLX94%%f>pp}-qN60 z4u2Vlgo(I`o_)IUefB&Un?TWwA@XJ;4$!*5StiE7tdoR?^VWxUb{gv}PS zm93$irL7?@MZRTtwIZE~j$>OV5iwTLOCVgV0*D-*E@$KXgZ8tIo@8lb@(@*R)QW|S z6HeIY;o(o50gD5ke#>%wN$#X-*0v3ZHhS|i=!c(FLl7hv0g4Mr3u_8>XFgF=@g}EV zqDN4bG!nj^(42(9L(PtzWk%DTwG1hr|t4@3Y_6^GLvbM5&; zldqfsI;h)aof5k3WhyZOOzL85xUrhGFLDqVuk#7b`I#1gb+@i`L~y)sM+P)gwpXHY zN=OyLQ&F0~c)1qY0&i>Wt(YO=F0V~Hx2FCl zw#K{8j>F+HIC77p7iZt%PpnJsNuI2A&)Q&J-&{Q=3x5Pfv7$F0kk0ycp1VF=kd#{C zVo_kq`qt=>2`o8cs|<$0;DL#G4(21`^ArN*kwgI?woJ4^HbxkN20h&R-zKXo`94gXgLE(gw?^At=rE zO%fU+z%n>VfiaY={C#9rYJN|9$L9Kfez0fpfdgOY)r}R$>J={rnrR@xjSmx0yH0} zHPWj>dj0$pLNoVpPCFnrzgaxt0ZFG1qrgY9!zK@uf4Ana_|2;JQv)ebsR=nAGJ%`- zY%9{zkl1BUKk(%E9E+OmJPyzRJos7Kj_A?@6y=CnUT>+DU9GQi`K&UkH!IQSIlK;D z-{{)4oLB97@VU*lN~Agiib(06DnBai78lO9XUj}NNPhIym;n} zLqm@6mkI+zl`Jez+(t|b8~N@A6VK}SMbx+}^{!9q+(g+Z3nL4%-(19#jN4qHYQQY2*)XBW)W{9eD*$;mzF|?)# z27DeQ4ghui<1c?xc{!LSx`-lmF_ne?8DtL#dc}(r2&O;`3I_^Zr!yzVD2q37F9 z5a$Q_y@dv}{hfHJ0@`(lr6rduZ?mJTGj})R86l{PkBPmSMJ05rpU5AI)g3YHprk4y zstOp}#VLFaR0lv^joMGOp8el8x82k+OAA2#Qx>@1^)Q;mwx%Tt(cGKQ%>=M12E#RT z<%ujuofAYYCVIFBpW)%CEtjTWtU%UZ~su)yeN3I`3E*ZMBRV2n=QMqbBEZ|YRR6kRWVPV9%5_?%9H&-Bmp=6U zpe^GyX+xiguagezuR#YTQGs(X9TACL0`J`JP(DQNGetDcBuYBJ?*>ct+{z6xXY?hp zoz62fq^^@-liKirD3s>RbX#|T2rDnakB;%TR17(LrHQx_U4*9BA8uckI`>r996ltg zH0V&7Gv6Y_d-1^LtKgHc)jbOfYmR3ye9==)k|&eGN>y5C`5Mcm_?K;kfvQc|@B(+-;fHthwTO$gcHdXXjQ_LaGQMNtsmK z_S-g%M-8=jQLr>69cbB9;l1`1%0X7V>Tl`^QOJR|$|As<;reMa?{z-`S9!4x3K>R6Z&g+^DOdY8c8JTxbJu}@At74*sRiAT2G@v35-zmcNjyyz1vEkc>(R%rMF(k>%6RCu8Zy@)Po{~a@8Lt zytiLcZ00tFN|lQenPtie#%?0Icgr!w`8H}{zz<^i`e&K2M0BuS!&efiTrR&e!OvxE zhOVoM#!c)FokCh2Pd2#?y4)7H8oCzTmhDLg&6sRNfI^4lnsq>_6l=t`XUmUlRNCdS zY2mHVH1>RJwEY8cAm9Op+G>GDXma%^(mLDrX(e+|-^{*?!0Y=1chXH#rmk9bbA%kN z|HUncK-BPrG+hlL-3izIRc_zxD6{{yOWRx7*NkpXX$?BCa#}7;%1k}*U1<1po{Tbr zst#G6cGu%dW5ew81)`TbgkDj+(d5>8gP2!9+^bcW;^U>v%jO5)GwFDJ$dtw6~3ij(;y2p3Z1*YwGXhgTMLHloZbI7jP z8+98Ii)P}w?h+EG*Ogj@uC1wj)z@NHhUf+wr>{9muWhh>+x+StVc+v{RBD4B$I%&E z2-jI7PZ-xuh}rBYG06nhf(Tfa_1faUimQ5!lzGD6S0+o=I<9lFxD?7cI+p8_1$99` zF$D~}-}8@DD8AeW1@b%6`CVl#$bGGa9s3eaN6O@!fz9_0`Nr(Ja{lBu zLz?~xoGINqGgJ2^8A*+@(J{BBz<lN{j5KN=n7RzS|?6|H~9|D^&iNn1) zW=rYUEyh2iWFSeW1oWzWzvcmJNbcQ$rXBJMV`;)@^-#yXutJNd>CEqOkvAMEm%Isx zf;Sl$W;;gXs0ga)`IWw!jrI1@y=Yq@a7!H>rHA2-Ff@9@(bcO#CMfQ(X;tEYphs;a8pA>WrF7#qcT1r#rbKbZ!LzK_CRzI=h;Kj|B; zBdu?otlWRy#X@;upZB;f?%2YI6Tu#e9-`njHe2PrQ7z&tR`A&Gut@hw`EYU~99fwU zXQ^BZBG5k}(eG^@=#~v2 zFQ;M3QLYZ^4_yfO4BixwvG@b~C0>8j-Z%mZt~-ETAoEELzIVq+lg0WGfWyS!as?Qi zHD)-1&O2+9=j)e|R>G(KuWGh3BZf01kO0%`)f8xCrh7uRxMl@4(exhXNfT$*E^kK} z&|_m>YhLoy=>8p|?gIF%J1$zrK)&h@^V7V$KlFWUbKgwa%uF4>eyQgntu#(>EVKEQ z!QlPcS4LkJvsjP1=3inghDFV_HWjpP1Q~?-*!icBy}mc0T4pv!A*-gzb0}_UJ!0}{ zXANjVCbjxRWMpiUd=K5-qFv=~w_7S(l^-ThNg*xf%Yb`q1Z>+VHY6|duRIwk+ja#E z+jRWTeb9#ci&;LPKjIM!0At!|m5w2g0>E+awdZbOSLIb%?G?bF#Wn|1^y>f+*9n-@ zTZLctRP!D5jj(N7b4aB^DPLLd!X<>YVgNlu)&wO!R!S9k-(627t{g6h*tI>1xXZ;k zM6@WoJA90#WuDQ}$S!K``axl0aige7^XqIRcT*G!+c+kFObxuD7i-BU$#_J7|DY{a?yc=x21n|$>uV6?4L9)wl zQm+Yh&D!0=`a={T)Xp)66uE7?jIO2zMm))n1_mK)``8E z=O7$s4Sf!}k>3#^WmL@grLa-&wH7O*=dj?uNd`Ti)GbbopA^zLpC-N|izRK}3-!#Qb zpsF%Cj9;oBSad1w=y?vtW2^{c+*5jAx6@Qp_(5(!(Yoy#+o}-2AM^FJ5p}})wuuR~ zx|4wf`a<(QbWa|J+`nDuY0e>m1T3__aLFhnHQ=e36;QuTDp)baXmW{K!6Mo33WyRM zHomK_f!b2vokpsb)tr&a%V`HxHCAnr%o85lvWJXs91$4le!7aVewpR97(3Fa4=$!H zBAu~?fSJ)aK*kb&aBMubbD`n*Om#5?k)UXAI1*mlLWAEIq?9XwX+M(wA5GRF$|+uBVcsQ3k$wvcc6%v ziFYPmq1(pAyh&6CFPYmqjz7`KUPp+R&<7rHNIfWQ7YObVaJ?d+hQ#;p;s5w`Co#Zk z0M^*Doo_w*<9&(NZo9s{j$vZc9<%G-WtX6_W(khHSY`Fg^n~%b@%bA;qMz*(0@>=B%s<)H;!0Z78#h~v@nstU6kY{60MSs^ylt(A1T2zyTi;+XQ zECk2ZI9W1}2$u`q)T6!X(ua??I82i^otRmy(q60c32_f_@778nC8Wf+#1Y8fc2!NC zgIR`P>M0ORX8{-H#_@gAVWGBc{>`_ZL&;XIWdIXK?ptUa=n<3)+}(2BJoG`Wsz^=TG$RU(j# z4kbhmtUD-VQjqz61aIEm$>+&A8P*CRPI_zrI zw<>|49EHG#XeB1`MBaF0MNQB~oS+Urk>aW+>P^`cT@i@0@NN#Gc$n`Cy(c#CI+o)7$}gh? zltl%8wQ&4FszA+KN?;DaH*sm&g@6ARzzB+)Zxow$Eo?=+jZAR3daNToQG;S(*bZkz zd)-Gt`^$E1Sqc9B2RXw#6$gQ;+&!7Re8ny72{?}OAOf>gwry)l_d4cHS~~SLB>nHP zj?Hf@_4D_H66%NzXXRsQIDjBtl zj)^hG7hE>=$iGldAl#{%XKd8WzuCeYU*_ZSKttxv7 ztJh@ix=mR-Zs_gjk&qr0&k_Eg5&0s$g3)LAat7SmZM+|GmBkDF);-PeD1)v{LYZyZ znC2zo@Onb&>Q9G4`oAUXpZ!qo4>x;avO$zsZ^yS?^)44&vu}cBe~LAnc@H#@6Iy;U zhPbt(isAbLzF*fQWyCnsel zz?k@5raclG`uH2?VSq;rd~W5Q9gLxz@2W|8%or!O*W!JuEDVHYnR%h;Q+aI`wLmap zGlO2^2Q~35sd@LFweS`XbB-Y=Lp7Ah!ZwdUIB^if$zKe{bALbI6-I|R_)E!XslQ5k z2^n;Ax3fUvAEQ-@b}mTe30+a96DGg)VzQgroKBA<MdqbyL`dL#up0;UPAa&>ZA#X+?!8H< zwy>OJ+kW3QhLpTw_f7uLzLD$J{@^9PTVJ*k7LvF9y#(vWR}L$Z)ki`U&lM|sLS%hM z*jX;nNddVByKVK{7;TEgZk-b(B`((FYjv`Fdmkgxg|L|#@%1(}D1@R0AE|Gu)sLuY zX=A_MVe_+()^lE@ic#~duo9F{(wI_KihFya$cq=p<=L}2qjOB$p+D#b8b;&eU$d=J z2z6wED;Em5mD;#5H2FSbKfKTUJkmNkAQ7DT3F%#F5q9@?KanBgu+2BO7=o};q9PhO znOVxLB0Rgtc#jecR@1JOG-ENxD{}F#x2H{fX{?F(WMfx8$w3|V7j8HMyn)I?+vr9^ zpKFAtjle8wHb3gFRwCwaBh9kc5B94UU5X6#a1H?S;p9E9*@H*j!Z7LEGDK*;e(kPk(uVk_5JRx9)t4iDbVMtRV~#AcK?jM5A~aP zY$oz#@9h9|`a27yn(&pVCNOwJ)&L`u{QY)lpG>@7K)GAtBu*Azey`Al;JEh)8$m z3?)dHgrt&6w{(MaDIguv-2)86d(qGLx8Akp?>qNC=bR_^-ltF@^JvSdC$(nAe9Cht zb@#)%TC90Du`pNb`74%FYx8-4tB<%=e|;fGB};-zwQO0G6WFwOSqrcuuMXEl+8yR< zJDF0A*1Ckg@-N#ExdUmC$NR+9?(iW!**5mI;5`*dg9%o|Y>Tpd=XYQnq`kGm=0cS9 z%=V{q|3alrS_oTRF9&}vJ}#^E&i5@A?EtLLWabIIRrd=RQE%?ArLFHCLSA#jZJz ze0pX*UthMj+2cipN@LZXyMC2iX1Jg=>oScaf;F3olGJbym%$Sc1rWCV7s!7+GF?JA zZK+fggA|7tjMZ!!JHSgsmP@Ecbz&^EfEdBRF{yHKJ>=l5-}|Pr1UBWS+I4`wa&{Lx?yoed%?`sW@^u0Q-i#4lonDg_$xVZDz#@g=i z#bYwBM3J88t`iajXIS35(3cpIeu1 zO%w0*v_^@Qql0BhRn6pis9K#68gJ3`TWfl;fpy%qldRy$Y90$mM4R;QWs*N&x5Pdh zojAqCEkY_WsQY&Iu5p%QzC@3^a*W!rMFO^hJY%fL=`$yJ1oTS!n~&P|HoFT7Vn=F; z-p7fZM%YmlGR1>WOkR^e=pUpX$vPp1a{B~I-05G#@q&=OcJu7D|`Q|RsQcnmqPWn^FjlCA$|2&Q1J+F1scz%N~ z(oh@{%2%B$)^j>uD^_1CiuvpD;D8|^8417idpDF%wwW+3Y|{d$2V+NaiUe=g59~hS zirhR*x!#;}r>@^tRyph3iB6Hf#FsHp+6lc!5rOhsR${rA>)ayl%F}W%3XiH1?G#!HT<{pr1jWb5c4fswd-^woT6rP zOI|Qj1d=IEEs}k5p}tyOPb!A-Skm0uIX@aBa}&{J!BM}b{fqT)7shE=eXr1Gen`bx z^-Ow2k=^81WZPX8sLEzVNVI28ffcIc<>^_u8McISa$4SMRM5PInRftO@Z+j7f(JT$ zV9SSN-l5Gu3Fi`x6eLvYw8oDlr^RotClb9T6We-jU-4q%CRQ-->EGleD&ja7$?Wl%m-kMzG_hUlPG;gs z(^taiS@w|z!?(UE$Y+-MPSp`YJ6Bs#XUsVT=r*7Q4%IHB(K5xPeFhDYW-AdwNIL62 z-91KVHS>nfkBw(0Cdxy?2i8^VpDrHcD7Zd?=-KAB)h8Mp7e#=yj*6eoRVhd^Pr-JunHpIH>9(4_cFlqTwTwjO#oJB{GGUKeRiw z-%4gK)UdZ1HPJ9iqy_5?E zp#trDnDDZ=Ij^M{g59nMv`Ht2#F{aFe!sq}zewqijgcb5Uy(*7(ym21&5I4BiW&N9 zDlHt@rV%U6lRWhz5l4%+H!#eqCPNmZA5K2HW7$&9+LI#)5hXEq{3qvi{Q`PFz5{!u zf8MV&Y?u_?SJu1|_0?^n*-o}65;a%seqFm9`T`j~tLMwU1~1@!VvF?2%>%D}co1?d z2#q@ExgzGhD&05aS9DR+efL2${Zx9cna*J-$jtHg{Ri;)o(jyzX4XZ55QBxF`!zd- zfc1*)4}>?IW;cA}JW#T30jEfIZ14P&27>?Ha0dXqSVSUJ4Pgueu!FmscT@h)#8soV zmd;xc6b+S^28UgP+?i**AhJ=f{=s)Bk(faH_E(s)uEpV?INzcSaee`nR7oN6#zxZn zt4*%nIyXJ@iN>uw?U5`Py$3Uz+@qJR#Z9l23B2t$%UqgO4mko4Y*BS|@859IvFY4f zuEhtcq_0(&O8yA7ihpg3+oC^f9 zLIV4R8PT@REZoWiw`_cV!qZX&4;CiR-V<-SsBN$h^Mxv#0BzO{HGSb|DPE>rZZHW1 zK#0RU9Iy;(9it8=my0a&CX2P6QhY_6&Zg>uYUDFNF#jr{A2FqVmS^`lvm#ijMN0mf zEM2Jny}-_-3lu%zI&s7Ad3V3lYUHfwq}9T1W80@BvR6#dk7iBU5LsoR9_)qV z3vEhzrmqjbzWXi`Y;-k~Xqy*P%z}{}ryIy? zas491(I$?5j2cKNEx})q67g=IJpD2|ObN%-b1B$ZGr6VvW;)x=6$<`@T8U3*NFtY; zPp~pqXtYQttGn;Lb-&bItjv|t*xI-Dr;TN`a|iq6e7CGv%o0p-sno@lzT?yY5uLjc z`D}732NTT3FjdO6B>fOKAYE`#v;H|ZZ+tx$_w_5RqZUGpB2>*3XZP4uin#vGXeoO9 z6>Ht{cbyVWM32)Yga5q#>Bjj!doLT|I~?FRno({-N+nq&w40L(i{4W;t@st-Xk$$+ zP%ijaQA)spoAAh^bG38Jml{C&T$!{3ooPYL%Lk4~;<(s?VyF|K(+5K~`y~kOSAt?M z$tzCnOH_@Hq%~0Zz1dx_6GX(EjqWp*;3oGM*hsWoEj8$;d?(v@%a!WN8i&xl)4$9C%!WT4LhZbP;W4O;w;YsKz(#l<}YR{ z&1J*@i4(qqKPFbFS`%(mit`1iD$uK9yP0x6)#717Hnk&5kd+}vq9910!!e0?5+f8H z^gOTf+A6LEiDkq94PwNwD`*|@{ukvBgTk+A4YrT6vm3CIxP)m3Vp8dz9?-pqp?y)( zHj8Ep78hC`q-2SEdHl1AOCBvMi6DFwJX@rmo32+D$OXVGa-A`qr7g-! z&C6i`fj!ORNCoL#RV5>k*1FOI^D(>VKv6r?H$yI;QU$c57enNJtySLU;&BfzLm}s4 z9ex5D(P^$z&CCcMeNlE7eY;5ZRd}YJyd%4AG3b^&~9`X zi%Luas<|O1*@2Jrhb||GsUpz~^gN&gFc@;zoB+)j*xq)@L~gf+``!d;!JVJWtEfJO zxWy=};$W-y(=S@i8_h!;dl&hG@~>7a#7WxHY}b+|hKE`O7^9wQ!t-q$1|FMDF++7F zJc#_NV~gF?SO*g#WnwIVO^$-1qCf<4AF7tZ@DB7PRk%rckk07MMB=2ezg&F6aFFT7 z-E-f1CDFIUi$z(_(e<@;?v0iIV~JDs&cf|Spp(=#yta|C&UPORxLjD1|{%X3r;m`o)WrUZbGP1?)m4I^| zALmB%qQQmHc`K`9XOe1I7_t4mC1XbOJpK6WG0=kn#v@74XRB7ATIQAiprL3WMAFZ~ za$-Bd)iAHN&X^g~P|?E{;twC$87v2l7S1m0b1y{Pzfy#Y>|DIfcXH~vep~(S^_vA2 z1H6xKed;NymE@Dp*QY*w5?%@U0^|!TC;iBcF*J{yBhfA|w4zMqgzf)hA)!7HJ3F20 zv}2^Im{y1(*!TJ$!uGE6tnR1A^ZW3%OAIKQFAw^ckZrurzm&7i2U6M=>B!%2OH`L; zs0?$=;Xb2HQT#NPtOk{l9k&$Vs~kvOi}e#pPXSf5vE>C>K(Q^E@OK2Q*SMj5^$%DK zD$t#|p|dbiPBMVw&GEBjdIm;S>tII&zWz=Hu@{R#NuM7(Qh)IgtDw^5ie`b|LxRiZ zH)_Bq!A{zzQ<~gO)N8$iTk=iYlA6#ezGc(>%IF}PS8rc)&?qWO{fho-_H$6sHkk~Y zAA$Nctg^$-njFA10o|RJ`*ntN$d-oyxET_VJ$HNfH#nlw_~Q%R?(YY{ z$>e);X85eV066di%?K+Dz-~x)+NYBR9Ww(4rS*r=zsDty{yo*!v2>8(;<`s6WaK4T zSit!rriabi#CY>!{^AjQ&!7vkqHH$*`k2~;usHYfL^-J~kpCT0cQ|^ATgLN8u$w&& zjHOQMrS(#>fBfMeUf^kz%&?!Ipf&ol2?P-eV3QL;#|6so@Ka6e<+CR(RUS-DsdFC| zFTU_{UZC%HIwq@DgsRPxMr=0=@c@&vfm0_TcYee{lX!V;CZ^2JMzH7FK~@_idEbaJHkRF4Td`0P$8?y*?T7 z#0pCePV`OL7Z))w?_(rGYGAa1A^KoM59R;dC;<7bPVtFA-M(8{oaOI)r6 zT|~O{%0fV<^M7egSXofSG3R-7H;)WJ8u8XDRglZ2!z_l)H?k}C@9t635nnVD?zRmdXVC=K&hz>)@LeLc*Nq>&!=}h%1TwDYM+~b#|6V zTm9bdB3GsljaERECy0U3jHH9({8Ve{VG5EUE!Df60J?$JCU#M@UzYfgY)&OblVIst00*Cdj4=Pnif3EsbU8BZDrk}QVSPTpH;Ht2~2MjAvc z!wsFM5vmGCevjccUzTBnRfADdYZ3dur@meNwgR=w)1sCqZ!Rl+tnRPk2aeD~bTI_9 zR=O{z^*f`b5CplaTuRLE5Eds#))UplOJr90~U6Jj)ikF2Fa1V>)|CD{|vJ+NXh zW)h!+WmTm_#k5k(jG8gk@3)Wv0;doTwFrjc6|)6gblKddQD!@e=sbEimaC1c8Y(XU ze?}Y3ec_pF&u#-&Lz`AM?cC%9|I$qGIvsfZYpQUSgZ4ije-IC|8JH*9DVBKTp$|k1 zQL$J})B*AVZnT*1WDXfG88^lQhhNta$GPtXv5k;?aZUHMW`wv0R-Iq3^gJL?W2#+J%8?jT@<6;1f z1qqv_kFf!5=R_*T&FBL*CqfMfEVISl$+08$o$eXAo*YvHmjVbt-l}%b;5hTFSGg>W zE%DPp^6kyO(qI3*^v|tzWj{dD+GTa67@h4=z#-y7j5XHZomedFY-P2}w2?(G9c9mS z(KSj8u%A3hmI-Tp^K0G}4<0T%b@d1d~(Z30Qr>@U6a18h5Am3sW5OpS10qH>iGi4!xbv+sz$EOnF%X z=-KvZu=nv#SZZ2fY(K;TFWQm`+KC{1@gd~_w;f6}q){I6LyeCpnM(qDP`d2@Q$g;{ zh60G9JI-3H2NKAgb5nVgv3fe+4LXEmA8h#Dnv4YGB8B;>eEqqqQa=j zgWf1h^_nUvVP*U8n$Ff)@B(UO$ndM`nzu@+u4#2wJh(J!y-z2KepoM$MtmEmMZ5#J z>ld$9hHN@K-k_=|4Sti9HEVPQo%9zhu)dDhzbExIx`X}zsEp1#sZ{a z_lUbR0Y7rKtZ@RXbL#i5mg^eqRcwccnk1?K{>W<|U*LRZpz>I@yP*c_crv8p?g|y7 zb_Xe%oKQ$tnQqMfX#*1BgPYRW$4o=z*(3S%4>_4TOQ3PmU|8y<1$iCTt%+G|l*i}E zRP@H-RE@(#kn-v$1Y5H?gZ{Z%X;3+3j{ceMkL~5Yn3`R`0VV(y^Lk5+FUr)D#uFI- z0WRYON-5@dltq;irMCoTrjm$Hi}e;Pg~P@nKetaKKCp$&eL&YNdL=OUM7>M+gWhNN z-MEOw3lV39PXO|SV2>8vbG-GsP#)1Z}gNYzU zGhW>{7vI+}MXM~gXrA;W_hs+0SGDfEqEbtb+-x(%XI#93TQ(^}+OMU8OYese+;Hu9 zf_8NY%;9uo_y#v1L4TX+Y@wSCXSQ80kJp6V7Q`PgRuVL61;4 zHuu`|B%5}W#LcP1!H$1VIOr~YJ$trVp>LYvo>aJe2_Ht+Hd_uOSL`!Qf31=Z}XO! zhWoS3q)Bv^9>x1Mw2{?Y7v|$y6&U=pe_P^5aWuREGL%H#MQl6!fV-}nHp5lG5 zfnn`i;y{yxP+W0&xJy`l^xsD zZ0s@2G7K-rRBcxf>3iVOLSs#L=<+{HOeCi4U@W?&Ck)Mmy_*?$#^aoz(T8CRBD`-QX_wM|vjR&^_+Xlw@kD>wwwmefi`!vjjs|~KU7g>*P#>NC)zPrV(9v~Q)iT!JVxx4+Kj!KT)E z$|{J5{>wXaq}uv5cb))(L1PwxgwywxVC=WVrJy&gm?~AD8&&~EYgg&(nzy41 zSPII@-GF|RW}~a+qxx0?Aer<>6YgII^p}q>AojO{I3HL8 zRtq4wOG;=}2CZ=aqJO=nQ3O}}`GF?fGmrqn zzH!ZVUmYP?1wZ+{QS_!43R^$xcO!)ObUL^^nB5EvhYaB@&n#U9JYo0%u+ooXa7i5f zSy|5kL5>viXUxoK%l>DN=IR|XcKU_vDYC_U_`d0-{vz!;`4k$8a%IN5kV)_(1}b4^ zr&rNu;nWh69U!0k%AIO@+VBJ0$LC+*(y!YbbRWT-W}}O@+R+fc9@DwFkrE<=fgs%a zzZVDmk%rzoL9RDQ0Uoyw!a;8JY8lcWPhTzeF0?6v7|J#ZGz*afw|{c~tih+G^a~A( z!z0WHzg4_10eI_8&Jx=4%Iw-Vo1{J{#mlhaD!g4CZn5yME)%Q`7MK)XvGQ-$lmvXt zah-3?(|pCcPZ(g@0ih;+UQ zGzX;S&$0{h^Bn8Gnp8U=L1}jKCS6x%LC-XRiQRg{eN|oT1O-rLJAj7> zvfa4TQaNtPbttdg&x_(Aef_o<1L3b}<;LdF@jYsdgho(vOi1>TH1`!(_0h!e+D7+% z(kB#+j0~Q|pu-6MqMfjG%f(@T#D6BaHEObi6a~iq&VOb+>vOfJx4F37c(!yk>Tt7` z`{H8W|7i$i13Ut< zU#(9OilP8T2~3B2oNrI|L9UCB7Z}&Xp!I{QzwQCk#{m)_d5THQ%uD6QqRFYLl=Q`$ zqeDadOXQ1J+>q-i8@%G3TM~*}sfMlT05XMK5pbvF+ttKhKL(fICxkRBe^%-`3(#ZW zV#H+Wi-U?TJ-W|;gxR&>T>2xfrHqVnvvRVNv)gya`C|UQMgp`IeecyhgD?Qks$Js* z#Q61$nX*3ybO~=oq!zFN*vUS*#wS4(gq4Cc;uFD4CCBTVe6{vJ$(yl)jC0BgWss%Js=#)LhxHhTyDRa0=DU80p{Vq7=Asle1qQIH#;W{@ zSMm2oNz8a8%=iv&d_YR5*a@i$e4k)FVTyjSriGZVl5qhDHV|)#CyFysEgOkAJNAEX2@Jppma`K8`Xq*X>Ne-Nu7K z=?jawcB&&8(G;ps6(^ihzT+b^zpu5KA&{H|CO%k4lfle%IBs%BX1A6IN^i+!V$X^{ zP8l894iENM$!VuK$oTcFrhA`jF8cPiDA(jpf|(K#s{j3ux*R@U^ijXylm-uXMFQ|f zp_t>y+hlRzaq~_x%%5q{3;4#LgT83gy@Go=Hr+qWbcJupcgK9G6iH)09^p^0Fib2& z{M@s$x-33vI74~Z01tb>8D$pjctyyC`!Vb_m@mKhYp2ebsahacVZ_$hfnvCzFpX-% zXt#^HL_~gyS4U1f+ZHCpRGdkB&cOww0`tR28tvO<-mNRAp;I<$C=wt^m=Mz zCUfkv>rV)z%Dexsq3$U;1GsR$oHFJej2fEmY!QgE00NYymM_ z+zs~p?_>{NYW6ioaDt{ChUB|z<*fB3wQgY|-7VIce~*Vkt}ebArBpx8Khv2(z`8IU zoz*6dAj?KdUUn;SK7KV>e?_K9WzY4{a@BF3nsNe2p>3K@+ilHTtiLB#s*GL=q;IFv ziT%|B<(u!pH~F!;5)Cwa1i`XXdUTuCq7Z*8W?!u}lHLp9*rRD4I>bwk>L8c08V2r& z2iqy+Qh@hSjUEv(CqU`x4z8kMx999I6LRRaGU7dz9>i!l_mRZ43+CqN9Zt93bJ--h zM{n`1BPCNP9*kCefC5WA)Q_N%dRRJqy4+uFufq1m=5S})O;9Hfdm{;~y0d(^)od2X zR1*am0{2MZJHb&Jdfx&9B+_qZ9M}RC)xD~MgTVIbg38CwE9$GKoogH`00;P&X@00KqB>f#!Lu*=RG24oQKl@{P5qP zup-F}o3$)3s^oran#ulkR3T8~S9)C5JuLB5!C(;^G5I!GhlGxQG>*pjmf`vHx!UVI zSk2GxPyW5YGvEb2zvDGIN}x9j%fX#VY@qbM3xKKbq3UIQO-owu?&#(97AQmhR5aax z`GrbA@nSV}i8|kNwi}mLz7H&h{jjD~Zs2{CHM$}scYPj`rzoYcDR`hHDWc>sWJr{} z+@#}+@eWzeg`U}Ka+(<9h&?S=7lVpJve^=gT!&n$AxedOBBD*;T+l6GR1!v5Xg)V{ z1)hTmJcm7DN7N($714ILLLorzO$mE$OE%pHfrhKm#CWOiLKjk(yE^&?$XK%_r+7ox zxF0KFoGAcS=|h2mkH}L0F7m2M2bN@pd>lH=B_VyftYvQ}LhW^U@wVh@{-y&y>4EIs zn!n23RSj9-nzI|qG6km)*FPuG6@-LpCZwO>Bg6wD4y3qL8DU<(-Qj`o@^*}Pnd`)1 zaZvmsr#|&HvvSlboi_-grre_E{VWV<8`$9%!1&2vFCLbNUcF;!%!7O&l>l3sPYqxUP2*j)3~4*Ea-#MrA-8)Ttln=U!`%9jU$^(`HM6WA90(3^L())bU%D@J2?M2Knur9WtjG1*8y1jXpTVm%GHI$B~sP8?TBGLHjD8o#> z({dXBdZPeL{jnV5!)Ecs=^aq?6FYfn6NhHpgAKo12RZolo?@S5LKWf$h;*t4ym@N@ z{(RO%qxK(Tw-CO9ftkVtMVC&hC#0AgAwIy?1mfp8TzK=}j{1gDx;?cwa~UX-<&Jqpr#hAK8Ms)H$sQ}&y&qQHY23%2O=vFMx;LgXe&n8pD%(A zXxYAxVFzJS-095?4Cq|y6mS~BwjrZ9rxq=Z{$;A>1}_=SHcGqted(npPBOREBUXg; zz-Ei?U;jk^5E-EV9vKEDhKLLgy-n|3Z(KbAe_m3fd+eIgOrACXq8-Z-}9^cEi z_#g{*P_wK#d8w=Dd!Zi2*~K|Dvz`2;JTkvr*s)F}4-XQb2Lf;lfqt6NL(&42>1Nz! z{iA=QV@V8(y+=jsbK{EL?avMf_#?hm5tsb@kr2CXeESQN7{OEdl}~NP8r+nwj#6+XPvbv77oiRKxdhv5jUNs5CpF@3R>MehwIGUTkMD|F11uAg zF=mQu=stGpYmi6j_Nz*<`L3P%GFgA3iK1?H)vS8Q^lmuq1HB9ydMp7*aV+OWC*H$a zMF@dg?nRlch${i3&dCA&c}Mb&YsH;aQTO8W^8j%1coawShG+WpOc?$cZytlo{6SKKX9b@R_>4k65jEP&% za=QtNoFMw3V^D<|#lqy2myM(7f`}kUreMAQ_E>QGq-N@a;L}Ho`w6Z9Bi{C z$XfF?)as}Si+xkQH=?cVDI&Ix`a0lW6s4C3vY|gJB+R5#`xm!6EGsQ%&@su1!Ay)U zHu~+j*n~#EptWN6Qb&JY=LK(~-i!WLasKflOU84XX^bKw`#$3K=Esnh8|i#zmdP_x zX4%l+4=V`diM~K|23khh*HRC`EO`lF?||G7as02DoG*X#RUA!AImN7PxYo|NFZ?*$ zS{;hSB%^I5U>8TVPW>YrT>VKC~A0MbRr2E`_RzNG0%E-%e6{I{CxZ`xFKLTYpFAtC)haC1Y&m~+`JLIn9 z@_iQc+uf+!Op?QYiLM=A=IgL>M5Io^$*Xv)GH0&zXXBRtg)bE-F1X(t=yQ$CIx{HX zp3C94+)3REv673@frtYx)k{2|m1l{mNs2rF5T6{d2-d#Gwc%bn0+H`-u%++mAxPW9 zD%kJC)i*7NyrDMMH6{F~n}Tu)k~9UfsN3UTFb`Yy)LzQ^up)3Vaws3tJ17bVNMeW+ z)5hTFB_sExn-kbTA<)t9&xFV@Rmp4(wdnrS{hO6ezOa7(oz;0A6muqci72SRPKCI# zq;tJkp^;3hXt?8hg>q$y2WH}A(EFG@h6*~on4~+t3MY92h^Aae^FcKIrRaYt5M59J z#&P2z(gz`sb(nF`xk$iPmRJ18k6{Fia%L+XL2?S#tChFfy1H-RuoL|@e_`Lt{SQ-W zqN?whqpXX*YPsSvzw6D>y$Xbx=a%KvUro7Zr7{{snqJXtssy%EmY=iqA`4Oc@W-3a zlUq~I8R;9p8LD(ypXYaw2+8yxE5DXd6fHwK( zVN!>Tx+BHk?S5A?xw^SknRN!+4grEwgv0f~T|=}NE3(I_f34x2my2FSnZD`BpGT39Zzo&rR#XyUei}ELNJM99o zm4#s2?efcWAuhDmrrIwsQ4IJ(*xzK=%n1y5V=;zNWb2rm?CjViHIV`D_~!tpBRnEN z^mtH(XqWzhF+K6JFRxh&Wth~|KPZ7wo~F`vmZo{@Z35NA*M(ZE3UAe`P>HpJlK8#= z3ME}oHW3~3)cN_y<-2w4`&5O9uTfL}sNak2ERyx66EI)2=-0UmOZeVvJzNyBJ>O6v zoLTw^E%2f02U1Qj=yreHGCv%UQJr3?$0r)pc=C%FMO9fzX-;b+BO$@OC9J8va6UJf z&p=c_A!2*i{MMrS;I{Ug|G^v*Gx-qSQ44D31*MG#n@M|1ddDTb)PkkddQs(QQq;~N zcLF=I2ewVDPcUqa@mx<;6N1QJsbKOH^Wr zE({gBGm7hRRVsR>AJ=$!7Tr@&^%9b&G4|Jsu#!LRv)X|a%)J{H^Yzm14&!9V5<7|UOIe|D$kRIuxv(lIi!M~QT75Z5}WYBJSF zE8Emj&xvdw$pEG2k-nyx6uNOduqJ8cFTu5{yHnVjG+(-m; zTx^x$5jb=8_Q2IUFvUW8!zTBqfV#2OT*e5Xv5Fx=>|&MWL1F^A^H^WK-Oh`9(5mYm z^aXjO$Hzl~YZ-D)+XKk}HRabcSzRV@8)k>N?etd#^P2CSeO6J^Te>OnL)pN;Mv6Y$ z?x;)ZY@Uh}BJwj#-1J3Y$2^O)2!0Q{weQW~r>`_AF2s41xs^Lw&d&nKHp<->&IkHg*~hvd z++>k>mc^w0$215;QAw4@1|Hc&12M(3VSy#x24RU;xqN^7a8h9TD5yk@4tg=6g57a& z&X)Y1sNJ3qV>#9=g<(mc6QIg1W)Q-O|d~izYiI+3N-A0x+ zs_J$Qxkq2{!y9kXp@V$nTTtG4Oy^Tv8by{|IoeF3CGJLtcvX8|-xJenkY_1}leo4% zUY4y*>a_eGkOwCv&GCDjsC8LxE6B2f{x@v^5=kgjqUa*7PU+PLJ>*z(x8e&dG%E${ zh{AaXPx2P^o}RF7lg{$J*+7bH2n}%M2>EtnRB6O4`#h3;Bf+cy&X3exb4=wH<+o|6 zV^#AFn9I7OFA78{%O(@60`zbIrQgWu1N8OJ;WT}~U5qSi+{6;su@Qtn-xEpxbG5M~ zAPGIHYJx40d-Y)tt~OopRLJ^h0s(#vwRmt1+uCes{VNCdN49{(Yp!^%dLN7(F6Vn= z!r5w}`)NQzqv^3*(idKIx7Aa*&#$$yqmKc&L#!)@zFwp&gs&~LmMBj0l+3a9z?1Rj zr|eimjoC_c(qBIo-;n@hOx|yO-Y@^I&697dJ(}7zu@SRF^VG%5B@TK-p^#l77D}8E zR3>97Mx~e-NZmHfy{bm}LA$hLfp##Vt~0DbJ}y*pD2y92`#$-+Y4Wfdsa^Q)CCc4UYaGJ6}7A#NTsGsSJfJ=!!j7#6ji~0ySm5ZBvdzl+4d4lmxi$`~I zFZN^AJpvQbeczEq7p&Q}@8r3j^)-lGDx2+wDvRIXvs!E6B!~qsCb3nhO&$IBPLTP13aOq6#Qq zfOLI8{G%Zazn_pF77$_$DZ9+~D0*A{k{)if7V7(Xu`;2|-4kwMDnEV6X`Qgj3uuxe zBcI9-B+CDFBU0f0SR~YyDhuQ0c1777;=oI#gJds0lgEYX4?(b4i>2}x)9qFuTCx-A z4omB|Xl#7_eZ$Csa@ph0;DQZ_pC7JPWBAeYAVcfU%BC0RAllYDr`n1wnf5~mS(#UG zROspdTzo7wNEbvaI$d{`QMz5@aI@cJaxlPxOvsGX>W3UdAv|-$o*YXv$;`9$9H!+6 zEg7&lZd^afIzPMv=1*7&aqf`EH~n6{2Q2O~1$y$P%dzUrf###UkWTRV_R?a{&b3?R zf4&j;sK|$pVlbv`XZap}6ep=f*`_b$cWFRSOf4lBgJslM_*H?!m}&h&9EGSAvwa%r z7F%;|q|g1Ht#u6`J_6ldt`Al4#mvG1=jf4echgBK`qR(R!2u-nFc&vYx$vVG9JFMoH)+F0RbEph6+ z+o;SEiB_1pLLms|bNt@Ebm9DQEH$7!sD1Qx;f868{nP6z-fYS%ao-u`&T*In>mXZK zeV27lzwBpLZpzgAcS5@}s)AD+E@`BozJL8P=V8ps-w7d^KxatSyFZ#mUDikgcB(87 z|9a)LE*BBSrGaWM7WfTOnxXdIX5x^Bcao$s-5aN*$JEABI6FreyZvrgbYY@FH1K%y z^A@dBN5fevBQGyO0jj9Ws~i&03CoPs329G`=D!-wLJ6wVf$IoV^dzg4*p7Qe(&O)= zxY67*+o#mg<>Ju9vcK&veY*A8~!w%|I^k0CnV+>Sn9wMgzzi ztK9u6M;si-WAedp%JoL1I|{*`%_#L3-N`84?5zqMoP1lpD?yU$f<@;*Q#;2OQOd## z=i!(CS4(|<0rVOA)qws*)N$17xcTX_j&3}3U)YknmQFbCKsnY+c7#B=bQK@^>`7Df ziC4++CH?J>WJfMN9M3$YDWFo|)Xc!&i;`7_*u5NDXyZMkdgA@2%F|A`?Pg8JC}3pK z8&J%429*Aliq`aP&g4%>15u-d|5dga9uM72$IhJQz>4ojg{sJ0?NqZZ!j(lpA-aC^ zbww-48c@;GYa*FtT2q%<+c*|%GM&47CR6Vv42-RYMo-ON+B)HUICAp*otYI}e!}7-B1fw>ULha$GllZ7g^F|DvP_@PT!sPr`j(OeYnE&wJ-c1SQ3g zBQJQ-j#pWjE!xi%d)%TXUQ8);Hl+?0N=-PY^CB(6-VCUJ-{mJqnac^@5-gWn{3`KM z=jRLCMYJCS120ON9C7vwv|3SnXy;!GsgN`NEnSPLQDFA2`x?I%rxeXNdN7L9rp)GV zQTs(b!0zwkNKG&zBo|vw^sbPQ*cTOFzR(ubhwJ=EgUlBYDr~5^VYLg~XesKmiA$3B zL0&eRio||j@HQ=U6!q9|P7|ue-ZT|>+&)$u5+xVip{FI;AG^=tZjH~mrryicn*$p7 z7fFH)6s1a_z$^~va5p}&t}teY*Oo00rw1*}g3E_SvyhQRm922{4u=s8XOR%_b7KZ- z2?Ub8pHZlJn7wCShbeSmiTn8L5uyLy-DqY-K&x+>=fr?effoC@^()$R*?Et5{zn+oHsrh|(a(7|Gfc`ufggz5Z% z%v8)z(p45|yhc7G_sjIP>m>K5AjIr*Ik@F@HSaG`RMExX5mzi7jK4aFm-yU=qu*Bq zT97Lif@`*hyPYbUUW~EIV2lFyl!!-bX^t`&Xy1kZudS~kim5M;Dju|XVk#QbQ*wg} zW~(RRLIEyZf$i6HjIw0nlO1ux#>Ky=pyXjPAf3tb3=(xj*x|ly!_^za|9XC55gUDy zdhk2;d!&0N{uCaykwMv>jE(HmX{!(QMyYW%@3TtK-|S3Y0~eym#LvX7sdw=FtE+y# z0Z(h1r`Ew<#oZKCVZn+SC8SMrY1&u58H*pJaK^%>gaafIWZRQWl{o}~y5~EiY~ybi zKOzZuKStD2t~iMhag7qn9cZ+3#e!ZeV9y>MU zen!94@ODaKZESz=f0;w^FO=Z&bUux7BS?{!mtUif?BJuH{8+c3U&E4mBo>A|F9LWJ0)PEwvW z0&gs(tdlp^PI-&y8O51?oLS*+{NJhaeP}i_e*wl9KikqSU^a&Bi+n6jVcN^nq+^-H zKA#ZiJ3XJ3A3bli-6r9Wo5v-;?~X7pPU($bVBDTNHcB4vpr+ zJeO9T!By+bPg|wDpo{aJW#IDAqd@e9Kv4}1d0jAnP1KPWD#^Jx0s9G)NX*kIRqt_-Y;uBWU9xJ<<#vYIrKH7vp9TVVTn ztUnAwrqDd_Pg38BL8w+5`%*SA|K-7fs1KF?_zvj77W$vH`_Rn$&&(1T8dngx`Y$7K*rW{HHJrAJt-#d}q*)e3j*T&g|dFGWD(Kw2>QHuSL zmsTOtEqfg|CYgHl(HZd+*?Bs`L_tMmt)=Zf#((Lz^hcnN9J|2^1j;&=cO|Zp*F6Lf zk?dtHbx=Po?RPc91G{N3akC$v2G{!eN0o3^5xRBa^tl6u1mCJ3mzJxHM<_ksVn5t} zbtBzN33xR~$y(B{Enc&-#r0(}?KyA7=$UOc*ya1IYJCTIwFK1Ey|&BZ&gB11;dCCp zIVBHRf4DC%JM0|G-lM$Fvi=bx%=ZFI`~JCY{?`{vz?B3PKS!+AAFGNg&SAHKAnwdY zHjE>>;!e~*dvmwjV61z7J&sP1n~Mh9oe{B0J(tE4Fq<0E46mP^yOSy2*9D@^zc=(c z7~do~AaI!f>tlopJ@osLUjczSFiYPVZ+YdrNz8CZ5S|25Vp0&yn> z>TD=Cu$(OB02AT6d!c1sJjuhO;FdhZ*j0f^bM5@R(wriZjqn)K>T5A_Q1i$`znlur`QXf5pmw%^&H|?$@g~#48&{m7tk!JKvn*44eBMXttMvs`YUxamu(b_ z<*k7_@ZSLr@^Ne~;c%z?n*JG^h z0YHI^y#?F0_OQbo%7O>!@%%Da&j>V3a+HbIY@GU>W|l5z_cx)_%4h?$EF{SIK6$B) zJH4z)_1@vqq3d?(ig!$KEPsoL*kUDAowrbtLPF@u+dtsMI-ivIjaj&-e*IgJt7Tex ztm@iCk)dIY_jr=7qE)@3IuWIczMV z<8Q<1|22f)nHt4{K+wI_ActSnFPnF+<&3v#*|~uu+u7oc{`uNa;)k_&dszEe$Kf^) z6zrz_CtIS+et1H+_B#pmT8{cb+xCkrx=+ul@60lx3azq*o=GR5Jj%2Hk{gmYkzQ zAM3f`@eDQNvw^BaHm#1%b`CKHg~wV6=QRw&{0DSDK_QK!Lc>LrdDzMpb+93f(7CjuPKkdXFl4>pRT z{4b>oVPWj8vnDIjC29fA_ZIhky<3?P(`EC{X{5;6wEz2TVF14I$NKkNq@?(L>FbVS z&3Ip3v-Q3XL4(dB<>n)kHFYmSAH6D_*lXl?QAX!Xf-s#ony&`$EKhfq);kSTK86|g zjdot`d9K!dlzbzrTkk?K3a}U{-N5~gY*@~B#1_#ot%)DzCf~+g|8~)VKPHd{uo~86 zb68FG1g3G~BxJ)r2zZBJnQ`-#;J590b(qewL;mvu=Hzh9qD}buN*mM>M-eub$AMR*=KQIkFjyj-eRPPT1HDe%9BUNCI!oc zi_l2ctRm40(}FY5!IqMPVr|vxt#QcwehS*{8-7PuX}cMbzYx#T_HIWr^BrHY9q#`o zUmvB4o)W3+f^OP)?}Tl+%ySB^9Ndd>XN5Pc=OY|6S+dr~kp4?yrCt)`PIdDN&I5ZM zm%DID^dYCqiP`(DMpTi7)p72=GG`Z*oZR8Hzjad6rP=C7p;wWcsoWPXX9U^VxW$1h z21FlKNWZ&F+sO!%ZR?G2@f%-SCJ9&!81dmJ%yS;AJWMnRI>qAdgTO(K?J|92$vpYD zED1=UIEN*lX|gzV9zAh&tvMNLA7z6%Gf`xb;&OmLC=hJ!Z((DeD%-bBWU*Qb&$R zs^4or^-T8Xt~EEFXnaB>)u`Ho{MqY)8wHNR@$bU9rxXZ80iiVL#W{Yh=3I-`mVDr zMo@IT2#SuH*NW20h-jU`DgS`;)O81>xk*`l`svBSX6;0r&*TLY-;HL$AYt`Z%j7a9 z$6s$F+8*83|MZbgC-yp5v1t0Q64@f+B5i~45hlUL6?y)S&V7VhiZ70BB3Y_|}5hMq{)ALBo~qQ6!RvH(l=X7|}x#%_Au8!-**UPV2F ze%F*f9H=J-^vLuS-w@Pqg+dx69+T}?Iz|k(W+(9yl|4ys+U>~W$hbRjZgEh8YEr5~ zl-D|V#6G4`<@b|7cA?Scouvl=6}4q{=q_HoL{w^XDpiUkRz71lkM1efNQRqrvgOAH zALrw@WjnZ-2!z51mj-zb=w3s!;j~m1rF4xHcT1{y?D=mE><>)kST*li(h8#X`3*(Y zTcKuRaUC4HJlR@!qSUcU5!~ktqtU-g?jQ*S-!;lDNQlx9VSwK#sl_f2ahZx-KV>6n zCJ3E$G~CpKr9?C>?kk{};GM>tSA?R^9*C?*p?~qe3UKgupk0Vn<uM(SF(9~XnkUXY4)YM%Z&xcsdprDsiL&($$q1!X-&)Yv z+Ro>z8Ec#k#ziuoQLgW+Xk=P z2cdFDZgW8*-8Eo%ce3z-5tkZmvMlkSZ1VtVXO5JwUC%hvx{6Kb7afY^=fnSy@qrRL zuyy`Q$sHiA*vo8b%sCIFff$44r5@e*P$evPXhHfE*DLMqxO}8@&53f-Ig#Qbb5ynb zIZK$9Ojw#Pm2(}6&|hd>whUBkD)ufPk=;3y57mZ-wDWmW%Awl<T9{DfaoA4TU$}y^3H^r?&gK_~GjbVBi2cO_Qb)(^#4FTm5SE}5-!9JT- zXdHqJifd@MNFH`74Y;ng^Y+qppVpa@B1HRp67|iA9F%aK$>TD+L6(Mo5t*YggF%N$ zX#4Bf&`Jr?7GE)C)l|_)#9>S$!{4v0bS@g8fi#I&wlXTSrIo|2UOzSuZU?6gSKL}9 z>TqXT?JYXL9QWf31;??grE2W?D?JBSTUX9CTGfZ+9_wxC|J3#Tigs~>5zPs%Jzcu| zOO5kt{hfpA7i(SZ9@8F<7#HPhb2=S*YO$1LeQcx@D$`-lBtG7Md?9MIt8Y3w{0Pg9 z$B$0CF2zvKO_Zvd+fF~Yl`w1L|B}NOkF6xW|7S*L6o?HNuf6=C;M#m+GO`w!VOjv| zhFhAN330}B-vGutAJ}{W3E%JWo|8Mr6pzjXzHdN7`7JOZ_p2aLJHRU#$)!(IJ_o8c zvGl@;TWqv}JytnMke8Q&bg++aZv-yq=YqFIs-Ga?Tj^o@|YP=Q?ZeO^S*U$hEQjwsqb5H$|-(y z8K-=1==_Qw^P2Se<7+d zVvU$iM!e8#T_=u-jn7o-%mZfoafFsg<742?=mumtEqmWe`X%sB7QN4>n}$i#M`cfh zki>7?yj`L}$-9(Gc!OBV@76uNp<*vIk|^9LiI<8C85yG(cn$y%Ge#6IRhQPQihbDTE zri7>%Xc>;`UaYxl6_Eg{Y{MQ{2_3~RJv|A+yYF_sU(;)tGRcX5EVb-ArNVyHwfRaQ zcT5ci3+#oE;ED2g3dEP%B`*`X+1wU}yhTeo@A*mc6d4r%dtIGON%)hodd~}yV z)OnchC`FEq$;7|pyH;O%Q{zc;MdA?=u)+n*UWkZ@h{kr&Ei8emF=H`iosvzAV#MSL zJCG-Tk(wFEB zaB5yG@71;25LdW2bcPu@EQvp_fzfIc+j~-TS;L@?!Lp{ji6m9p z!W)-1Unb8F4Gv|CeDUT1sn1@ut;tACgY4-fCx(F2r%nO{0L)(8hNf?e4HiMCqu-g2 z@HnvmUE9!CUU=_cUwOBH4^k#@-T_vG#U(HxVbL=lSZIsc5MYVdzWFXvk=@b9h4#`= zb%NDp&nJo50UOOsQ%InX1XANFkHcLD#bAXY8fU>1=(h z5BxcC74KoTAP9qkEVNn;ABpm03$pX#Ho9}zc8khXI0rCNuE(9@IE%;+zzi+Sg`T(4 z!ma0N&wdZUv$Ky^2(4iKddN?v4Uib3;A_MdiQ%3vg(NQn%->jVqypc3nUL)pm~^q5 zgN?}uqCSXRj%A$sE{on*YXNF{ZQ5@m9Yaif#`5OHM!AWHb6&4oFchZg6m%vf^j83d zrwDC6mRKrek0zO~nySamR0qgJP%#APY5a+$kL!N@#M~A15kb0&ADabtH5YNj@xl4OThNWt^a|QX? zjT}<7a!&Nc7Flq3j^xXcuaA#QV!wak!0kKVTrxscbQm{c366G{4;CQ~#1@a_rr*9( zk`eV00YWeYTb=`Men*6hlPb03r+$hQOuUlp`mlW~xBEPIOMMm08=Z-Zd;LaZOJCV< z)Ai41OM)iw@QpjWZ3o${j|BOFaUjr^_S3*>spcLGa6#SuYmvwSf@+b$t^vN0ncYj2 zo6qBM8TVUIFPIa#xA*5~d)(6BDURc<3GTRTX}z65l%wxt@8$(w@xJ`0?+4t3Vg>Fr zH$B)CNLyyxbP0TaLNj%bvcDioKh%T60>4?o@^H&sRBy*w8w!QKb`0&xmA2m06kz?m zGe-@qf~u!Al%tCXc&Tk2!rpJ~9r%^sl&RyJyZa}GRv;qf20@E*7^tMRnYVT7gi#e@ zC72dCr{c|8)6WSNRWh(I`{*ob4n2g&m6lTMOl))FanmE+JCEAuXImSby3vmbw<@)% zQ7tvB5S5eHx&jeG1H{RREbOk->=Tl0snN9Gcu}=_<8jF^;Z*(KTBzs^4DW_5UKX#g zN|-H_hBS{pnHE{t?+d^Os|!@O>Js7}*{m@%|IVu(?@W&~ph7d-1rqsfTphXu?bY;~ z!}%OSt1VB2%?`XByZRDmtvtw*B`#bPYpaR5xa(|n(GUYM@0B;jRs-wl=-q*!#o<(z zKi_ab-H-xRVCjhC1nGJAZXQcA>$R!S1#68{B5bH1KXg+|d)}&fPJY%KFE&f1wp;Q< zW?NGGV=#^svT57<5D{C*N!QelSu9k*3i){cJi6st+fgy}erCEO4Yy*LXfK;Dbz%(a zN|-;{*#QT8(i+hiwLSrp9RZ>>c<-E`|eY$@ZS;z7?-%56zh6q1D5DSrWA-NMNn{-|N z>;m+#Sh6D-Scj)RnNkXmWzTCTk@<0do8aWM2i`LF#0DhBr^*Bpjk3M?T0cUI)H>3F zw~7+eT5_w?XL`q7ln=YRWrLMQX88Wz4aXm5H*nKo;!o4z%uAS*dx%dUi7zWjRXOE; zwhgM8^vmGn7J*q>aH8VV5uUB`mE*xxNe0(|mWzNy0mr#+T;^$;ADMsefD5i@@NiYB zMju8*#R72A`P{KC5K+*B8j*d=n>6XsP#lSp>f!v)# zc?Q#=#%E0(SH%o5`PoV{=Am~b%7YnP=wI+>wg~z z8?Tyuq|o`FcgZL6`oCOIftfyo6e%ol438K7d(1+Yj?*&6zZpI!+tys;xWHZYoB#BC zh~O_LH<=EBYDjYO?UUAobgtlE{~k1FU<9WEAv3MD^`+k^vf~+Q1Z1fHG2UXP;{R~d zdN@1kzfaUup3tTkK9>;_zc z;i3hQ-Bo26TeIeU8X`;Run`nQ!My;zCv-=p_iG`Eqa~WWld$082}zdmf7969Y@a06LO_0njSnQ9-O*L zkSGX^1^$h!e%nnY9NEQpba2b+2z-Mqqz|H|U|{t!S=ejHPdt4@0pJU9LL=h1LX+Az z>B>q4YEjO4QQskYJA`prKC-wQ8@CqtkzKv8Sn8d4Bugrkm4zi*dD_6W@ zYJ8TLXx|SHw0paw5csJjE^^+L^;D+v^5HqaTV7b4MjTDC@!0ZPkeInLrjb`q);{5p zm4yioUMp`JWJ;)S$UuMpi?}#K z{EN-<>^`yyyGBCaBjj2MQEv85(qm*4i5v^zVTThBGTB zPi^5{ltA;Rk7@QQGZC~r83AK6N6!Necnm9@okFJ5@n~MuMG7orc@HgE6{8cvMOpjn zbCK+tW|MVUP0A@J=cSN&cWBDGLm#ZGwP9TO`beDL5aQ7Wtdj4_n2cywS3`xa^m4l9 zh}C9IdHE#dO|@KAUaSNmv!ijr@OX`99_^E{Fzprq&dpdmbVRwVapebuhK63Me|or| z%>(Bb^7^FLI3}l&7h2B5q70bM*oo@LyzcJq8Q9@$LNSxaABL32AUcWr+?iHwUgyYr zjJI>k&vA;2j{wQz0gj2U9b=~BN6Rk88`PR_M7hdnE@kxip~VX|{RgKiP~klvnIHHhvsm4FJ^`d`C%QeUwAbo?%~aBWQ~<#(YNO(bP%l9 zhq|>f?gGd1XVV}-G`~|YE}ByhG~aAh!vMd!EGvC@bMY1pMW%A&R8gXUNhAQxj?~T6 zgYL99uuFSQs1xU+7Jx~6y-(ACU!b!-t!d?%*f)B$>f!S;A47qJ5QH(7Z9^yY|)dL=HNAcK_a~T z!7R0P`in0Ng1_Lz=pDy^ynUIC{tr>|($eFrE*ne0m*hRv6stM*R`T(I_>?z_u#ZWL zjobOs3o}^0i!Zh>d)!m_^X-0nE(}qH?YHK9YIZ4xC?-KMw@LjLus+!6D>+jtYmzxG zFOFw(%qVZqL&yMjYA_=hzOqUnbs@rO;gdRsVM$?jEXr|xqS_kjM^)PV`~_DA!m@yT z@PiBuw#M1vekxsmEL(1U6L71J?nx2fC>PXM3Ws~m-%oK^)AUpfKrQe2%{}9F1-c-X zz?sQf+z65@9cEV^uW;v1q)*t&(5J(}QkxX-QJX5cB599BGm#9G%Cnn>tI%jhmRr22^rCIJvY`4(Ki)>rxcD& zdk5|tlQJ-A2PR*<#GOl_H1M9(a|kPLJlVU?R;h?G6eyeIR9>W`c|L8*RD`YEfjJFG z1J&$-Ze=K$OT(V^HrvS6*_lJE6rsNT#XvJZyZeO>h1KodyNm~ZBeb%PKE-ZMyt$gP z&5|l0{j)&9W-Q?$k897MyyLqZ!SLB*9ackNe1o^TqhmBg5;Rm~wwA zp~hT|2WmZ6PP5;;)e$#X>{2cK?11bx5$S%+%7%$_=$f>%?JaAOAVI@52afyiC<)Sy-Q}d?Mh;I`w~fbf_7h&jl*!SM z_RxC-5P~=SqRU$YQ)y#3?ciRnd}EhB?w;tS+TFyAvBW>F+4rrlTv9T)rjomFA__F`(n|34XaDR3ZZFX zrzu&Jt}4dl^_1UvAsd*Wx_H{b!NeXpT~=O+Z!b{0TjS21I%4yfIa zqfZapb9i#)2_ePIfEi;XtsgYt@P^zW`o2F%dsY4Frv1mgKF{lK-(cQIxzP7Tter*9I9dM^J|kXk9{hjf3@{Q#7` z-_od+;~sEYNy*GaU(_D=X#jSR*-wh8$d7lvIXK@~FRLvPKN_R>j^atR%=s}rroJ&9 zu-66M95!-EPLy>*6^x^w6hOGwCIdiyAlWM|05?AE+QnbFnk*8%gGLrtkbBps=Xob; z88OTGSJC?HvppzU2eS4VvLTqtLfJZuX`z>XNe=30Sv{R!Ff$INjOdIwPr8W|Zt_@0 zRw_Gj=IRBENvg;a))a~pdvX;)5mR*+bK%87LE=#@nTIRvmI0emr;}pov!^D(>FSO6 z&9&96igJiKhS^q+cI|NGdtUam<640-ZgJxwba#-?tiy9qRN6VI_e{PZ$ui}GlN3 zsZ2YY`em5YCp&v>&5`1b4WkjKH{ZFU*0A>&7%K=E+9T* zfSWrSTmJ^@{1xWp@_@Q1J7f00Vh9j55l~p7L>&GJiFMdPV78ni^sk1^NuZp|%$E4@ z@9>W^a2do#57a{K{%HuJzyH)TqGGA58vkktVlqPJ_I9h`zY{XY!Z!@?p}Xe)8WK3V zlQ#;936%aFTq+O^(Sl@NEcw7c4MA+FLl9Qtmi!wY@f-IF*aT*P3IDH#AmRh6k1<{O j{~Ea`= Date: Mon, 31 Oct 2016 01:40:02 -0700 Subject: [PATCH 08/11] Add default cuda system path (#192) * DYLD_LIBRARY_PATH is disable after Mac OS X 10.11 * fix clang + gpu compile error on Mac OS * fix some words and errors in build docs --- cmake/flags.cmake | 51 ++++++++++++++++---- doc/build/build_from_source.md | 16 +++---- paddle/cuda/src/hl_dso_loader.cc | 81 +++++++++++++++++++++++--------- 3 files changed, 108 insertions(+), 40 deletions(-) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index dbad6be3f4..e087770991 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -21,12 +21,6 @@ function(safe_set_flag is_c src_list flag_name) endif() if(${safe_name}) set(${src_list} "${${src_list}} ${flag_name}" PARENT_SCOPE) - if(is_c) - set(CUDA_NVCC_FLAGS - --compiler-options;${flag_name} - ${CUDA_NVCC_FLAGS} - PARENT_SCOPE) - endif() endif() endfunction() @@ -40,6 +34,20 @@ macro(safe_set_cxxflag src_list flag_name) safe_set_flag(OFF ${src_list} ${flag_name}) endmacro() +# helper macro to set nvcc flag +macro(safe_set_nvflag flag_name) + string(REPLACE "-" "_" safe_name ${flag_name}) + string(REPLACE "=" "_" safe_name ${safe_name}) + CHECK_C_COMPILER_FLAG(${flag_name} C_COMPILER_SUPPORT_FLAG_${safe_name}) + set(safe_name C_COMPILER_SUPPORT_FLAG_${safe_name}) + if(${safe_name}) + set(CUDA_NVCC_FLAGS + --compiler-options;${flag_name} + ${CUDA_NVCC_FLAGS}) + endif() +endmacro() + + CHECK_CXX_SYMBOL_EXISTS(UINT64_MAX "stdint.h" UINT64_MAX_EXISTS) if(NOT UINT64_MAX_EXISTS) set(CMAKE_REQUIRED_DEFINITIONS -D__STDC_LIMIT_MACROS) @@ -63,20 +71,43 @@ set(COMMON_FLAGS -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wno-unused-parameter + -Wno-unused-function + -Wno-error=literal-suffix + -Wno-error=unused-local-typedefs) + +set(GPU_COMMON_FLAGS + -fPIC + -fno-omit-frame-pointer + -Wnon-virtual-dtor + -Wdelete-non-virtual-dtor + -Wno-unused-parameter + -Wno-unused-function -Wno-error=literal-suffix -Wno-error=unused-local-typedefs -Wno-error=unused-function # Warnings in Numpy Header. ) +if (APPLE) + # On Mac OS X build fat binaries with x86_64 architectures by default. + set (CMAKE_OSX_ARCHITECTURES "x86_64" CACHE STRING "Build architectures for OSX" FORCE) +else() + set(GPU_COMMON_FLAGS + -Wall + -Wextra + -Werror + ${GPU_COMMON_FLAGS}) +endif() + + foreach(flag ${COMMON_FLAGS}) safe_set_cflag(CMAKE_C_FLAGS ${flag}) safe_set_cxxflag(CMAKE_CXX_FLAGS ${flag}) endforeach() -# On Mac OS X build fat binaries with x86_64 architectures by default. -if (APPLE) - set (CMAKE_OSX_ARCHITECTURES "x86_64" CACHE STRING "Build architectures for OSX" FORCE) -endif () +foreach(flag ${GPU_COMMON_FLAGS}) + safe_set_nvflag(${flag}) +endforeach() + # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # So, don't set these flags here. diff --git a/doc/build/build_from_source.md b/doc/build/build_from_source.md index f7db0a9b92..7727c8c378 100644 --- a/doc/build/build_from_source.md +++ b/doc/build/build_from_source.md @@ -153,12 +153,12 @@ As a simple example, consider the following: - **Only CPU** ```bash - cmake .. -DWITH_GPU=OFF -DWITH_DOC=OFF + cmake .. -DWITH_GPU=OFF ``` - **GPU** ```bash - cmake .. -DWITH_GPU=ON -DWITH_DOC=OFF + cmake .. -DWITH_GPU=ON ``` - **GPU with doc and swig** @@ -171,7 +171,7 @@ Finally, you can build PaddlePaddle: ```bash # you can add build option here, such as: -cmake .. -DWITH_GPU=ON -DWITH_DOC=OFF -DCMAKE_INSTALL_PREFIX= +cmake .. -DWITH_GPU=ON -DCMAKE_INSTALL_PREFIX= # please use sudo make install, if you want to install PaddlePaddle into the system make -j `nproc` && make install # set PaddlePaddle installation path in ~/.bashrc @@ -246,7 +246,7 @@ easy_install pip ```bash sudo tar -xzf cudnn-7.5-osx-x64-v5.0-ga.tgz -C /usr/local - sudo chmod a+r /usr/local/cuda/include/cudnn.h /usr/local/cuda/lib64/libcudnn* + sudo chmod a+r /usr/local/cuda/include/cudnn.h /usr/local/cuda/lib/libcudnn* ``` 2. Then you need to set DYLD\_LIBRARY\_PATH, PATH environment variables in ~/.bashrc. @@ -273,12 +273,12 @@ As a simple example, consider the following: - **Only CPU** ```bash - cmake .. -DWITH_GPU=OFF -DWITH_DOC=OFF + cmake .. -DWITH_GPU=OFF ``` - **GPU** ```bash - cmake .. -DWITH_GPU=ON -DWITH_DOC=OFF + cmake .. -DWITH_GPU=ON ``` - **GPU with doc and swig** @@ -291,9 +291,9 @@ Finally, you can build PaddlePaddle: ```bash # you can add build option here, such as: -cmake .. -DWITH_GPU=ON -DWITH_DOC=OFF -DCMAKE_INSTALL_PREFIX= +cmake .. -DWITH_GPU=ON -DCMAKE_INSTALL_PREFIX= # please use sudo make install, if you want to install PaddlePaddle into the system -make -j `nproc` && make install +make -j `sysctl -n hw.ncpu` && make install # set PaddlePaddle installation path in ~/.bashrc export PATH=/bin:$PATH ``` diff --git a/paddle/cuda/src/hl_dso_loader.cc b/paddle/cuda/src/hl_dso_loader.cc index eee9984e07..91c60d85a1 100644 --- a/paddle/cuda/src/hl_dso_loader.cc +++ b/paddle/cuda/src/hl_dso_loader.cc @@ -46,63 +46,100 @@ static inline std::string join(const std::string& part1, const std::string& part return ret; } -static inline void GetDsoHandleWithSearchPath( +static inline void GetDsoHandleFromDefaultPath( + std::string& dso_path, void** dso_handle, int dynload_flags) { + LOG(INFO) << "Try to find cuda library: " << dso_path + << "from default system path."; + // default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH + *dso_handle = dlopen(dso_path.c_str(), dynload_flags); + + // DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to + // bring System Integrity Projection (SIP), if dso_handle + // is null, search from default package path in Mac OS. + #if defined(__APPLE__) or defined(__OSX__) + if (nullptr == *dso_handle) { + dso_path = join("/usr/local/cuda/lib/", dso_path); + *dso_handle = dlopen(dso_path.c_str(), dynload_flags); + if (nullptr == *dso_handle) { + if (dso_path == "libcudnn.dylib") { + LOG(FATAL) << "Note: [Recommend] copy cudnn into /usr/local/cuda/ \n" + << "For instance, sudo tar -xzf cudnn-7.5-osx-x64-v5.0-ga.tgz -C " + << "/usr/local \n sudo chmod a+r /usr/local/cuda/include/cudnn.h " + << "/usr/local/cuda/lib/libcudnn*"; + } + } + } + #endif +} + +static inline void GetDsoHandleFromSearchPath( const std::string& search_root, - const std::string& dso_path, + const std::string& dso_name, void** dso_handle) { int dynload_flags = RTLD_LAZY | RTLD_LOCAL; *dso_handle = nullptr; - std::string dlPath = dso_path; + std::string dlPath = dso_name; if (search_root.empty()) { - // default search xxx.so from LD_LIBRARY_PATH - *dso_handle = dlopen(dlPath.c_str(), dynload_flags); + GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags); } else { // search xxx.so from custom path - dlPath = join(search_root, dso_path); + dlPath = join(search_root, dso_name); *dso_handle = dlopen(dlPath.c_str(), dynload_flags); - // then, search xxx.so from LD_LIBRARY_PATH - if (nullptr == *dso_handle) { - *dso_handle = dlopen(dso_path.c_str(), dynload_flags); + // if not found, search from default path + if (nullptr == dso_handle) { + LOG(WARNING) << "Failed to find cuda library: " << dlPath; + dlPath = dso_name; + GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags); } } CHECK(nullptr != *dso_handle) - << "For Gpu version of PaddlePaddle, it couldn't find CUDA library: " - << dlPath.c_str() << ". Please make sure you already specify its path. " - << "Note: for training data on Cpu using Gpu version of PaddlePaddle, " - << "you must specify libcudart via export LD_LIBRARY_PATH for Linux or " - << "export DYLD_LIBRARY_PATH for MAC OS."; + << "Failed to find cuda library: " << dlPath << std::endl + << "Please specify its path correctly using one of the following ideas: \n" + + << "Idea 1. set cuda and cudnn lib path at runtime. " + << "http://www.paddlepaddle.org/doc/ui/cmd_argument/argument_outline.html \n" + << "For instance, issue command: paddle train --use_gpu=1 " + << "--cuda_dir=/usr/local/cudnn/lib --cudnn_dir=/usr/local/cudnn/lib ...\n" + + << "Idea 2. set environment variable LD_LIBRARY_PATH on Linux or " + << "DYLD_LIBRARY_PATH on Mac OS. \n" + << "For instance, issue command: export LD_LIBRARY_PATH=... \n" + + << "Note: After Mac OS 10.11, using the DYLD_LIBRARY_PATH is impossible " + << "unless System Integrity Protection (SIP) is disabled. However, @Idea 1" + << "always work well."; } void GetCublasDsoHandle(void** dso_handle) { #if defined(__APPLE__) || defined(__OSX__) - GetDsoHandleWithSearchPath(FLAGS_cuda_dir, "libcublas.dylib", dso_handle); + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib", dso_handle); #else - GetDsoHandleWithSearchPath(FLAGS_cuda_dir, "libcublas.so", dso_handle); + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so", dso_handle); #endif } void GetCudnnDsoHandle(void** dso_handle) { #if defined(__APPLE__) || defined(__OSX__) - GetDsoHandleWithSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", dso_handle); + GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", dso_handle); #else - GetDsoHandleWithSearchPath(FLAGS_cudnn_dir, "libcudnn.so", dso_handle); + GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", dso_handle); #endif } void GetCudartDsoHandle(void** dso_handle) { #if defined(__APPLE__) || defined(__OSX__) - GetDsoHandleWithSearchPath("", "libcudart.dylib", dso_handle); + GetDsoHandleFromSearchPath("", "libcudart.dylib", dso_handle); #else - GetDsoHandleWithSearchPath("", "libcudart.so", dso_handle); + GetDsoHandleFromSearchPath("", "libcudart.so", dso_handle); #endif } void GetCurandDsoHandle(void** dso_handle) { #if defined(__APPLE__) || defined(__OSX__) - GetDsoHandleWithSearchPath(FLAGS_cuda_dir, "libcurand.dylib", dso_handle); + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib", dso_handle); #else - GetDsoHandleWithSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle); + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle); #endif } From 12945b2c9017c88d2581253a6df83d9e35e2b804 Mon Sep 17 00:00:00 2001 From: backyes Date: Mon, 31 Oct 2016 23:42:59 +0800 Subject: [PATCH 09/11] Add glog header path to include (#295) --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index e96ce28248..527064e310 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -135,6 +135,7 @@ endif() if(WITH_GLOG) add_definitions(-DPADDLE_USE_GLOG) + include_directories(${LIBGLOG_INCLUDE_DIR}) endif() if(WITH_GFLAGS) From 45c81a414f912b93c6ad34fa4045e8e80fd396f6 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Wed, 2 Nov 2016 09:56:20 +0800 Subject: [PATCH 10/11] Add job=time in trainer, refine cudnn_conv to reduce gpu memory and speed up training. (#218) * Add benchmark for PaddlePaddle, tensorflow and caffe * ConvProjection to reduce memory for goolenet * Add unit test for ConvProjection. 1. unit test in test_LayerGrad. 2. compare the ConvPorjection and CudnnConvLayer, also compare the concat_layer+img_conv_layer and concat_layer_conv_projection. * Reduce cudnn_conv memory and add benchmark document. 1. Use TmpMatrix as the workspace in cudnn_conv to reduce gpu memory. It reduce lots of memory. 2. Add benchmark document. 3. fix smallnet_mnist_cifar.py in paddle. * Add job=time and refine cudnn_conv to reduce gpu memroy and speed up * Refine cudnn_conv and shared biases operation in concat_layer and mixed_layer. * follow comments * follow comments * Use unique_ptr to prevent memory leaks in CudnnConvLayer. --- doc/ui/cmd_argument/argument_outline.md | 7 +- doc/ui/cmd_argument/detail_introduction.md | 4 + paddle/cuda/include/hl_device_functions.cuh | 19 ++ paddle/cuda/include/hl_matrix.h | 36 +++ paddle/cuda/include/stub/hl_matrix_stub.h | 13 + paddle/cuda/src/hl_cuda_cudnn.cc | 7 +- paddle/cuda/src/hl_cuda_matrix.cu | 87 ++++++ paddle/cuda/src/hl_cuda_sparse.cuh | 18 -- paddle/gserver/layers/ConcatenateLayer.cpp | 33 ++- paddle/gserver/layers/ConvBaseLayer.cpp | 51 ++-- paddle/gserver/layers/ConvBaseLayer.h | 24 +- paddle/gserver/layers/ConvProjection.cpp | 210 ++++++++++++++ paddle/gserver/layers/ConvProjection.h | 125 +++++++++ paddle/gserver/layers/CudnnConvLayer.cpp | 256 +++--------------- paddle/gserver/layers/CudnnConvLayer.h | 73 +---- paddle/gserver/layers/ExpandConvLayer.cpp | 39 ++- paddle/gserver/layers/ExpandConvLayer.h | 10 +- paddle/gserver/layers/MixedLayer.cpp | 20 +- paddle/gserver/layers/MixedLayer.h | 1 + paddle/gserver/tests/LayerGradUtil.cpp | 6 +- paddle/gserver/tests/LayerGradUtil.h | 3 +- paddle/gserver/tests/img_conv_a.conf | 39 +++ paddle/gserver/tests/img_conv_b.conf | 32 +++ paddle/gserver/tests/test_LayerGrad.cpp | 39 +++ paddle/gserver/tests/test_NetworkCompare.cpp | 9 + paddle/math/Matrix.cpp | 52 ++++ paddle/math/Matrix.h | 28 ++ paddle/math/tests/test_matrixCompare.cpp | 56 ++++ paddle/trainer/CMakeLists.txt | 1 + paddle/trainer/Trainer.h | 1 + paddle/trainer/TrainerBenchmark.cpp | 71 +++++ paddle/trainer/TrainerMain.cpp | 2 + proto/ModelConfig.proto.m4 | 5 +- python/paddle/trainer/config_parser.py | 68 ++++- .../paddle/trainer_config_helpers/layers.py | 105 ++++++- .../paddle/trainer_config_helpers/networks.py | 165 ++++++++++- .../tests/configs/check.md5 | 9 +- .../tests/configs/generate_protostr.sh | 2 +- .../tests/configs/test_bi_grumemory.py | 10 + 39 files changed, 1340 insertions(+), 396 deletions(-) create mode 100644 paddle/gserver/layers/ConvProjection.cpp create mode 100644 paddle/gserver/layers/ConvProjection.h create mode 100644 paddle/gserver/tests/img_conv_a.conf create mode 100644 paddle/gserver/tests/img_conv_b.conf create mode 100644 paddle/trainer/TrainerBenchmark.cpp create mode 100644 python/paddle/trainer_config_helpers/tests/configs/test_bi_grumemory.py diff --git a/doc/ui/cmd_argument/argument_outline.md b/doc/ui/cmd_argument/argument_outline.md index 98dadc270d..d6cc2c6ed7 100644 --- a/doc/ui/cmd_argument/argument_outline.md +++ b/doc/ui/cmd_argument/argument_outline.md @@ -183,7 +183,7 @@ It looks like there are a lot of arguments. However, most of them are for develo -GPUgpu_id +GPUgpu_id √√√√ @@ -207,6 +207,11 @@ It looks like there are a lot of arguments. However, most of them are for develo √√√√ + +cudnn_conv_workspace_limit_in_mb +√√√√ + + RNN beam_size diff --git a/doc/ui/cmd_argument/detail_introduction.md b/doc/ui/cmd_argument/detail_introduction.md index 0d0362d022..07608e5edf 100644 --- a/doc/ui/cmd_argument/detail_introduction.md +++ b/doc/ui/cmd_argument/detail_introduction.md @@ -163,6 +163,10 @@ - Choose path to dynamic load NVIDIA CUDA library, for instance, /usr/local/cuda/lib64. [Default]: LD_LIBRARY_PATH - type: string (default: "", null) +* `--cudnn_conv_workspace_limit_in_mb` + - Specify cuDNN max workspace limit, in units MB, 4096MB=4GB by default. + - type: int32 (default: 4096MB=4GB) + ## NLP: RNN/LSTM/GRU * `--rnn_use_batch` - Whether to use batch method for calculation in simple RecurrentLayer. diff --git a/paddle/cuda/include/hl_device_functions.cuh b/paddle/cuda/include/hl_device_functions.cuh index 88d950d6c1..159c26f443 100755 --- a/paddle/cuda/include/hl_device_functions.cuh +++ b/paddle/cuda/include/hl_device_functions.cuh @@ -48,5 +48,24 @@ inline __device__ double paddleAtomicAdd(double* address, double val) { } } // namespace paddle +/** + * @brief sum reduction + * + * @param[in,out] smem input data, better to use __shared__ memory. + * @param[in] tid thread index. + * @param[in] threads the total thread number used to reduce, + * such as, blockDim.x. + * + * @return smem[0]: the sum of each elements in smem. + */ +__device__ __forceinline__ +void simpleReduce(real* smem, int tid, int threads) { + for (unsigned int s = threads / 2; s > 0; s >>= 1) { + if (tid < s) { + smem[tid] += smem[tid + s]; + } + __syncthreads(); + } +} #endif /* HL_DEVICE_FUNCTIONS_CUH_ */ diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index 1741979047..71e8f8e3a6 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -229,4 +229,40 @@ extern void hl_cossim_derivative(real* grad, int input2_height, real scale); +/** + * @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel]. + * + * @param[in] A_d input matrix (M x N). + * @param[in] B_d input matrix (1 x channel). + * @param[in] channel width of B. + * @param[in] dimM height of A. + * @param[in] dimN width of A. + * @param[in] scale scalar used for addition. + * + */ +extern void hl_matrix_add_shared_bias(real* A_d, + real* B_d, + const int channel, + const int dimM, + const int dimN, + real scale); + +/** + * @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel]. + * + * @param[in] B_d input matrix (1 x channel). + * @param[in] A_d input matrix (M x N). + * @param[in] channel width of B. + * @param[in] dimM height of A. + * @param[in] dimN width of A. + * @param[in] scale scalar used for addition. + * + */ +extern void hl_matrix_collect_shared_bias(real* B_d, + real* A_d, + const int channel, + const int dimM, + const int dimN, + real scale); + #endif /* HL_MATRIX_H_ */ diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index f1f1020c84..e37b127543 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -101,4 +101,17 @@ inline void hl_cossim_derivative(real* grad, int input2_height, real scale) {} +inline void hl_matrix_add_shared_bias(real* A_d, + real* B_d, + const int channel, + const int dimM, + const int dimN, + real scale) {} + +inline void hl_matrix_collect_shared_bias(real* B_d, + real* A_d, + const int channel, + const int dimM, + const int dimN, + real scale) {} #endif // HL_MATRIX_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index b215c0f6e3..7810d0d100 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -20,6 +20,11 @@ limitations under the License. */ #include "hl_thread.ph" #include "hl_dso_loader.h" #include "paddle/utils/Logging.h" +#include "paddle/utils/CommandLineParser.h" + +P_DEFINE_int32(cudnn_conv_workspace_limit_in_mb, 4096, + "Specify cuDNN max workspace limit, in units MB, " + "4096MB=4GB by default."); namespace dynload { @@ -242,7 +247,7 @@ void hl_conv_workspace(hl_tensor_descriptor input, CHECK_NOTNULL(conv); // Specify workspace limit directly - size_t memoryLimitBytes = 8 * 1024 * 1024; + size_t memoryLimitBytes = (1LL << 20) * FLAGS_cudnn_conv_workspace_limit_in_mb; // cudnn convolution forward configuration cudnnTensorDescriptor_t fwd_src_desc = GET_TENSOR_DESCRIPTOR(input); diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 067e68c41e..3df9f63f9e 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -20,6 +20,7 @@ limitations under the License. */ #include "hl_sequence.h" #include "paddle/utils/Logging.h" #include "hl_device_functions.cuh" +#include "hl_gpu_matrix_kernel.cuh" DEFINE_MATRIX_UNARY_OP(Zero, a = 0); DEFINE_MATRIX_TERNARY_PARAMETER_OP(_add, TWO_PARAMETER, c = p1*a + p2*b); @@ -673,3 +674,89 @@ void hl_cossim_derivative(real* grad, input1_height, input2_height, scale); CHECK_SYNC("hl_cossim_derivate failed"); } + +__global__ void KeMatrixAddSharedBias(real* A, + real* B, + const int channel, + const int M, + const int N, + real scale) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int dim = N / channel; + if (index < M * N) { + int i = index % N; + i = i / dim; + A[index] += scale * B[i]; + } +} + +void hl_matrix_add_shared_bias(real* A_d, + real* B_d, + const int channel, + const int dimM, + const int dimN, + real scale) { + const int blocks = 512; + const int grids = DIVUP(dimM * dimN, blocks); + KeMatrixAddSharedBias<<>> + (A_d, B_d, channel, dimM, dimN, scale); + CHECK_SYNC("hl_matrix_add_shared_bias failed"); +} + + +template +__global__ void KeMatrixCollectSharedBias(real *B, + real *A, + const int channel, + const int M, + const int N, + const int dim, + const int limit, + real scale) { + if (dim < limit) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < channel) { + real sum = 0.0; + for (int i = 0; i < M; ++i) { + for (int j = 0; j < dim; ++j) { + sum += A[i * N + index * dim + j]; + } + } + B[index] += scale * sum; + } + } else { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + __shared__ real smem[blockSize]; + real sum = 0.0; + for (int j = 0; j < ((dim * M + blockSize - 1) / blockSize); ++j) { + int n = j * blockSize + tid; + int m = n / dim; + int w = n % dim; + smem[tid] = (m < M && w < dim) ? A[m * N + bid * dim + w] : 0.0; + __syncthreads(); + simpleReduce(smem, tid, blockSize); + sum += smem[0]; + } + if (tid == 0) { + B[bid] += scale * sum; + } + } +} + +void hl_matrix_collect_shared_bias(real* B_d, + real* A_d, + const int channel, + const int dimM, + const int dimN, + real scale) { + const int dim = dimN / channel; + const int blocks = 256; + const int limit = 64; + int grids = (dimM * dim) < limit ? DIVUP(channel, blocks) : channel; + + KeMatrixCollectSharedBias + <<< grids, blocks, 0, STREAM_DEFAULT>>> + (B_d, A_d, channel, dimM, dimN, dim, limit, scale); + CHECK_SYNC("hl_matrix_collect_shared_bias failed"); +} diff --git a/paddle/cuda/src/hl_cuda_sparse.cuh b/paddle/cuda/src/hl_cuda_sparse.cuh index c3b98f4ebc..9cf2d5a843 100644 --- a/paddle/cuda/src/hl_cuda_sparse.cuh +++ b/paddle/cuda/src/hl_cuda_sparse.cuh @@ -908,24 +908,6 @@ int findIndex(int* indice, int num, int index) { return (end - 1); } -/** - * @brief sum reduction - * - * @param[in,out] smem input data, better to use __shared__ memory. - * @param[in] tid local thread index. - * @param[in] blockDimX the size of blockDim.x. - * - * note: return smem[0]: the sum of each elements of smem. - */ -__device__ __forceinline__ -void reduce(real* smem, int tid, int blockDimX) { - for (unsigned int s = blockDimX / 2; s > 0; s >>= 1) { - if (tid < s) { - smem[tid] += smem[tid + s]; - } - __syncthreads(); - } -} /** * @brief sum columns of csr sparse matrix (csr_val), then add to a_val. diff --git a/paddle/gserver/layers/ConcatenateLayer.cpp b/paddle/gserver/layers/ConcatenateLayer.cpp index 52a7cb6f77..bb6709b8df 100644 --- a/paddle/gserver/layers/ConcatenateLayer.cpp +++ b/paddle/gserver/layers/ConcatenateLayer.cpp @@ -97,7 +97,8 @@ void ConcatenateLayer::backward(const UpdateCallback& callback) { */ class ConcatenateLayer2 : public Layer { public: - explicit ConcatenateLayer2(const LayerConfig& config) : Layer(config) {} + explicit ConcatenateLayer2(const LayerConfig& config) : + Layer(config) {} ~ConcatenateLayer2() {} @@ -110,6 +111,8 @@ protected: std::vector> projections_; std::vector projOutput_; std::vector> projCol_; + bool sharedBias_; + std::unique_ptr biases_; }; REGISTER_LAYER(concat2, ConcatenateLayer2); @@ -119,7 +122,6 @@ bool ConcatenateLayer2::init(const LayerMap& layerMap, /* Initialize the basic parent class */ if (!Layer::init(layerMap, parameterMap)) return false; - CHECK(!biasParameter_); CHECK_EQ(inputLayers_.size(), parameters_.size()); projections_.reserve(inputLayers_.size()); projCol_.reserve(inputLayers_.size()); @@ -137,6 +139,13 @@ bool ConcatenateLayer2::init(const LayerMap& layerMap, } CHECK_EQ(getSize(), endCol); + /* initialize biases_ */ + if (biasParameter_.get() != NULL) { + sharedBias_ = config_.shared_biases(); + size_t psize = config_.bias_size(); + biases_ = std::unique_ptr(new Weight(1, psize, biasParameter_)); + } + return true; } @@ -154,8 +163,17 @@ void ConcatenateLayer2::forward(PassType passType) { projOutput_[i].grad = output_.grad->subColMatrix(startCol, endCol); } - for (size_t i = 0; i != inputLayers_.size(); ++i) { - projections_[i]->forward(&getInput(i), &projOutput_[i], passType); + { + AsyncGpuBlock block; + for (size_t i = 0; i != inputLayers_.size(); ++i) { + projections_[i]->forward(&getInput(i), &projOutput_[i], passType); + } + } + + /* add the bias-vector */ + if (biases_) { + REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); + output_.value->addBias(*(biases_->getW()), 1, sharedBias_); } /* activation */ { @@ -170,6 +188,13 @@ void ConcatenateLayer2::backward(const UpdateCallback& callback) { backwardActivation(); } + AsyncGpuBlock block; + if (biases_ && biases_->getWGrad()) { + REGISTER_TIMER_INFO("Concat2BpBiasTimer", getName().c_str()); + biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBias_); + biases_->getParameterPtr()->incUpdate(callback); + } + for (size_t i = 0; i != inputLayers_.size(); ++i) { if (projections_[i]) { projections_[i]->backward(callback); diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp index 9ed9572139..040510b7ad 100644 --- a/paddle/gserver/layers/ConvBaseLayer.cpp +++ b/paddle/gserver/layers/ConvBaseLayer.cpp @@ -35,25 +35,12 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, filterSizeY_.push_back(conf.filter_size_y()); filterPixels_.push_back(filterSize_.back() * filterSizeY_.back()); channels_.push_back(conf.channels()); - imgSize_.push_back(conf.img_size()); - imgPixels_.push_back(imgSize_.back() * imgSize_.back()); + imgSizeH_.push_back(conf.img_size()); + imgSizeW_.push_back(conf.img_size()); groups_.push_back(conf.groups()); filterChannels_.push_back(conf.filter_channels()); - outputX_.push_back(conf.output_x()); - outputs_.push_back(outputX_.back() * outputX_.back()); - } - - /* initialize the weightList */ - CHECK(inputLayers_.size() == parameters_.size()); - for (size_t i = 0; i < inputLayers_.size(); i++) { - size_t height, width; - height = filterPixels_[i] * filterChannels_[i]; - width = numFilters_; - - // create a new weight - CHECK_EQ(parameters_[i]->getSize(), width * height); - Weight* w = new Weight(height, width, parameters_[i]); - weights_.emplace_back(w); + outputH_.push_back(conf.output_x()); + outputW_.push_back(conf.output_x()); } /* initialize the biases_ */ @@ -74,4 +61,34 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, return true; } +size_t ConvBaseLayer::calOutputSize() { + auto clearAndReserve = [this](IntV* vec) { + vec->clear(); + vec->reserve(this->inputLayers_.size()); + }; + clearAndReserve(&imgSizeH_); + clearAndReserve(&imgSizeW_); + clearAndReserve(&outputH_); + clearAndReserve(&outputW_); + size_t layerSize = 0; + for (size_t i = 0; i < inputLayers_.size(); i++) { + imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); + imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); + if (imgSizeH_[i] == 0) + imgSizeH_[i] = config_.inputs(i).conv_conf().img_size(); + if (imgSizeW_[i] == 0) + imgSizeW_[i] = config_.inputs(i).conv_conf().img_size(); + outputH_.push_back( + outputSize(imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i])); + outputW_.push_back( + outputSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i])); + CHECK_EQ(outputH_[i], outputH_[0]); + CHECK_EQ(outputW_[i], outputW_[0]); + } + getOutput().setFrameHeight(outputH_[0]); + getOutput().setFrameWidth(outputW_[0]); + layerSize = outputH_[0] * outputW_[0] * size_t(numFilters_); + return layerSize; +} + } // namespace paddle diff --git a/paddle/gserver/layers/ConvBaseLayer.h b/paddle/gserver/layers/ConvBaseLayer.h index eaeaebf43b..316514acf1 100644 --- a/paddle/gserver/layers/ConvBaseLayer.h +++ b/paddle/gserver/layers/ConvBaseLayer.h @@ -43,19 +43,18 @@ protected: IntV filterSizeY_; /// The spatial dimensions of the convolution input. IntV channels_; - /// The spatial dimensions of input feature map. - IntV imgSize_; - /// The total pixel size of input feature map. - /// imgPixels_ = imgSizeX_ * imgSizeY_. - IntV imgPixels_; + /// The spatial dimensions of input feature map height. + IntV imgSizeH_; + /// The spatial dimensions of input feature map width. + IntV imgSizeW_; /// filterPixels_ = filterSizeX_ * filterSizeY_. IntV filterPixels_; /// filterChannels_ = channels_/groups_. IntV filterChannels_; - /// The spatial dimensions of output feature map. - IntV outputX_; - /// The spatial dimensions of output feature map. - IntV outputs_; + /// The spatial dimensions of output feature map height. + IntV outputH_; + /// The spatial dimensions of output feature map width. + IntV outputW_; /// Group size, refer to grouped convolution in /// Alex Krizhevsky's paper: when group=2, the first half of the /// filters are only connected to the first half of the input channels, @@ -80,6 +79,13 @@ public: virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + /** + * imgSizeH_ and imgSizeW_ will be set according to the previous input layers + * in this function. Then it will calculate outputH_ and outputW_ and set them + * into output argument. + */ + virtual size_t calOutputSize(); + Weight& getWeight(int idx) { return *weights_[idx]; } /** diff --git a/paddle/gserver/layers/ConvProjection.cpp b/paddle/gserver/layers/ConvProjection.cpp new file mode 100644 index 0000000000..d1ce53fe26 --- /dev/null +++ b/paddle/gserver/layers/ConvProjection.cpp @@ -0,0 +1,210 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#include "paddle/utils/Stat.h" +#include "ConvProjection.h" + +namespace paddle { + +REGISTER_PROJECTION(conv, ConvProjection); + +ThreadLocalD> ConvProjection::convMem_; + +ConvProjection::ConvProjection(const ProjectionConfig& config, + ParameterPtr parameter, bool useGpu) + : Projection(config, parameter, useGpu) { + + CHECK(useGpu); // only support GPU + getConvParams(); + initCudnn(); + + size_t height = filterH_ * filterW_ * channels_ / groups_; + size_t width = numFilters_; + weight_.reset(new Weight(height, width, parameter)); + weightOffset_ = height * width / groups_; +} + +void ConvProjection::getConvParams() { + const ConvConfig &conf = config_.conv_conf(); + paddingH_ = conf.padding_y(); + paddingW_ = conf.padding(); + + strideH_ = conf.stride_y(); + strideW_ = conf.stride(); + + filterH_ = conf.filter_size_y(); + filterW_ = conf.filter_size(); + + configImgH_ = conf.img_size(); + configImgW_ = conf.img_size(); + + channels_ = conf.channels(); + numFilters_ = config_.num_filters(); + + groups_ = conf.groups(); + CHECK_EQ(channels_ % groups_, 0); + CHECK_EQ(numFilters_ % groups_, 0); +} + +void ConvProjection::initCudnn() { + hl_create_filter_descriptor(&filterDesc_, channels_, numFilters_, + filterH_, filterW_); + hl_create_tensor_descriptor(&inputDesc_); + hl_create_tensor_descriptor(&outputDesc_); + hl_create_convolution_descriptor(&convDesc_, inputDesc_, filterDesc_, + paddingH_, paddingW_, strideH_, strideW_); + + // initialize all to default algorithms + fwdAlgo_ = 0; + bwdFilterAlgo_ = 0; + bwdDataAlgo_ = 0; + fwdLimitBytes_ = 0; + bwdDataLimitBytes_ = 0; + bwdFilterLimitBytes_ = 0; + workSpaceInBytes_ = 0; + + batchNum_ = 0; + isSelectAlgo_ = false; +} + +void ConvProjection::reshapeTensorDesc(int batchSize) { + hl_tensor_reshape(inputDesc_, batchSize, channels_, imageH_, imageW_, + channels_ * imageH_ * imageW_, imageH_ * imageW_, + imageW_, 1); + hl_reset_convolution_descriptor(convDesc_, inputDesc_, filterDesc_, + paddingH_, paddingW_, strideH_, strideW_); + + // The stride between two consecutive images in ConvProjection may not be 1, + // for example, in the case of layer ConcatenateLayer2 with two + // ConvProjection, the stride is the output_size of layer ConcatenateLayer2. + // So the calculation of nStride is different from CudnnConvLayer. + // In fact, only "nStride = out_->value->getStride()" is ok. + size_t nStride = numFilters_ * outputH_ * outputW_; + if (out_->value->isContiguous()) { + CHECK_EQ(nStride, out_->value->getWidth()); + } else { + nStride = out_->value->getStride(); + } + + hl_tensor_reshape(outputDesc_, batchSize, numFilters_, outputH_, outputW_, + nStride, outputH_ * outputW_, outputW_, 1); +} + +void ConvProjection::reshape(int batchSize) { + size_t width = calOutputSize(); + CHECK_EQ(width, out_->value->getWidth()); + + isSelectAlgo_ = (batchSize == batchNum_); + batchNum_ = batchSize; + + if (!isSelectAlgo_) { + reshapeTensorDesc(batchSize); + hl_conv_workspace(inputDesc_, outputDesc_, filterDesc_, + convDesc_, &fwdAlgo_, &fwdLimitBytes_, + &bwdDataAlgo_, &bwdDataLimitBytes_, + &bwdFilterAlgo_, &bwdFilterLimitBytes_); + + size_t maxWorkSpace = 0; + maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_); + maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_); + workSpaceInBytes_ = maxWorkSpace; + + + VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_ + << " / " << bwdDataAlgo_ + << " / " << bwdFilterAlgo_; + } + + isSelectAlgo_ = true; +} + +void ConvProjection::forward() { + int batchSize = in_->value->getHeight(); + reshape(batchSize); + + void* workSpace = NULL; + if (workSpaceInBytes_ > 0) { + workSpace = getSpaceBytes(workSpaceInBytes_); + } + + for (int g = 0; g < groups_; ++g) { + REGISTER_TIMER_INFO("CudnnConvFwTimer", getName().c_str()); + + real *inputData = in_->value->getData() + g * inputOffset_; + real *wgtData = weight_->getW()->getData() + g * weightOffset_; + real *outData = out_->value->getData() + g * outputOffset_; + hl_convolution_forward(inputDesc_, inputData, outputDesc_, + outData, filterDesc_, wgtData, + convDesc_, workSpace, + fwdLimitBytes_, fwdAlgo_); + } +} + +void ConvProjection::backward(const UpdateCallback& callback) { + REGISTER_TIMER_INFO("CudnnConvBpTimer", getName().c_str()); + + void* workSpace = NULL; + if (workSpaceInBytes_ > 0) { + workSpace = getSpaceBytes(workSpaceInBytes_); + } + + for (int g = 0; g < groups_; ++g) { + real *outGrad = out_->grad->getData() + g * outputOffset_; + if (weight_->getWGrad()) { + real *inputData = in_->value->getData() + g * inputOffset_; + real *weightGrad = weight_->getWGrad()->getData() + g * weightOffset_; + hl_convolution_backward_filter( + inputDesc_, inputData, outputDesc_, outGrad, filterDesc_, + weightGrad, convDesc_, workSpace, bwdFilterLimitBytes_, + bwdFilterAlgo_); + } + + MatrixPtr preGrad = in_->grad; + if (NULL != preGrad) { + real *inputGrad = preGrad->getData() + g * inputOffset_; + real *wgtData = weight_->getW()->getData() + g* weightOffset_; + hl_convolution_backward_data( + inputDesc_, inputGrad, outputDesc_, outGrad, filterDesc_, + wgtData, convDesc_, workSpace, bwdDataLimitBytes_, + bwdDataAlgo_); + } + } + + weight_->getParameterPtr()->incUpdate(callback); +} + +void* ConvProjection::getSpaceBytes(size_t size) { + std::vector& convMem = *convMem_; + if (convMem.empty()) { + int numDevices = hl_get_device_count(); + convMem.resize(numDevices); + } + + int devId = hl_get_device(); + MemoryHandle** localMem = &(convMem[devId]); + if (NULL == *localMem || size > (*localMem)->getAllocSize()) { + *localMem = new GpuMemoryHandle(size); + } + return (*localMem)->getBuf(); +} + +ConvProjection::~ConvProjection() { + hl_destroy_tensor_descriptor(inputDesc_); + hl_destroy_tensor_descriptor(outputDesc_); + hl_destroy_filter_descriptor(filterDesc_); + hl_destroy_convolution_descriptor(convDesc_); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/ConvProjection.h b/paddle/gserver/layers/ConvProjection.h new file mode 100644 index 0000000000..41a100ac3c --- /dev/null +++ b/paddle/gserver/layers/ConvProjection.h @@ -0,0 +1,125 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#pragma once + +#include "Projection.h" + +namespace paddle { + +/** + * @brief Convolution projection do the same calculation with CudnnConvLayer. + */ +class ConvProjection : public Projection { +public: + /** + * Constructor. + */ + ConvProjection(const ProjectionConfig& config, ParameterPtr parameter, + bool useGpu); + + ~ConvProjection(); + + virtual void forward(); + virtual void backward(const UpdateCallback& callback); + +protected: + void getConvParams(); + void initCudnn(); + + void reshapeTensorDesc(int batchSize); + void reshape(int batchSize); + + int outputSize(int imageSize, int filterSize, int padding, int stride) { + return (imageSize - filterSize + 2 * padding) / stride + 1; + } + + size_t calOutputSize() { + imageH_ = in_->getFrameHeight(); + imageW_ = in_->getFrameWidth(); + if (imageH_ == 0) imageH_ = configImgH_; + if (imageW_ == 0) imageW_ = configImgW_; + outputH_ = outputSize(imageH_, filterH_, paddingH_, strideH_); + outputW_ = outputSize(imageW_, filterW_, paddingW_, strideW_); + + const_cast(out_)->setFrameHeight(outputH_); + const_cast(out_)->setFrameWidth(outputW_); + + inputOffset_ = (channels_ / groups_) * imageH_ * imageW_; + outputOffset_ = (numFilters_ / groups_) * outputH_ * outputW_; + return outputH_ * outputW_ * numFilters_; + } + + static void* getSpaceBytes(size_t size); + + /// imageH_ and imageW_ is calculated from the input layer. + int imageH_, imageW_; + /// configImgH_ and configImgW_ is obtained from config. + int configImgH_, configImgW_; + int outputH_, outputW_; + int channels_, numFilters_; + int paddingH_, paddingW_; + int strideH_, strideW_; + int filterH_, filterW_; + /// One group offset of input data. + int inputOffset_; + /// One group offset of output data. + int outputOffset_; + /// One group offset of weight. + int weightOffset_; + int groups_; + + /// Cudnn tensor descriptor for input. + hl_tensor_descriptor inputDesc_; + /// Cudnn tensor descriptor for output. + hl_tensor_descriptor outputDesc_; + /// Cudnn tensor descriptor for filter. + hl_filter_descriptor filterDesc_; + /// Cudnn tensor descriptor for a convolution operation. + hl_convolution_descriptor convDesc_; + + /// Record the algorithm for forward convolution, which is obtained by cudnn + /// api to search the best suited algorithm. + int fwdAlgo_; + /// Record the algorithm for computing convolution gradient with respect to + /// filter coefficients. + int bwdFilterAlgo_; + /// Record the algorithm for computing convolution gradient with respect to + /// the output. + int bwdDataAlgo_; + /// Amount of GPU memory needed as workspace to be able to execute a + /// forward convolution with the specified algo. + size_t fwdLimitBytes_; + /// Amount of GPU memory needed as workspace to be able to execute a + /// backwardFilter with the specified algo. + size_t bwdDataLimitBytes_; + /// Amount of GPU memory needed as workspace to be able to execute a + /// backwardData with the specified algo. + size_t bwdFilterLimitBytes_; + /// Size of total work space. + size_t workSpaceInBytes_; + + /// Whether to call cuDNN api to choose conv algorithm. + bool isSelectAlgo_; + /// batchNum is used to record batch size. If the batch size is changed, + /// the selection algorithm will be called. + int batchNum_; + bool bias_; + + std::unique_ptr weight_; + static ThreadLocalD> convMem_; +}; + +} // namespace paddle diff --git a/paddle/gserver/layers/CudnnConvLayer.cpp b/paddle/gserver/layers/CudnnConvLayer.cpp index 0f932f960f..e77216f17c 100644 --- a/paddle/gserver/layers/CudnnConvLayer.cpp +++ b/paddle/gserver/layers/CudnnConvLayer.cpp @@ -22,215 +22,64 @@ REGISTER_LAYER(cudnn_conv, CudnnConvLayer); bool CudnnConvLayer::init(const LayerMap &layerMap, const ParameterMap ¶meterMap) { - ConvBaseLayer::init(layerMap, parameterMap); + if (!ConvBaseLayer::init(layerMap, parameterMap)) return false; CHECK(useGpu_) << "CudnnConvLayer only support gpu"; - maxGroups_ = 0; - for (size_t i = 0; i < inputLayers_.size(); i++) { - CHECK_EQ(channels_[i] % groups_[i], 0); - CHECK_EQ(numFilters_ % groups_[i], 0); - - hl_filter_descriptor filter; - hl_create_filter_descriptor(&filter, channels_[i] / groups_[i], - numFilters_ / groups_[i], filterSizeY_[i], - filterSize_[i]); - filterDesc_.push_back(filter); - - hl_tensor_descriptor input; - hl_create_tensor_descriptor(&input); - inputDesc_.push_back(input); - - hl_tensor_descriptor output; - int outputX = - outputSize(imgSize_[i], filterSize_[i], padding_[i], stride_[i]); - CHECK_EQ(outputX, outputX_[i]); - hl_create_tensor_descriptor(&output); - outputDesc_.push_back(output); + CHECK_EQ(inputLayers_.size(), parameters_.size()); + projections_.reserve(inputLayers_.size()); + projConf_.reserve(inputLayers_.size()); - hl_convolution_descriptor conv; - hl_create_convolution_descriptor(&conv, input, filter, paddingY_[i], - padding_[i], strideY_[i], stride_[i]); - convDesc_.push_back(conv); - - weightOffset_.push_back((numFilters_ / groups_[i]) * - (channels_[i] / groups_[i]) * filterPixels_[i]); - inputOffset_.push_back((channels_[i] / groups_[i]) * imgSize_[i] * - imgSize_[i]); - outputOffset_.push_back((numFilters_ / groups_[i]) * outputX_[i] * - outputX_[i]); - - // initialize all to default algorithms - fwdAlgo_.push_back(0); - bwdFilterAlgo_.push_back(0); - bwdDataAlgo_.push_back(0); - fwdLimitBytes_.push_back(0); - bwdFilterLimitBytes_.push_back(0); - bwdDataLimitBytes_.push_back(0); - - // cudnn streams per group equal to 1 - if (groups_[i] > maxGroups_) { - maxGroups_ = groups_[i]; - } - } - - workSpaceInBytes_ = 0; - workSpaceData_ = NULL; - for (int i = 0; i < maxGroups_; ++i) { - workSpace_.push_back(NULL); + numFilters_ = config_.num_filters(); + CHECK(config_.shared_biases()); + for (size_t i = 0; i < inputLayers_.size(); i++) { + ProjectionConfig* conf = new ProjectionConfig(); + conf->set_type("conv"); + conf->set_num_filters(numFilters_); + conf->set_allocated_conv_conf( + config_.mutable_inputs(i)->mutable_conv_conf()); + conf->set_input_size(getPrev(i)->getSize()); + conf->set_output_size(getSize()); + projConf_.emplace_back(conf); + projections_.emplace_back(Projection::create(*projConf_[i], + parameters_[i], useGpu_)); } if (biases_.get() && sharedBiases_) { hl_create_tensor_descriptor(&biasDesc_); + hl_create_tensor_descriptor(&outputDesc_); hl_tensor_reshape(biasDesc_, 1, numFilters_ / groups_[0], 1, 1); biasOffset_ = numFilters_ / groups_[0]; } - batchNum_ = 0; - isSelectAlgo_ = false; return true; } -void CudnnConvLayer::allocConvWorkSpace(size_t maxWorkSpace) { - size_t totalWorkSpace = maxWorkSpace * maxGroups_; - - if (totalWorkSpace > workSpaceInBytes_) { - if (workSpaceInBytes_ != 0) { - hl_free_mem_device(workSpaceData_); - } - // total amount of storage needed over all groups - workSpaceData_ = hl_malloc_device(totalWorkSpace); - - // update work space address for each group - for (int i = 0; i < maxGroups_; ++i) { - workSpace_[i] = reinterpret_cast(workSpaceData_) - + i * maxWorkSpace; - } - workSpaceInBytes_ = totalWorkSpace; - } -} - -void CudnnConvLayer::reshape(int batchSize) { - CHECK_NE(inputLayers_.size(), 0UL); - imageH_ = inputLayers_[0]->getOutput().getFrameHeight(); - imageW_ = inputLayers_[0]->getOutput().getFrameWidth(); - if (imageH_ == 0) imageH_ = imgSize_[0]; - if (imageW_ == 0) imageW_ = imgSize_[0]; - - for (size_t i = 1; i < inputLayers_.size(); i++) { - int imageH = inputLayers_[i]->getOutput().getFrameHeight(); - int imageW = inputLayers_[i]->getOutput().getFrameWidth(); - if (imageH) { - CHECK_EQ(imageH_, imageH) << "Inputs must have same height."; - } - if (imageW) { - CHECK_EQ(imageW_, imageW) << "Inputs must have same width."; - } - } - - outputH_ = outputSize(imageH_, filterSizeY_[0], paddingY_[0], strideY_[0]); - outputW_ = outputSize(imageW_, filterSize_[0], padding_[0], stride_[0]); - // check outputH & outputW - getOutput().setFrameHeight(outputH_); - getOutput().setFrameWidth(outputW_); - - // if the batchSize remains the same, set isSelectAlgo_ true. - // Otherwise, set isSelectAlgo_ false and select algo again. - isSelectAlgo_ = (batchSize == batchNum_); - batchNum_ = batchSize; - - size_t maxWorkSpace = 0; - for (size_t i = 0; i < inputLayers_.size(); i++) { - CHECK_EQ(inputLayers_[i]->getOutput().value->getWidth(), - (size_t)(channels_[i] * imageH_ * imageW_)); - - hl_tensor_reshape(inputDesc_[i], batchSize, channels_[i] / groups_[i], - imageH_, imageW_, channels_[i] * imageH_ * imageW_, - imageH_ * imageW_, imageW_, 1); - - hl_tensor_reshape(outputDesc_[i], batchSize, numFilters_ / groups_[i], - outputH_, outputW_, numFilters_ * outputH_ * outputW_, - outputH_ * outputW_, outputW_, 1); - - hl_reset_convolution_descriptor(convDesc_[i], inputDesc_[i], - filterDesc_[i], paddingY_[i], - padding_[i], strideY_[i], stride_[i]); - - inputOffset_[i] = (channels_[i] / groups_[i]) * imageH_ * imageW_; - outputOffset_[i] = (numFilters_ / groups_[i]) * outputH_ * outputW_; - - if (!isSelectAlgo_) { - hl_conv_workspace(inputDesc_[i], outputDesc_[i], filterDesc_[i], - convDesc_[i], &fwdAlgo_[i], &fwdLimitBytes_[i], - &bwdDataAlgo_[i], &bwdDataLimitBytes_[i], - &bwdFilterAlgo_[i], &bwdFilterLimitBytes_[i]); - - maxWorkSpace = std::max(fwdLimitBytes_[i], bwdDataLimitBytes_[i]); - maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_[i]); - - VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_[i] - << " / " << bwdDataAlgo_[i] - << " / " << bwdFilterAlgo_[i]; - } - } - - if (!isSelectAlgo_) { - allocConvWorkSpace(maxWorkSpace); - } - - isSelectAlgo_ = true; -} - void CudnnConvLayer::forward(PassType passType) { Layer::forward(passType); - int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); - reshape(batchSize); - resetOutput(batchSize, outputH_ * outputW_ * numFilters_); + + int batchSize = getInput(0).getBatchSize(); + resetOutput(batchSize, calOutputSize()); for (size_t i = 0; i != inputLayers_.size(); ++i) { - REGISTER_TIMER_INFO("CudnnConvFwTimer", getName().c_str()); - for (int g = 0; g < groups_[i]; ++g) { - real *inputData = getInputValue(i)->getData() + inputOffset_[i] * g; - real *wgtData = weights_[i]->getW()->getData() + weightOffset_[i] * g; - real *outData = getOutputValue()->getData() + outputOffset_[i] * g; - hl_convolution_forward(inputDesc_[i], inputData, outputDesc_[i], - outData, filterDesc_[i], wgtData, - convDesc_[i], workSpace_[g], - fwdLimitBytes_[i], fwdAlgo_[i]); - } + projections_[i]->forward(&getInput(i), &getOutput(), passType); } if (biases_) { REGISTER_TIMER_INFO("CudnnConvBiasTimer", getName().c_str()); - addBiases(); - } - - forwardActivation(); -} - -void CudnnConvLayer::addBiases() { - if (sharedBiases_) { + int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + hl_tensor_reshape(outputDesc_, batchSize, numFilters_ / groups_[0], + outputH_[0], outputW_[0], numFilters_ * outputH_[0] * outputW_[0], + outputH_[0] * outputW_[0], outputW_[0], 1); + outputOffset_ = getOutputValue()->getWidth() / groups_[0]; for (int g = 0; g < groups_[0]; ++g) { real *biasData = biases_->getW()->getData() + biasOffset_ * g; - real *outData = getOutputValue()->getData() + outputOffset_[0] * g; + real *outData = getOutputValue()->getData() + outputOffset_ * g; hl_convolution_forward_add_bias(biasDesc_, biasData, - outputDesc_[0], outData); + outputDesc_, outData); } - } else { - LOG(FATAL) << "Not supported"; } -} -void CudnnConvLayer::bpropBiases() { - if (sharedBiases_) { - for (int g = 0; g < groups_[0]; ++g) { - real *biasGrad = biases_->getWGrad()->getData() + biasOffset_ * g; - real *outGrad = getOutputGrad()->getData() + outputOffset_[0] * g; - hl_convolution_backward_bias(biasDesc_, biasGrad, - outputDesc_[0], outGrad); - } - } else { - LOG(FATAL) << "Not supported"; - } + forwardActivation(); } void CudnnConvLayer::backward(const UpdateCallback &callback) { @@ -238,52 +87,23 @@ void CudnnConvLayer::backward(const UpdateCallback &callback) { if (biases_ && biases_->getWGrad()) { REGISTER_TIMER_INFO("CudnnConvBpBiasTimer", getName().c_str()); - bpropBiases(); + for (int g = 0; g < groups_[0]; ++g) { + real *biasGrad = biases_->getWGrad()->getData() + biasOffset_ * g; + real *outGrad = getOutputGrad()->getData() + outputOffset_ * g; + hl_convolution_backward_bias(biasDesc_, biasGrad, outputDesc_, outGrad); + } biases_->getParameterPtr()->incUpdate(callback); } for (size_t i = 0; i != inputLayers_.size(); ++i) { - REGISTER_TIMER_INFO("CudnnConvBpTimer", getName().c_str()); - for (int g = 0; g < groups_[i]; ++g) { - real *outGrad = getOutputGrad()->getData() + outputOffset_[i] * g; - if (weights_[i]->getWGrad()) { - real *inputData = getInputValue(i)->getData() + inputOffset_[i] * g; - real *weightGrad = - weights_[i]->getWGrad()->getData() + weightOffset_[i] * g; - hl_convolution_backward_filter( - inputDesc_[i], inputData, outputDesc_[i], outGrad, filterDesc_[i], - weightGrad, convDesc_[i], workSpace_[g], bwdFilterLimitBytes_[i], - bwdFilterAlgo_[i]); - } - - MatrixPtr preGrad = getInputGrad(i); - if (NULL != preGrad) { - real *inputGrad = preGrad->getData() + inputOffset_[i] * g; - real *wgtData = weights_[i]->getW()->getData() + weightOffset_[i] * g; - hl_convolution_backward_data( - inputDesc_[i], inputGrad, outputDesc_[i], outGrad, filterDesc_[i], - wgtData, convDesc_[i], workSpace_[g], bwdDataLimitBytes_[i], - bwdDataAlgo_[i]); - } - } - weights_[i]->getParameterPtr()->incUpdate(callback); + projections_[i]->backward(callback); } } CudnnConvLayer::~CudnnConvLayer() { - if (biasDesc_) { + if (biases_) { hl_destroy_tensor_descriptor(biasDesc_); - } - - for (size_t i = 0; i < inputDesc_.size(); i++) { - hl_destroy_tensor_descriptor(inputDesc_[i]); - hl_destroy_tensor_descriptor(outputDesc_[i]); - hl_destroy_filter_descriptor(filterDesc_[i]); - hl_destroy_convolution_descriptor(convDesc_[i]); - } - if (workSpaceInBytes_ != 0) { - hl_free_mem_device(workSpaceData_); - workSpaceInBytes_ = 0; + hl_destroy_tensor_descriptor(outputDesc_); } } diff --git a/paddle/gserver/layers/CudnnConvLayer.h b/paddle/gserver/layers/CudnnConvLayer.h index a6dadba10d..6390d96315 100644 --- a/paddle/gserver/layers/CudnnConvLayer.h +++ b/paddle/gserver/layers/CudnnConvLayer.h @@ -17,12 +17,13 @@ limitations under the License. */ #include "ConvBaseLayer.h" #include "paddle/math/Matrix.h" +#include "Projection.h" #include namespace paddle { /** - * @brief A subclass of ConvBaseLayer by cuDNN implementation. It only + * @brief A 2-dimension conv layer implemented by cuDNN. It only * supports GPU mode. We automatic select CudnnConvLayer for GPU * mode and ExpandConvLayer for CPU mode if you set type of "conv". * User also can specfiy type of "exconv" or "cudnn_conv" for @@ -31,81 +32,21 @@ namespace paddle { * The config file api is img_conv_layer. */ class CudnnConvLayer : public ConvBaseLayer { -private: - /// resize Cudnn workspace size - void allocConvWorkSpace(size_t maxWorkSpace); - protected: - int imageH_, imageW_, outputH_, outputW_; - /// Cudnn tensor descriptor for bias. + std::vector> projConf_; + std::vector> projections_; + hl_tensor_descriptor biasDesc_; - /// Cudnn tensor descriptor for input. - std::vector inputDesc_; - /// Cudnn tensor descriptor for output. - std::vector outputDesc_; - /// Cudnn tensor descriptor for filter. - std::vector filterDesc_; - /// Cudnn tensor descriptor for a convolution operation. - std::vector convDesc_; - /// One sample offset of input data. - IntV inputOffset_; - /// One sample offset of output data. - IntV outputOffset_; - /// One group offset of weight. - IntV weightOffset_; - /// One group offset of bias. + hl_tensor_descriptor outputDesc_; int biasOffset_; - - /// Save the algorithm for forward convolution, which is obtained by cudnn - /// api to search the best suited algorithm. - std::vector fwdAlgo_; - /// Save the algorithm for computing convolution gradient with respect to - /// filter coefficients. - std::vector bwdFilterAlgo_; - /// Save the algorithm for computing convolution gradient with respect to - /// the output. - std::vector bwdDataAlgo_; - /// Amount of GPU memory needed as workspace to be able to execute a - /// forward convolution with the specified algo. - std::vector fwdLimitBytes_; - /// Amount of GPU memory needed as workspace to be able to execute a - /// backwardFilter with the specified algo. - std::vector bwdFilterLimitBytes_; - /// Amount of GPU memory needed as workspace to be able to execute a - /// backwardData with the specified algo. - std::vector bwdDataLimitBytes_; - - /// Device work space address for each group. - std::vector workSpace_; - /// Max number of groups. - int maxGroups_; - /// Total work space address in device for all groups. - void* workSpaceData_; - /// Size of total work space. - size_t workSpaceInBytes_; - - /// Is or not select conv algorihtm. - bool isSelectAlgo_; - - /// batchNum is used to record batch size. If the batch size is changed, - /// the selection algorithm will be called. - int batchNum_; + int outputOffset_; public: explicit CudnnConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {} ~CudnnConvLayer(); - /** - * Intialization. Initialize member variables and create tenor descriptor. - */ bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); - /** - * Reshape is done each forward. Reshape tensor decriptor - * inputDesc_, outputDesc_, convDesc_. And search the faster algo - * or the fastest algo within a given memeory limit. - */ - void reshape(int batchSize); void forward(PassType passType); void backward(const UpdateCallback& callback); void addBiases(); diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index df79c3e303..80a6a62b5c 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -37,32 +37,29 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, caffeMode_ = conf.caffe_mode(); } + /* initialize the weightList */ + CHECK(inputLayers_.size() == parameters_.size()); + for (size_t i = 0; i < inputLayers_.size(); i++) { + size_t height, width; + height = filterPixels_[i] * filterChannels_[i]; + width = numFilters_; + + // create a new weight + CHECK_EQ(parameters_[i]->getSize(), width * height); + Weight* w = new Weight(height, width, parameters_[i]); + weights_.emplace_back(w); + } + return true; } -size_t ExpandConvLayer::getSize() { +size_t ExpandConvLayer::getOutputSize() { CHECK_NE(inputLayers_.size(), 0UL); - imgSizeH_.clear(); - imgSizeW_.clear(); - outputH_.clear(); - outputW_.clear(); + size_t layerSize = ConvBaseLayer::calOutputSize(); subN_.clear(); - size_t layerSize = 0; for (size_t i = 0; i < inputLayers_.size(); i++) { - imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); - imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); - if (imgSizeH_[i] == 0) imgSizeH_[i] = imgSize_[i]; - if (imgSizeW_[i] == 0) imgSizeW_[i] = imgSize_[i]; - outputH_.push_back( - outputSize(imgSizeH_[i], filterSize_[i], padding_[i], stride_[i])); - outputW_.push_back( - outputSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i])); subN_.push_back(outputH_[i] * outputW_[i]); - CHECK(layerSize == 0 || subN_[i] * size_t(numFilters_) == layerSize); - layerSize = subN_[i] * numFilters_; } - getOutput().setFrameHeight(outputH_[0]); - getOutput().setFrameWidth(outputW_[0]); return layerSize; } @@ -119,7 +116,7 @@ void ExpandConvLayer::expandFwdOnce(MatrixPtr image, int inIdx, int startIdx) { } void ExpandConvLayer::addSharedBias() { - size_t mapW = getSize() / numFilters_; + size_t mapW = getOutputValue()->getWidth() / numFilters_; size_t mapH = getOutputValue()->getElementCnt() / mapW; MatrixPtr out = Matrix::create(getOutputValue()->getData(), mapH, mapW, false, useGpu_); @@ -158,7 +155,7 @@ void ExpandConvLayer::forward(PassType passType) { * transOutValue correspond sample to one row */ int batchSize = inputLayers_[0]->getOutputValue()->getWidth(); batchSize = inputLayers_[0]->getOutputValue()->getHeight(); - resetOutput(batchSize, getSize()); + resetOutput(batchSize, getOutputSize()); MatrixPtr image = nullptr; for (size_t i = 0; i != inputLayers_.size(); ++i) { @@ -183,7 +180,7 @@ void ExpandConvLayer::forward(PassType passType) { } void ExpandConvLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) { - size_t mapW = getSize() / numFilters_; + size_t mapW = v->getWidth() / numFilters_; size_t mapH = v->getElementCnt() / mapW; MatrixPtr vTmp = Matrix::create(v->getData(), mapH, mapW, false, useGpu_); diff --git a/paddle/gserver/layers/ExpandConvLayer.h b/paddle/gserver/layers/ExpandConvLayer.h index fc3d69b1b7..030a3ba397 100644 --- a/paddle/gserver/layers/ExpandConvLayer.h +++ b/paddle/gserver/layers/ExpandConvLayer.h @@ -37,14 +37,6 @@ protected: IntV subN_; /// subK_ = channels_ * filterPixels_ * groups_. IntV subK_; - /// The spatial dimensions of height of input feature map. - IntV imgSizeH_; - /// The spatial dimensions of width of input feature map. - IntV imgSizeW_; - /// The spatial dimensions of height of output feature map. - IntV outputH_; - /// The spatial dimensions of width of output feature map. - IntV outputW_; /// Expand one sample at a time. shape: /// (numChannels * filterPixels_, outputSizeH * outputSizeW) MatrixPtr expandInput_; @@ -58,7 +50,7 @@ public: bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); - size_t getSize(); + size_t getOutputSize(); /** * Create or resize expandInput_. diff --git a/paddle/gserver/layers/MixedLayer.cpp b/paddle/gserver/layers/MixedLayer.cpp index 054ddd3a22..26b1360290 100644 --- a/paddle/gserver/layers/MixedLayer.cpp +++ b/paddle/gserver/layers/MixedLayer.cpp @@ -41,9 +41,13 @@ bool MixedLayer::init(const LayerMap& layerMap, } operators_.emplace_back(Operator::create(operator_conf, useGpu_)); } + /* initialize biases_ */ if (biasParameter_.get() != NULL) { - biases_ = std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + sharedBias_ = config_.shared_biases(); + size_t psize = config_.bias_size(); + biases_ = std::unique_ptr( + new Weight(1, psize, biasParameter_)); } return true; @@ -119,12 +123,6 @@ void MixedLayer::forward(PassType passType) { MatrixPtr outV = getOutputValue(); - /* add the bias-vector */ - if (biases_.get() != NULL) { - REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); - outV->addBias(*(biases_->getW()), 1); - } - for (size_t i = 0; i != inputLayers_.size(); ++i) { if (projections_[i]) { projections_[i]->forward(&getInput(i), &output_, passType); @@ -140,6 +138,12 @@ void MixedLayer::forward(PassType passType) { op->forward(ins, &output_, passType); } + /* add the bias-vector */ + if (biases_.get() != NULL) { + REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); + outV->addBias(*(biases_->getW()), 1, sharedBias_); + } + /* activation */ { REGISTER_TIMER_INFO("FwAtvTimer", getName().c_str()); forwardActivation(); @@ -154,7 +158,7 @@ void MixedLayer::backward(const UpdateCallback& callback) { if (biases_ && biases_->getWGrad()) { REGISTER_TIMER_INFO("BpBiasTimer", getName().c_str()); - biases_->getWGrad()->collectBias(*getOutputGrad(), 1); + biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBias_); /* Increasing the number of gradient */ biases_->getParameterPtr()->incUpdate(callback); diff --git a/paddle/gserver/layers/MixedLayer.h b/paddle/gserver/layers/MixedLayer.h index 9bac1355bd..5842e51e1d 100644 --- a/paddle/gserver/layers/MixedLayer.h +++ b/paddle/gserver/layers/MixedLayer.h @@ -58,5 +58,6 @@ protected: /// the matrix size of projection state std::vector projectionStateMatrixSize_; std::unique_ptr biases_; + bool sharedBias_; }; } // namespace paddle diff --git a/paddle/gserver/tests/LayerGradUtil.cpp b/paddle/gserver/tests/LayerGradUtil.cpp index 552a6c5b41..bc7bee0e4b 100644 --- a/paddle/gserver/tests/LayerGradUtil.cpp +++ b/paddle/gserver/tests/LayerGradUtil.cpp @@ -669,12 +669,14 @@ void testLayerGrad(TestConfig testConf, string testLayerName, size_t batchSize, void testProjectionGrad(ProjectionConfig conf, InputType inputType, size_t parameterSize, size_t batchSize, bool useGpu, - bool testState) { + bool testState, int biasSize, bool sharedBias) { TestConfig config; conf.set_name(conf.type()); config.layerConfig.set_type("mixed"); config.layerConfig.set_size(conf.output_size()); - config.biasSize = config.layerConfig.size(); + config.biasSize = biasSize == 0 ? config.layerConfig.size() : biasSize; + config.layerConfig.set_bias_size(config.biasSize); + config.layerConfig.set_shared_biases(sharedBias); config.inputDefs.push_back( {inputType, "layer_0", conf.input_size(), parameterSize}); *config.layerConfig.add_inputs()->mutable_proj_conf() = conf; diff --git a/paddle/gserver/tests/LayerGradUtil.h b/paddle/gserver/tests/LayerGradUtil.h index 1e608dc062..3b9ec80395 100644 --- a/paddle/gserver/tests/LayerGradUtil.h +++ b/paddle/gserver/tests/LayerGradUtil.h @@ -217,7 +217,8 @@ void testLayerGrad(TestConfig testConf, string testLayerName, size_t batchSize, void testProjectionGrad(ProjectionConfig conf, InputType inputType, size_t parameterSize, size_t batchSize, bool useGpu, - bool testState = false); + bool testState = false, int biasSize = 0, + bool sharedBias = false); void testOperatorGrad(TestConfig& config, OperatorConfig& operatorConf, size_t batchSize, bool useGpu, bool testState = false); diff --git a/paddle/gserver/tests/img_conv_a.conf b/paddle/gserver/tests/img_conv_a.conf new file mode 100644 index 0000000000..940589ed9a --- /dev/null +++ b/paddle/gserver/tests/img_conv_a.conf @@ -0,0 +1,39 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +settings(batch_size=10) +data = data_layer(name ="input", size=8*16*16) +conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) +conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) + +concat = concat_layer(input=[conv1, conv2]) + +conv = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=True, + act=LinearActivation()) + +outputs(concat, conv) diff --git a/paddle/gserver/tests/img_conv_b.conf b/paddle/gserver/tests/img_conv_b.conf new file mode 100644 index 0000000000..8ca9c94541 --- /dev/null +++ b/paddle/gserver/tests/img_conv_b.conf @@ -0,0 +1,32 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +settings(batch_size=10) +data = data_layer(name ="input", size=8*16*16) +proj1 = conv_projection(input=data, filter_size=1, filter_size_y=1, + num_channels=8, num_filters=16, stride=1) +proj2 = conv_projection(input=data, filter_size=1, filter_size_y=1, + num_channels=8, num_filters=16, stride=1) +concat = concat_layer(input=[proj1, proj2], bias_attr=False, act=ReluActivation()) + +proj = conv_projection(input=data, filter_size=1, filter_size_y=1, + num_channels=8, num_filters=16, stride=1) + +with mixed_layer(bias_attr=True, act=LinearActivation()) as conv: + conv += proj + +outputs(concat, conv) diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index eab9bf8414..bf2c2e0499 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -134,6 +134,45 @@ TEST(Projection, identity) { } } + +#ifndef PADDLE_ONLY_CPU +TEST(Projection, conv) { + const int NUM_FILTERS = 16; + const int FILTER_SIZE = 2; + const int FILTER_SIZE_Y = 3; + const int CHANNELS = 3; + const int IMAGE_SIZE = 16; + + ProjectionConfig conf; + conf.set_type("conv"); + conf.set_num_filters(NUM_FILTERS); + + ConvConfig* conv = conf.mutable_conv_conf(); + conv->set_filter_size(FILTER_SIZE); + conv->set_filter_size_y(FILTER_SIZE_Y); + conv->set_channels(CHANNELS); + conv->set_padding(0); + conv->set_padding_y(1); + conv->set_stride(2); + conv->set_stride_y(2); + conv->set_groups(1); + conv->set_filter_channels(conv->channels() / conv->groups()); + conv->set_img_size(IMAGE_SIZE); + int outputSize = (2 * conv->padding() + conv->img_size() - + conv->filter_size()) / conv->stride() + 1; + int outputSizeY = (2 * conv->padding_y() + conv->img_size() - + conv->filter_size_y()) / conv->stride_y() + 1; + conv->set_output_x(outputSize); + conf.set_input_size(IMAGE_SIZE * IMAGE_SIZE * CHANNELS); + conf.set_output_size(outputSize * outputSizeY * NUM_FILTERS); + + testProjectionGrad(conf, INPUT_DATA, + /* parameterSize */ NUM_FILTERS * CHANNELS * FILTER_SIZE * FILTER_SIZE_Y, + /* batchSize */ 100, true, false, NUM_FILTERS, true); +} +#endif + + TEST(Layer, concat) { TestConfig config; config.biasSize = 0; diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp index b3ef530673..8d3eac5aca 100644 --- a/paddle/gserver/tests/test_NetworkCompare.cpp +++ b/paddle/gserver/tests/test_NetworkCompare.cpp @@ -236,6 +236,15 @@ TEST(Compare, img_pool) { compareNetwork(config_file_a, config_file_b); FLAGS_use_gpu = useGpu; } + +TEST(Compare, img_conv) { + std::string config_file_a = "./gserver/tests/img_conv_a.conf"; + std::string config_file_b = "./gserver/tests/img_conv_b.conf"; + bool useGpu = FLAGS_use_gpu; + FLAGS_use_gpu = true; + compareNetwork(config_file_a, config_file_b); + FLAGS_use_gpu = useGpu; +} #endif diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 843eabc97d..aaeae98f0d 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -340,6 +340,15 @@ void GpuMatrix::addBias(Matrix& b, real scale) { BaseMatrix::addBias(b, scale); } +void GpuMatrix::addSharedBias(Matrix& b, real scale) { + CHECK(b.getHeight() == 1) << "the Bias should be a vector"; + CHECK_LE(b.getWidth(), getWidth()); + CHECK_EQ(getWidth() % b.getWidth(), 0UL); + hl_matrix_add_shared_bias(getData(), b.getData(), b.getWidth(), + getHeight(), getWidth(), scale); +} + + void GpuMatrix::collectBias(Matrix& a, real scale) { CHECK_EQ(getHeight(), (size_t)1); CHECK_EQ(width_, a.getWidth()); @@ -354,6 +363,14 @@ void GpuMatrix::collectBias(Matrix& a, real scale) { } } +void GpuMatrix::collectSharedBias(Matrix& a, real scale) { + CHECK_EQ(getHeight(), (size_t)1); + CHECK_EQ(a.getWidth() % getWidth(), 0UL); + hl_matrix_collect_shared_bias(getData(), a.getData(), getWidth(), + a.getHeight(), a.getWidth(), scale); +} + + void GpuMatrix::sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode) { @@ -1983,6 +2000,24 @@ void CpuMatrix::addBias(Matrix& b, real scale) { } } +void CpuMatrix::addSharedBias(Matrix& b, real scale) { + CHECK_EQ(b.getHeight(), (size_t)1); + real* aData = getData(); + real* bData = b.getData(); + size_t numSamples = getHeight(); + size_t channel = b.getWidth(); + CHECK_EQ(getWidth() % channel, 0UL); + size_t dim = getWidth() / channel; + + for (size_t i = 0; i < numSamples; i++) { + for (size_t c = 0; c < channel; c++) { + for (size_t j = 0; j < dim; j++) { + aData[i * getStride() + c * dim + j] += scale * bData[c]; + } + } + } +} + void CpuMatrix::collectBias(Matrix& a, real scale) { CHECK_EQ(getHeight(), (size_t)1); CHECK_EQ(width_, a.getWidth()); @@ -2000,6 +2035,23 @@ void CpuMatrix::collectBias(Matrix& a, real scale) { } } +void CpuMatrix::collectSharedBias(Matrix& a, real scale) { + CHECK_EQ(getHeight(), (size_t)1); + real* B = getData(); + real* A = a.getData(); + size_t numSamples = a.getHeight(); + size_t channel = getWidth(); + CHECK_EQ(a.getWidth() % channel, 0UL); + size_t dim = a.getWidth() / channel; + for (size_t i = 0; i < numSamples; i++) { + for (size_t c = 0; c < channel; c++) { + for (size_t j = 0; j < dim; j++) { + B[c] += scale * A[i * channel * dim + c * dim + j]; + } + } + } +} + void CpuMatrix::sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode) { diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 047c76a860..52cbed528c 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -343,11 +343,35 @@ public: LOG(FATAL) << "Not implemented"; } + virtual void addSharedBias(Matrix& b, real scale) { + LOG(FATAL) << "Not implemented"; + } + + virtual void addBias(Matrix& b, real scale, bool sharedBias) { + if (!sharedBias) { + addBias(b, scale); + } else { + addSharedBias(b, scale); + } + } + /// add each sample from a to this. virtual void collectBias(Matrix& a, real scale) { LOG(FATAL) << "Not implemented"; } + virtual void collectSharedBias(Matrix& a, real scale) { + LOG(FATAL) << "Not implemented"; + } + + virtual void collectBias(Matrix& a, real scale, bool sharedBias) { + if (!sharedBias) { + collectBias(a, scale); + } else { + collectSharedBias(a, scale); + } + } + virtual void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode) { LOG(FATAL) << "Not implemented"; @@ -1021,6 +1045,7 @@ public: /// add b to each sample of this. void addBias(Matrix& b, real scale); + void addSharedBias(Matrix& b, real scale); /** * @code @@ -1028,6 +1053,7 @@ public: * @endcode */ void collectBias(Matrix& a, real scale); + void collectSharedBias(Matrix& a, real scale); void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode); @@ -1341,9 +1367,11 @@ public: public: /// add b to each sample of this. void addBias(Matrix& b, real scale); + void addSharedBias(Matrix& b, real scale); /// add each sample of a to this. void collectBias(Matrix& a, real scale); + void collectSharedBias(Matrix& a, real scale); void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index ac160479a9..0ddf7e0dfc 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/math/SparseMatrix.h" #include #include "paddle/gserver/tests/TestUtil.h" +#include "paddle/utils/Stat.h" + using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -2071,6 +2073,60 @@ TEST(Matrix, MaxOutFwdBwd) { } } +void testAddSharedBias(int numSamples, int dim, int channel) { + MatrixPtr cpuData = std::make_shared(numSamples, dim); + MatrixPtr gpuData = std::make_shared(numSamples, dim); + + MatrixPtr cpuBias = std::make_shared(1, channel); + MatrixPtr gpuBias = std::make_shared(1, channel); + + cpuData->randomizeUniform(); + gpuData->copyFrom(*cpuData); + cpuBias->randomizeUniform(); + gpuBias->copyFrom(*cpuBias); + + cpuData->addSharedBias(*cpuBias, 1.0); + gpuData->addSharedBias(*gpuBias, 1.0); + + MatrixPtr check = std::make_shared(numSamples, dim); + check->copyFrom(*gpuData); + MatrixCheckErr(*cpuData, *check); +} + +void testCollectSharedBias(int numSamples, int dim, int channel) { + MatrixPtr cpuData = std::make_shared(numSamples, dim); + MatrixPtr gpuData = std::make_shared(numSamples, dim); + + MatrixPtr cpuBias = std::make_shared(1, channel); + MatrixPtr gpuBias = std::make_shared(1, channel); + + cpuData->randomizeUniform(); + gpuData->copyFrom(*cpuData); + cpuBias->randomizeUniform(); + gpuBias->copyFrom(*cpuBias); + + cpuBias->collectSharedBias(*cpuData, 1.0); + gpuBias->collectSharedBias(*gpuData, 1.0); + + MatrixPtr check = std::make_shared(1, channel); + check->copyFrom(*gpuBias); + MatrixCheckErr(*cpuBias, *check); +} + + +TEST(Matrix, sharedBias) { + for (auto numSamples : {1, 100, 520}) { + for (auto dim : {100 * 16, 100 * 32}) { + for (auto channel : {8, 16}) { + VLOG(3) << " numSamples=" << numSamples << " dim=" << dim + << " channel=" << channel; + testAddSharedBias(numSamples, dim, channel); + testCollectSharedBias(numSamples, dim, channel); + } + } + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/trainer/CMakeLists.txt b/paddle/trainer/CMakeLists.txt index 08b411d2cc..06c019f0a9 100644 --- a/paddle/trainer/CMakeLists.txt +++ b/paddle/trainer/CMakeLists.txt @@ -7,6 +7,7 @@ set(TRAINER_SOURCES Tester.cpp Trainer.cpp TrainerInternal.cpp + TrainerBenchmark.cpp ThreadParameterUpdater.cpp TrainerInternalConfig.cpp TrainerConfigHelper.cpp) diff --git a/paddle/trainer/Trainer.h b/paddle/trainer/Trainer.h index 4f4811a139..7762722456 100644 --- a/paddle/trainer/Trainer.h +++ b/paddle/trainer/Trainer.h @@ -99,6 +99,7 @@ public: void startTrainPass(); void finishTrainPass(); void trainOneDataBatch(DataBatch& dataBatch); + void time(); /** * given a dataBatch and the current parameter value diff --git a/paddle/trainer/TrainerBenchmark.cpp b/paddle/trainer/TrainerBenchmark.cpp new file mode 100644 index 0000000000..54862e95b4 --- /dev/null +++ b/paddle/trainer/TrainerBenchmark.cpp @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#undef PADDLE_DISABLE_TIMER + +#include "Trainer.h" +#include "paddle/utils/Stat.h" +#include "paddle/utils/Util.h" + +P_DECLARE_int32(test_period); + +P_DEFINE_bool(feed_data, false, "Wether to read data from DataProvider."); + +namespace paddle { + +void Trainer::time() { + startTrain(); + + trainerInternal_.getParameterUpdater()->startPass(); + evaluator_->start(); + + DataBatch dataBatch; + int32_t batchSize = config_->getOptConfig().batch_size(); + int32_t num = dataProvider_->getNextBatch(batchSize, &dataBatch); + CHECK_EQ(num, batchSize) << "The sample number is less than batch size " + << num << " != " << batchSize; + + CHECK(dataBatch.getSize()) << "No data from data provider"; + + std::vector outputs; + // burning time + LOG(INFO) << "Burning time..."; + for (int n = 0; n < 10; ++n) { + trainerInternal_.trainOneBatch(n, dataBatch, &outputs); + } + LOG(INFO) << "Burning time end."; + + for (int n = 0; n < FLAGS_test_period; n++) { + if (FLAGS_feed_data) { + REGISTER_TIMER("GetData"); + num = dataProvider_->getNextBatch(batchSize, &dataBatch); + } + + if (num != batchSize) { + break; + } + + { + REGISTER_TIMER("FwdBwd"); + trainerInternal_.trainOneBatch(n, dataBatch, &outputs); + } + } + globalStat.setThreadInfo(true); + globalStat.printSegTimerStatus(); + globalStat.reset(); + + finishTrain(); +} + +} // namespace paddle diff --git a/paddle/trainer/TrainerMain.cpp b/paddle/trainer/TrainerMain.cpp index 94266639f9..a486cc383a 100644 --- a/paddle/trainer/TrainerMain.cpp +++ b/paddle/trainer/TrainerMain.cpp @@ -103,6 +103,8 @@ int main(int argc, char** argv) { trainer.checkGradient(); } else if (FLAGS_job == "test") { trainer.test(); + } else if (FLAGS_job == "time") { + trainer.time(); } else { LOG(FATAL) << "Unknown job type: " << FLAGS_job; } diff --git a/proto/ModelConfig.proto.m4 b/proto/ModelConfig.proto.m4 index 70c1f8d563..79e76b6bf1 100644 --- a/proto/ModelConfig.proto.m4 +++ b/proto/ModelConfig.proto.m4 @@ -255,7 +255,7 @@ sinclude(`ModelConfigLayer.proto.m4') // (which is how convnets are usually trained). Setting this to // false will untie the biases, yielding a separate bias for // every location at which the filter is applied. - optional bool shared_biases = 8; + optional bool shared_biases = 8 [default = false]; // Valid values are ones that divide the area of the output // grid in this convolutional layer. For example if this layer @@ -379,6 +379,9 @@ sinclude(`ModelConfigLayer.proto.m4') // use to compute moving mean and variance. optional real moving_average_fraction = 47 [default = 0.9]; + + // bias size + optional uint32 bias_size = 48 [default = 0]; } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index fe8a5e5d48..e909894316 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -632,6 +632,44 @@ class ContextProjection(Projection): _total_pad = 0 +@config_class +class ConvProjection(Projection): + type = 'conv' + + def __init__( + self, + input_layer_name, + num_filters=None, + conv_conf=None, + **xargs): + super(ConvProjection, self).__init__(input_layer_name, **xargs) + + if num_filters is not None: + self.proj_conf.num_filters = num_filters + + parse_conv(conv_conf, + input_layer_name, + self.proj_conf.conv_conf) + # TODO: support rectangle input + self.proj_conf.output_size = (self.proj_conf.conv_conf.output_x ** 2) * num_filters + + def calc_output_size(self, input_layer_config): + return self.proj_conf.output_size + + def calc_parameter_size(self, input_size, output_size): + co = self.proj_conf.num_filters + ci = self.proj_conf.conv_conf.channels + fh = self.proj_conf.conv_conf.filter_size + fw = self.proj_conf.conv_conf.filter_size_y + return co * ci * fh * fw + + def calc_bias_size(self): + return self.proj_conf.num_filters + + def calc_parameter_dims(self, input_size, output_size): + return None + + # Define a operator for mixed layer @config_class class Operator(Cfg): @@ -2528,8 +2566,15 @@ class MixedLayer(LayerBase): record_operator_conf = self.config.operator_confs.add() record_operator_conf.CopyFrom(operator_conf) + psize = self.config.size + if isinstance(self.inputs[0], ConvProjection): + self.config.shared_biases = True + psize = 0 + for input in self.inputs: + psize += input.calc_bias_size() - self.create_bias_parameter(bias, self.config.size) + self.config.bias_size = psize + self.create_bias_parameter(bias, psize) if error_clipping_threshold is not None: self.config.error_clipping_threshold = error_clipping_threshold @@ -2547,8 +2592,10 @@ class ConcatenateLayer(LayerBase): self, name, inputs, + bias=False, **xargs): config_assert(inputs, 'inputs cannot be empty') + config_assert(not bias, 'ConcatenateLayer cannot support bias.') super(ConcatenateLayer, self).__init__( name, 'concat', 0, inputs=inputs, **xargs) size = 0 @@ -2567,10 +2614,19 @@ class ConcatenateLayer2(LayerBase): self, name, inputs, + bias=False, **xargs): config_assert(inputs, 'inputs cannot be empty') super(ConcatenateLayer2, self).__init__( name, 'concat2', 0, inputs=inputs, **xargs) + + if isinstance(self.inputs[0], ConvProjection): + for input_index in xrange(len(self.inputs) - 1): + input = self.inputs[input_index + 1] + config_assert(isinstance(input, ConvProjection), + "The first input of ConcatenateLayer2 is ConvProjection, " + "the other inputs should also be ConvProjection.") + size = 0 for input_index in xrange(len(self.inputs)): input_layer = self.get_input_layer(input_index) @@ -2596,6 +2652,16 @@ class ConcatenateLayer2(LayerBase): input.proj_conf.output_size) self.create_input_parameter(input_index, psize, dims) + psize = self.config.size + if isinstance(self.inputs[0], ConvProjection): + self.config.shared_biases = True + psize = 0 + for input in self.inputs: + psize += input.calc_bias_size() + + self.config.bias_size = psize + self.create_bias_parameter(bias, psize) + @config_layer('recurrent') class RecurrentLayer(LayerBase): def __init__( diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 7df9108ae8..9a23c02431 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -34,7 +34,7 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel", "table_projection", "mixed_layer", "data_layer", "embedding_layer", "fc_layer", "grumemory", "pooling_layer", "lstmemory", "last_seq", "first_seq", - "cos_sim", "hsigmoid", + "cos_sim", "hsigmoid", "conv_projection", "regression_cost", 'classification_cost', "LayerOutput", 'img_conv_layer', 'img_pool_layer', 'batch_norm_layer', 'img_cmrnorm_layer', 'addto_layer', @@ -1984,7 +1984,7 @@ def addto_layer(input, act=None, name=None, bias_attr=None, @wrap_act_default(act=IdentityActivation()) @wrap_name_default("concat") @layer_support() -def concat_layer(input, act=None, name=None, layer_attr=None): +def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None): """ Concat all input vector into one huge vector. Inputs can be list of LayerOutput or list of projection. @@ -2043,10 +2043,14 @@ def concat_layer(input, act=None, name=None, layer_attr=None): layer_type = (LayerType.CONCAT_LAYER if is_concat_layer else LayerType.CONCAT_PROJ_LAYER) + if layer_type == LayerType.CONCAT_LAYER: + assert not bias_attr + Layer( name=name, type=layer_type, inputs=[x.name for x in input] if is_concat_layer else input, active_type=act.name, + bias=ParamAttr.to_bias(bias_attr), **ExtraLayerAttribute.to_kwargs(layer_attr) ) @@ -2950,6 +2954,103 @@ def conv_operator(img, filter, filter_size, num_filters, op.origin = [img, filter] return op +@wrap_param_attr_default() +def conv_projection(input, filter_size, num_filters, + num_channels=None, stride=1, padding=0, + filter_size_y=None, stride_y=None, padding_y=None, + groups=1, param_attr=None): + """ + ConvProjection with a layer as input. + It performs element-wise multiplication with weight. + + Different from img_conv_layer and conv_op, conv_projection is an Projection, + which can be used in mixed_layer and conat_layer. It use cudnn to implement + conv and only support GPU mode. + + The example usage is: + + .. code-block:: python + + proj = conv_projection(img=input1, + filter_size=3, + num_filters=64, + num_channels=64) + + :param input: input layer + :type input: LayerOutput + :param filter_size: The x dimension of a filter kernel. + :type filter_size: int + :param filter_size_y: The y dimension of a filter kernel. Since + PaddlePaddle now supports rectangular filters, + the filter's shape can be (filter_size, filter_size_y). + :type filter_size_y: int + :param num_filters: channel of output data. + :type num_filters: int + :param num_channel: channel of input data. + :type num_channel: int + :param stride: The x dimension of the stride. + :type stride: int + :param stride_y: The y dimension of the stride. + :type stride_y: int + :param padding: The x dimension of padding. + :type padding: int + :param padding_y: The y dimension of padding. + :type padding_y: int + :param groups: The group number. + :type groups: int + :param param_attr: Convolution param attribute. None means default attribute + :type param_attr: ParameterAttribute + :return: A DotMulProjection Object. + :rtype: DotMulProjection + """ + if num_channels is None: + assert input.num_filters is not None + num_channels = input.num_filters + + if filter_size_y is None: + if isinstance(filter_size, collections.Sequence): + assert len(filter_size) == 2 + filter_size, filter_size_y = filter_size + else: + filter_size_y = filter_size + + if stride_y is None: + if isinstance(stride, collections.Sequence): + assert len(stride) == 2 + stride, stride_y = stride + else: + stride_y = stride + + if padding_y is None: + if isinstance(padding, collections.Sequence): + assert len(padding) == 2 + padding, padding_y = padding + else: + padding_y = padding + + if param_attr.attr.get('initial_smart'): + # special initial for conv layers. + init_w = (2.0 / (filter_size ** 2 * num_channels)) ** 0.5 + param_attr.attr["initial_mean"] = 0.0 + param_attr.attr["initial_std"] = init_w + param_attr.attr["initial_strategy"] = 0 + param_attr.attr["initial_smart"] = False + + proj = ConvProjection(input_layer_name=input.name, + num_filters=num_filters, + conv_conf=Conv(filter_size=filter_size, + padding=padding, + stride=stride, + channels=num_channels, + filter_size_y=filter_size_y, + padding_y=padding_y, + stride_y=stride_y, + groups=groups), + **param_attr.attr) + + proj.origin = input + return proj + @wrap_name_default() @layer_support() diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 65512b327c..bce88f9362 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -29,7 +29,7 @@ __all__ = ['sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool", "img_conv_bn_pool", 'dropout_layer', 'lstmemory_group', 'lstmemory_unit', 'small_vgg', 'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru', 'simple_attention', - 'text_conv_pool', + 'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'inputs', 'outputs'] @@ -811,22 +811,37 @@ def simple_gru(input, gru_layer_attr=None ): """ - simple_gru is also a recurrent layer group version Gated Recurrent Unit as - gru_group. The difference only lies in implemention details. + You maybe see gru_step_layer, grumemory in layers.py, gru_unit, gru_group, + simple_gru in network.py. The reason why there are so many interfaces is + that we have two ways to implement recurrent neural network. One way is to + use one complete layer to implement rnn (including simple rnn, gru and lstm) + with multiple time steps, such as recurrent_layer, lstmemory, grumemory. But, + the multiplication operation :math:`W x_t` is not computed in these layers. + See details in their interfaces in layers.py. + The other implementation is to use an recurrent group which can ensemble a + series of layers to compute rnn step by step. This way is flexible for + attenion mechanism or other complex connections. + + - gru_step_layer: only compute rnn by one step. It needs an memory as input + and can be used in recurrent group. + - gru_unit: a wrapper of gru_step_layer with memory. + - gru_group: a GRU cell implemented by a combination of multiple layers in + recurrent group. + But :math:`W x_t` is not done in group. + - gru_memory: a GRU cell implemented by one layer, which does same calculation + with gru_group and is faster than gru_group. + - simple_gru: a complete GRU implementation inlcuding :math:`W x_t` and + gru_group. :math:`W` contains :math:`W_r`, :math:`W_z` and :math:`W`, see + formula in grumemory. + The computational speed is that, grumemory is relatively better than gru_group, and gru_group is relatively better than simple_gru. - simple_gru does exactly the same calculation as the grumemory layer does. - Please see grumemory in layers.py for more detail about the maths. - The example usage is: .. code-block:: python - gru = gur_group(input=[layer1], - size=256, - act=TanhActivation(), - gate_act=SigmoidActivation()) + gru = simple_gru(input=[layer1], size=256) :param input: input layer name. :type input: LayerOutput @@ -863,6 +878,132 @@ def simple_gru(input, gru_layer_attr=gru_layer_attr) +@wrap_name_default('simple_gru2') +def simple_gru2(input, + size, + name=None, + reverse=False, + mixed_param_attr=None, + mixed_bias_attr=None, + gru_param_attr=None, + gru_bias_attr=None, + act=None, + gate_act=None, + mixed_layer_attr=None, + gru_cell_attr=None + ): + """ + simple_gru2 is the same with simple_gru, but using grumemory instead + Please see grumemory in layers.py for more detail about the maths. + simple_gru2 is faster than simple_gru. + + The example usage is: + + .. code-block:: python + + gru = simple_gru2(input=[layer1], size=256) + + :param input: input layer name. + :type input: LayerOutput + :param name: name of the gru group. + :type name: basestring + :param size: hidden size of the gru. + :type size: int + :param reverse: whether to process the input data in a reverse order + :type reverse: bool + :param act: type of the activiation + :type act: BaseActivation + :param gate_act: type of the gate activiation + :type gate_act: BaseActivation + :param gru_bias_attr: bias. False means no bias, None means default bias. + :type gru_bias_attr: ParameterAttribute|False + :param gru_layer_attr: Extra parameter attribute of the gru layer. + :type gru_layer_attr: ParameterAttribute|False + :return: the gru group. + :rtype: LayerOutput + """ + with mixed_layer(name='%s_transform' % name, + size=size * 3, + bias_attr=mixed_bias_attr, + layer_attr=mixed_layer_attr) as m: + m += full_matrix_projection(input=input, param_attr=mixed_param_attr) + + return grumemory(name=name, + size=size, + input=m, + reverse=reverse, + bias_attr=gru_bias_attr, + param_attr=gru_param_attr, + act=act, + gate_act=gate_act, + layer_attr=gru_cell_attr) + + +@wrap_name_default("bidirectional_gru") +def bidirectional_gru(input, size, name=None, return_seq=False, + fwd_mixed_param_attr=None, fwd_mixed_bias_attr=None, + fwd_gru_param_attr=None, fwd_gru_bias_attr=None, + fwd_act=None, fwd_gate_act=None, + fwd_mixed_layer_attr=None, fwd_gru_cell_attr=None, + + bwd_mixed_param_attr=None, bwd_mixed_bias_attr=None, + bwd_gru_param_attr=None, bwd_gru_bias_attr=None, + bwd_act=None, bwd_gate_act=None, + bwd_mixed_layer_attr=None, bwd_gru_cell_attr=None, + + last_seq_attr=None, first_seq_attr=None, + concat_attr=None, concat_act=None): + """ + A bidirectional_gru is a recurrent unit that iterates over the input + sequence both in forward and bardward orders, and then concatenate two + outputs to form a final output. However, concatenation of two outputs + is not the only way to form the final output, you can also, for example, + just add them together. + + The example usage is: + + .. code-block:: python + + bi_gru = bidirectional_gru(input=[input1], size=512) + + :param name: bidirectional gru layer name. + :type name: basestring + :param input: input layer. + :type input: LayerOutput + :param size: gru layer size. + :type size: int + :param return_seq: If set False, outputs of the last time step are + concatenated and returned. + If set True, the entire output sequences that are + processed in forward and backward directions are + concatenated and returned. + :type return_seq: bool + :return: LayerOutput object. + :rtype: LayerOutput + """ + args = locals() + + fw = simple_gru2(name='%s_fw' % name, input=input, size=size, + **dict((k[len('fwd_'):], v) for k, v in args.iteritems() + if k.startswith('fwd_'))) + + bw = simple_gru2(name="%s_bw" % name, input=input, size=size, + reverse=True, + **dict((k[len('bwd_'):], v) for k, v in args.iteritems() + if k.startswith('bwd_'))) + + if return_seq: + return concat_layer(name=name, input=[fw, bw], layer_attr=concat_attr, + act=concat_act) + else: + fw_seq = last_seq(name="%s_fw_last" % name, input=fw, + layer_attr=last_seq_attr) + bw_seq = first_seq(name="%s_bw_last" % name, input=bw, + layer_attr=first_seq_attr) + return concat_layer(name=name, input=[fw_seq, bw_seq], + layer_attr=concat_attr, act=concat_act) + + @wrap_name_default("bidirectional_lstm") def bidirectional_lstm(input, size, name=None, return_seq=False, fwd_mat_param_attr=None, fwd_bias_param_attr=None, @@ -893,7 +1034,7 @@ def bidirectional_lstm(input, size, name=None, return_seq=False, .. code-block:: python - lstm_step = bidirectional_lstm(input=[input1], size=512) + bi_lstm = bidirectional_lstm(input=[input1], size=512) :param name: bidirectional lstm layer name. :type name: basestring @@ -907,7 +1048,7 @@ def bidirectional_lstm(input, size, name=None, return_seq=False, processed in forward and backward directions are concatenated and returned. :type return_seq: bool - :return: lstm layer name. + :return: LayerOutput object accroding to the return_seq. :rtype: LayerOutput """ args = locals() diff --git a/python/paddle/trainer_config_helpers/tests/configs/check.md5 b/python/paddle/trainer_config_helpers/tests/configs/check.md5 index d1b22b3490..72dfdad7bd 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/check.md5 +++ b/python/paddle/trainer_config_helpers/tests/configs/check.md5 @@ -1,10 +1,11 @@ 86c0815275a9d5eb902e23c6a592f58a img_layers.protostr a5d9259ff1fd7ca23d0ef090052cb1f2 last_first_seq.protostr 9c038249ec8ff719753a746cdb04c026 layer_activations.protostr -5913f87b39cee3b2701fa158270aca26 projections.protostr +34e04043cbb12931c47fa44ec50eeffc projections.protostr 7334ba0a4544f0623231330fc51d390d shared_fc.protostr -8b8b6bb128a7dfcc937be86145f53e2f shared_lstm.protostr +bb8e233b05b8e07f9ed386b7aee4f2c6 shared_lstm.protostr 6b39e34beea8dfb782bee9bd3dea9eb5 simple_rnn_layers.protostr +f98e79e1630d5eb827c300e64836d269 test_bi_grumemory.protostr 0fc1409600f1a3301da994ab9d28b0bf test_cost_layers.protostr 6cd5f28a3416344f20120698470e0a4c test_cost_layers_with_weight.protostr 144bc6d3a509de74115fa623741797ed test_expand_layer.protostr @@ -15,7 +16,7 @@ d350bd91a0dc13e854b1364c3d9339c6 test_lstmemory_layer.protostr 5433ed33d4e7414eaf658f2a55946186 test_maxout.protostr 251a948ba41c1071afcd3d9cf9c233f7 test_ntm_layers.protostr e6ff04e70aea27c7b06d808cc49c9497 test_print_layer.protostr -2a75dd33b640c49a8821c2da6e574577 test_rnn_group.protostr +fded24727338fb8ce44d9951ed8aea08 test_rnn_group.protostr 67d6fde3afb54f389d0ce4ff14726fe1 test_sequence_pooling.protostr f586a548ef4350ba1ed47a81859a64cb unused_layers.protostr -8122477f4f65244580cec09edc590041 util_layers.protostr +f937a5a6e7e8864b4d8cf56b0f7c7f44 util_layers.protostr diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh index 4b1d2d3d41..6a31ceabdf 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh @@ -9,7 +9,7 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers util_layers simple_rnn_layers unused_layers test_cost_layers test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight -test_maxout) +test_maxout test_bi_grumemory) for conf in ${configs[*]} diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_bi_grumemory.py b/python/paddle/trainer_config_helpers/tests/configs/test_bi_grumemory.py new file mode 100644 index 0000000000..ab9f7c4948 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_bi_grumemory.py @@ -0,0 +1,10 @@ +from paddle.trainer_config_helpers import * + +settings( + batch_size=1000, + learning_rate=1e-4 +) + +din = data_layer(name='data', size=120) + +outputs(bidirectional_gru(input=din, size=40, return_seq=True)) From 7180b4246d085913390e4a338bc3f1decf77c44d Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 1 Nov 2016 23:55:46 -0500 Subject: [PATCH 11/11] Add some concepts documents to guide user for using paddle (#249) --- demo/quick_start/dataprovider_emb.py | 1 + demo/quick_start/trainer_config.lstm.py | 15 +- doc_cn/concepts/nn.rst | 3 + doc_cn/concepts/program_concepts.rst | 4 + doc_cn/concepts/pserver_topology.dot | 68 +++++++++ doc_cn/concepts/trainer_config.py | 23 +++ doc_cn/concepts/use_concepts.rst | 191 ++++++++++++++++++++++++ doc_cn/faq/index.rst | 10 ++ doc_cn/index.rst | 1 + 9 files changed, 305 insertions(+), 11 deletions(-) create mode 100644 doc_cn/concepts/nn.rst create mode 100644 doc_cn/concepts/program_concepts.rst create mode 100644 doc_cn/concepts/pserver_topology.dot create mode 100644 doc_cn/concepts/trainer_config.py create mode 100644 doc_cn/concepts/use_concepts.rst diff --git a/demo/quick_start/dataprovider_emb.py b/demo/quick_start/dataprovider_emb.py index ca940a89e5..f5632d5f3f 100755 --- a/demo/quick_start/dataprovider_emb.py +++ b/demo/quick_start/dataprovider_emb.py @@ -16,6 +16,7 @@ from paddle.trainer.PyDataProvider2 import * UNK_IDX = 0 + def initializer(settings, dictionary, **kwargs): settings.word_dict = dictionary settings.input_types = [ diff --git a/demo/quick_start/trainer_config.lstm.py b/demo/quick_start/trainer_config.lstm.py index ec8a2cb00a..b412a9cbd9 100644 --- a/demo/quick_start/trainer_config.lstm.py +++ b/demo/quick_start/trainer_config.lstm.py @@ -42,20 +42,13 @@ settings( gradient_clipping_threshold=25 ) -bias_attr = ParamAttr(initial_std=0.,l2_rate=0.) data = data_layer(name="word", size=len(word_dict)) emb = embedding_layer(input=data, size=128) -fc = fc_layer(input=emb, size=512, - act=LinearActivation(), - bias_attr=bias_attr, - layer_attr=ExtraAttr(drop_rate=0.1)) -lstm = lstmemory(input=fc, act=TanhActivation(), - bias_attr=bias_attr, - layer_attr=ExtraAttr(drop_rate=0.25)) -lstm_last = pooling_layer(input=lstm, pooling_type=MaxPooling()) -output = fc_layer(input=lstm_last, size=2, - bias_attr=bias_attr, +lstm = simple_lstm(input=emb, size=128, + lstm_cell_attr=ExtraAttr(drop_rate=0.25)) +lstm_max = pooling_layer(input=lstm, pooling_type=MaxPooling()) +output = fc_layer(input=lstm_max, size=2, act=SoftmaxActivation()) if is_predict: maxid = maxid_layer(output) diff --git a/doc_cn/concepts/nn.rst b/doc_cn/concepts/nn.rst new file mode 100644 index 0000000000..f4d2cf490d --- /dev/null +++ b/doc_cn/concepts/nn.rst @@ -0,0 +1,3 @@ +TBD + +目前正在书写中。敬请期待。 \ No newline at end of file diff --git a/doc_cn/concepts/program_concepts.rst b/doc_cn/concepts/program_concepts.rst new file mode 100644 index 0000000000..af5bbdac26 --- /dev/null +++ b/doc_cn/concepts/program_concepts.rst @@ -0,0 +1,4 @@ +TBD +### + +目前正在书写中。敬请期待。 \ No newline at end of file diff --git a/doc_cn/concepts/pserver_topology.dot b/doc_cn/concepts/pserver_topology.dot new file mode 100644 index 0000000000..9ff658b849 --- /dev/null +++ b/doc_cn/concepts/pserver_topology.dot @@ -0,0 +1,68 @@ +graph pp_topology { + rankdir=BT; + subgraph cluster_node0 { + style=filled; + color=lightgrey; + node [style=filled, color=white, shape=box]; + label = "机器0" + + pserver0 [label="Parameter \n Server 0"] + trainer0 [label="Trainer 0"] + } + subgraph cluster_node1 { + style=filled; + color=lightgrey; + node [style=filled, color=white, shape=box]; + label = "机器1" + + pserver1 [label="Parameter \n Server 1"] + trainer1 [label="Trainer 1"] + } + + subgraph cluster_node2 { + style=filled; + color=lightgrey; + node [style=filled, color=white, shape=box]; + label = "机器2" + + pserver2 [label="Parameter \n Server 2"] + trainer2 [label="Trainer 2"] + } + + subgraph cluster_node3 { + style=filled; + color=lightgrey; + node [style=filled, color=white, shape=box]; + label = "机器3" + + pserver3 [label="Parameter \n Server 3"] + trainer3 [label="Trainer 3"] + } + + data [label="数据", shape=hexagon] + + trainer0 -- pserver0 + trainer0 -- pserver1 + trainer0 -- pserver2 + trainer0 -- pserver3 + + trainer1 -- pserver0 + trainer1 -- pserver1 + trainer1 -- pserver2 + trainer1 -- pserver3 + + trainer2 -- pserver0 + trainer2 -- pserver1 + trainer2 -- pserver2 + trainer2 -- pserver3 + + trainer3 -- pserver0 + trainer3 -- pserver1 + trainer3 -- pserver2 + trainer3 -- pserver3 + + data -- trainer0 + data -- trainer1 + data -- trainer2 + data -- trainer3 +} diff --git a/doc_cn/concepts/trainer_config.py b/doc_cn/concepts/trainer_config.py new file mode 100644 index 0000000000..8d8c79fb39 --- /dev/null +++ b/doc_cn/concepts/trainer_config.py @@ -0,0 +1,23 @@ +from paddle.trainer_config_helpers import * + +define_py_data_sources2(train_list='train.list', + test_list='test.list', + module='provider', + obj='process') +settings( + batch_size=128, + learning_rate=1e-3, + learning_method=AdamOptimizer(), + regularization=L2Regularization(0.5) +) + +img = data_layer(name='pixel', size=28 * 28) + +hidden1 = simple_img_conv_pool(input=img, filter_size=3, num_filters=32, pool_size=3, + num_channel=1) + +hidden2 = fc_layer(input=hidden1, size=200, act=TanhActivation(), + layer_attr=ExtraAttr(drop_rate=0.5)) +predict = fc_layer(input=hidden2, size=10, act=SoftmaxActivation()) + +outputs(classification_cost(input=predict, label=data_layer(name='label', size=10))) diff --git a/doc_cn/concepts/use_concepts.rst b/doc_cn/concepts/use_concepts.rst new file mode 100644 index 0000000000..67e98edabc --- /dev/null +++ b/doc_cn/concepts/use_concepts.rst @@ -0,0 +1,191 @@ +######################### +PaddlePaddle 基本使用概念 +######################### + +PaddlePaddle是一个神经网络学习框架。其单机进程为 :code:`paddle train`。 单机的所有设备使用,均在单机进程内调度完成。 而多机辅助进程 :code:`paddle pserver` 负责联合多个单机进程进行通信,进而充分利用集群的计算资源。 PaddlePaddle同时以 :code:`swig api` 的形式,提供训练结果模型预测的方法和自定义训练流程。 + +下面我们会分别介绍主要进程 :code:`paddle train` 中的一些概念。这些概念会对如何使用PaddlePaddle有一定的帮助。 了解这些概念的前提是,读者已经了解 `基本的神经网络/机器学习原理和概念 `_ 。同时,如果想要了解PaddlePaddle实现中的一些概念,请参考 `PaddlePaddle 编程中的基本概念 `_ 。 + +.. contents:: + +PaddlePaddle 的进程模型 +======================= + +PaddlePaddle进程内嵌了一个 :code:`python` 解释器。 这个 :code:`python` 解释器负责解析用户定义的神经网络配置,和解析用户数据,并将用户数据传入给 PaddlePaddle。 + +.. graphviz:: + + digraph pp_process { + rankdir=LR; + config_file [label="用户神经网络配置"]; + subgraph cluster_pp { + style=filled; + color=lightgrey; + node [style=filled, color=white, shape=box]; + label = "PaddlePaddle C++"; + py [label="Python解释器"]; + } + data_provider [label="用户数据解析"]; + config_file -> py; + py -> data_provider [dir="back"]; + } + +所以,PaddlePaddle单机训练进程,:code:`paddle train` , 对于用户的主要接口语言为 python。 主要需要用户配置的两个文件为 :code:`DataProvider` 和训练文件 :code:`TrainerConfig` 。 + + +DataProvider +============ + +DataProvider是 :code:`paddle train` 的数据提供器。 它负责将用户的原始数据转换成 PaddlePaddle 可以识别的数据类型。每当 PaddlePaddle 需要新的数据训练时,都会调用 DataProvider 返回数据。 当所有数据读取完一轮后,DataProvider 便返回空数据通知 PaddlePaddle。PaddlePaddle负责在下一轮训练开始前,将DataProvider重置。 + +需要注意的是,DataProvider在PaddlePaddle中是被训练逻辑调用的关系, 而不是新的数据驱动训练。并且所有的 :code:`shuffle` , 和一些随机化的噪声添加,都应该在 DataProvider 阶段完成。 + +为了方便用户使用自己的数据格式, PaddlePaddle 提供了 `PyDataProvider`_ 来处理数据。 并且在这个Provider中,PaddlePaddle的 C++ 部分接管了如何shuffle,处理 batch,GPU/CPU通信,双缓冲,异步读取等问题。 用户可以参考 `PyDataProvider`_ 的相关文档,继续深入了解 DataProvider 的使用。 + + +训练文件 +======== + +训练文件是PaddlePaddle中配置神经网络结构、学习优化算法、数据传入方式的地方。 训练文件是一个python文件,使用命令行参数 :code:`--config` 传给 paddle 的主程序。 例如\: + +.. code-block:: bash + + paddle train --config=trainer_config.py + +一个典型简单的训练文件可能为 + +.. literalinclude:: trainer_config.py + :linenos: + +下面我们详细的介绍一下训练文件中各个模块的概念。 + + +trainer_config_helpers +---------------------- + +PaddlePaddle的配置文件与PaddlePaddle C++端通信的最基础协议是 :code:`protobuf` 。而为了避免用户直接写比较难写的 protobuf string,我们书写了一个helpers来生成这个protobuf包。所以在文件的开始,import这些helpers函数。 + +需要注意的是,这个 :code:`paddle.trainer_config_helpers` 包是标准的python包,这意味着用户可以选择自己喜欢的 :code:`ide` 或者编辑器来编写Paddle的配置文件,这个python包注释文档比较完善,并且考虑了IDE的代码提示与类型注释。 + +data_sources +------------ + +data_sources是配置神经网络的数据源。这里使用的函数是 :code:`define_py_data_sources2` ,这个函数是定义了使用 `PyDataProvider`_ 作为数据源。 而后缀 :code:`2` 是Paddle历史遗留问题,因为Paddle之前使用的 PyDataProvider 性能较差,所以完全重构了一个新的 `PyDataProvider`_ 。 + +data_sources里面的 train_list 和 test_list 指定的是训练文件列表和测试文件列表。 如果传入一个字符串的话,是指一个训练列表文件。这个训练列表文件中包含的是每一个训练或者测试文件的路径。如果传入一个list的话,则会默认生成一个 list 文件,再传入给 train.list 或者 test.list 。 + +而 :code:`module` 和 :code:`obj` 指定了 DataProvider 的模块名和函数名。 + +更具体的使用,请参考 `PyDataProvider`_ 。 + +settings +-------- + +`settings`_ 是神经网络训练算法相关的设置项。包括学习率,batch_size,优化算法,正则方法等等。具体的使用方法请参考 `settings`_ 文档。 + +网络配置 +-------- + +上述网络配置中余下的部分均是神经网络配置。第一行是定义一个名字叫 "pixel" 的 :code:`data_layer` 。每一个layer返回的都是一个 :code:`LayerOutput` 对象。 这里第一层的输出对象是 :code:`img` 。然后这个对象传输给了另一个 layer 函数, +:code:`simple_img_conv_pool` 。:code:`simple_img_conv_pool` 是一个组合层, +包括了图像的卷积 (convolution) 和池化(pooling), +并继续接了一个全连接层( :code:`fc_layer` ),然后再接了一个Softmax的全连接层。 + +最终,网络配置输出了 :code:`classification_cost` 。标记网络输出的函数为 +:code:`outputs` 。网络的输出是神经网络的优化目标,神经网络训练的时候,实际上就是 +要最小化这个输出。 + +在神经网络进行预测的时候,实际上网络的输出也是通过 :code:`outputs` 标记。 + + +Layer、Projection、Operator +=========================== + +PaddlePaddle的网络基本上是基于Layer来配置的。所谓的Layer即是神经网络的某一层, +而神经网络的某一层,一般是封装了许多复杂操作的操作集合。比如最简单的 +:code:`fc_layer` ,也包括矩阵乘法,多输入的求和,和activation。 + +.. code-block:: python + + data = data_layer(name='data', size=200) + out = fc_layer(input=data, size=200, act=TanhActivation()) + +而对于更灵活配置需求,可能这样基于Layer的配置是不灵活的。于是 PaddlePaddle 提供 +了基于 Projection 或者 Operator 的配置。使用Projection和Operator需要与 +:code:`mixed_layer` 配合使用。 :code:`mixed_layer` 是将layer中的元素累加求和, +并且做一个 :code:`activation` , 而这个layer具体如何计算,是交由内部的Projection +和 Operator 定义。Projection是指含有可学习参数的操作,而Operator不含有可学习的 +参数,输入全是其他Layer的输出。 + + +例如,和 :code:`fc_layer` 同样功能的 :code:`mixed_layer` 。 + +.. code-block:: python + + data = data_layer(name='data', size=200) + with mixed_layer(size=200) as out: + out += full_matrix_projection(input=data) + +PaddlePaddle可以使用的mixed layer 配置出非常复杂的网络,甚至可以直接配置一个完整的LSTM。 +用户可以参考 `mixed_layer`_ 的相关文档进行配置。 + +如何利用单机的所有GPU或所有CPU核心 +================================== + +PaddlePaddle的单机进程 :code:`paddle train` 可以充分利用一台计算机上所有的GPU资 +源或者CPU。 + +如果要使用机器上多块GPU,使用如下命令即可\: + +.. code-block:: bash + + paddle train --use_gpu=true --trainer_count=4 # use 4 gpu card, 0, 1, 2, 3 + +如果要使用机器上多块CPU, 使用如下命令即可\: + +.. code-block:: bash + + paddle train --trainer_config=4 # use 4 cpu cores. + +对于其他设置GPU的选择情况,例如选择第0、2号GPU显卡,则可以使用 :code:`CUDA_VISIBLE_DEVICES` 环境变量来选择部分的显卡。 具体可以参考连接`masking-gpus`_ 。 可以使用的命令为 + +.. code-block:: bash + + env CUDA_VISIBLE_DEVICES=0,2 paddle train --use_gpu=true --trainer_config=2 + +如何利用多台机器的计算资源训练神经网络 +====================================== + +PaddlePaddle多机使用的经典方法是通过 :code:`Parameter Server` 来对多机的 :code:`paddle train` 进行同步。 而多机训练神经网络,首先要讲数据切分到不同的机器上。 切分数据文件的方式在PaddlePaddle的开源实现中并没有提供工具包。 但是切分数据并不是一件非常复杂的事情,也不是神经网络实现的重点。 + +多机训练过程中,经典的拓扑结构如下\: + +.. graphviz:: pserver_topology.dot + +图中每个灰色方块是一台机器,在每个机器中,先去启动一个 :code:`paddle pserver` 进程,并确定整体的端口号。可能的参数是\: + +.. code-block:: bash + + paddle pserver --port=5000 --num_gradient_servers=4 --nics='eth0' + +这里说明系统的 :code:`paddle pserver` 的起始端口是 :code:`5000` ,并且有四个训练进程(:code:`gradient_servers`,Paddle同时将 :code:`paddle train` 进程称作 :code:`GradientServer` 。因为其为负责提供Gradient的进程)。 而对于训练进程的话,则需要在 :code:`paddle pserver` 启动之后,再在各个节点上运行如下命令\: + +.. code-block:: bash + + paddle train --port=5000 --pservers=192.168.100.101,192.168.100.102,192.168.100.103,192.168.100.104 --config=... + +对于简单的多机协同使用上述方式即可。同时,pserver/train 通常在高级情况下,还有两个参数需要设置,他们是 + +* --ports_num\: 一个 pserver进程共绑定多少个端口用来做稠密更新。默认是1 +* --ports_num_for_sparse\: 一个pserver进程共绑定多少端口用来做稀疏更新,默认是0 + +使用手工指定端口数量,是因为Paddle的网络通信中,使用了 :code:`int32` 作为消息长度,比较容易在大模型下溢出。所以,在 :code:`paddle pserver` 进程中可以启动多个子线程去接受 trainer 的数据,这样单个子线程的长度就不会溢出了。但是这个值不可以调的过大,因为增加这个值,还是对性能,尤其是内存占用有一定的开销的,另外稀疏更新的端口如果太大的话,很容易某一个参数服务器没有分配到任何参数。 + +详细的说明可以参考,使用 `集群训练Paddle`_ 。 + + +.. _PyDataProvider: ../ui/data_provider/pydataprovider2.html +.. _settings: ../../doc/ui/api/trainer_config_helpers/optimizers.html#settings +.. _mixed_layer: ../../doc/ui/api/trainer_config_helpers/layers.html#mixed-layer +.. _masking-gpu: http://www.acceleware.com/blog/cudavisibledevices-masking-gpus +.. _集群训练Paddle: ../cluster/index.html diff --git a/doc_cn/faq/index.rst b/doc_cn/faq/index.rst index 283607957c..db28b4436f 100644 --- a/doc_cn/faq/index.rst +++ b/doc_cn/faq/index.rst @@ -166,4 +166,14 @@ PaddlePaddle的参数使用名字 :code:`name` 作为参数的ID,相同名字 这里 :code:`hidden_a` 和 :code:`hidden_b` 使用了同样的parameter和bias。并且softmax层的两个输入也使用了同样的参数 :code:`softmax_param`。 +7. *-cp27mu-linux_x86_64.whl is not a supported wheel on this platform. +----------------------------------------------------------------------- + +出现这个问题的主要原因是,系统编译wheel包的时候,使用的 :code:`wheel` 包是最新的, +而系统中的 :code:`pip` 包比较老。具体的解决方法是,更新 :code:`pip` 包并重新编译PaddlePaddle。 +更新 :code:`pip` 包的方法是\: + +.. code-block:: bash + + pip install --upgrade pip diff --git a/doc_cn/index.rst b/doc_cn/index.rst index 715da44fb4..e74b0942a6 100644 --- a/doc_cn/index.rst +++ b/doc_cn/index.rst @@ -5,6 +5,7 @@ PaddlePaddle文档 -------- * `介绍 `_ * `快速入门 `_ +* `基本使用概念 `_ * `编译与安装 `_ * `用户接口 `_ * `使用示例 `_