diff --git a/demo/image_classification/prediction.py b/demo/image_classification/prediction.py index 9a86aafcb2..49c0ff600c 100755 --- a/demo/image_classification/prediction.py +++ b/demo/image_classification/prediction.py @@ -126,7 +126,7 @@ class ImageClassifier(): # For oversampling, average predictions across crops. # If not, the shape of output[name]: (1, class_number), # the mean is also applicable. - return output[output_layer].mean(0) + return output[output_layer]['value'].mean(0) def predict(self, image=None, output_layer=None): assert isinstance(image, basestring) diff --git a/demo/model_zoo/resnet/classify.py b/demo/model_zoo/resnet/classify.py index 4631816c43..6074cc1d3a 100755 --- a/demo/model_zoo/resnet/classify.py +++ b/demo/model_zoo/resnet/classify.py @@ -156,7 +156,7 @@ class ImageClassifier(): # For oversampling, average predictions across crops. # If not, the shape of output[name]: (1, class_number), # the mean is also applicable. - res[name] = output[name].mean(0) + res[name] = output[name]['value'].mean(0) return res diff --git a/paddle/api/Arguments.cpp b/paddle/api/Arguments.cpp index 41beed38a8..a3f4bfffc9 100644 --- a/paddle/api/Arguments.cpp +++ b/paddle/api/Arguments.cpp @@ -38,6 +38,13 @@ Arguments* Arguments::createByPaddleArgumentVector(void* ptr) { return args; } +Arguments* Arguments::createByPaddleArgument(const void* ptr) { + auto p = (paddle::Argument*)(ptr); + auto args = new Arguments(); + args->m->outputs.push_back(*p); + return args; +} + Matrix* Arguments::getSlotValue(size_t idx) const throw(RangeError) { auto& a = m->getArg(idx); return Matrix::createByPaddleMatrixPtr(&a.value); diff --git a/paddle/api/GradientMachine.cpp b/paddle/api/GradientMachine.cpp index 66115f8293..538ca2999f 100644 --- a/paddle/api/GradientMachine.cpp +++ b/paddle/api/GradientMachine.cpp @@ -144,12 +144,12 @@ Parameter* GradientMachine::getParameter(size_t i) throw(RangeError) { void GradientMachine::randParameters() { m->machine->randParameters(); } -Matrix* GradientMachine::getLayerOutput(const std::string& layerName) const +Arguments* GradientMachine::getLayerOutput(const std::string& layerName) const throw(UnsupportError) { - auto nn = std::dynamic_pointer_cast(m->machine); + auto nn = m->machine; if (nn) { - auto mat = nn->getLayerOutput(layerName); - return Matrix::createByPaddleMatrixPtr(&mat); + auto arg = nn->getLayerOutput(layerName); + return Arguments::createByPaddleArgument(&arg); } else { throw UnsupportError(); } diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 80c50cdb08..1831b8e170 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -454,6 +454,7 @@ public: private: static Arguments* createByPaddleArgumentVector(void* ptr); + static Arguments* createByPaddleArgument(const void* ptr); void* getInternalArgumentsPtr() const; private: @@ -769,7 +770,7 @@ public: void randParameters(); - Matrix* getLayerOutput(const std::string& layerName) const + Arguments* getLayerOutput(const std::string& layerName) const throw(UnsupportError); /** @@ -956,7 +957,7 @@ public: Arguments* getForwardOutput(); - Matrix* getLayerOutput(const std::string& layerName); + Arguments* getLayerOutput(const std::string& layerName) const; }; /// the N-Best results generated from one input sequence. diff --git a/paddle/api/Trainer.cpp b/paddle/api/Trainer.cpp index d83dc380be..84e4ca054a 100644 --- a/paddle/api/Trainer.cpp +++ b/paddle/api/Trainer.cpp @@ -131,12 +131,11 @@ void Trainer::testOneDataBatch(size_t batchSize, const Arguments& args) { void TrainerPrivate::finishTestPeriod() { tester_->finishTestPeriod(); } void Trainer::finishTestPeriod() { m->finishTestPeriod(); } -Matrix* Trainer::getLayerOutput(const std::string& layerName) { - auto nn = std::dynamic_pointer_cast( - this->m->getGradientMachine()); +Arguments* Trainer::getLayerOutput(const std::string& layerName) const { + auto nn = this->m->getGradientMachine(); CHECK(nn) << "trainerInternal_.getGradientMachine() is not NeuralNetwork"; - auto m = nn->getLayerOutput(layerName); - return Matrix::createByPaddleMatrixPtr(&m); + auto arg = nn->getLayerOutput(layerName); + return Arguments::createByPaddleArgument(&arg); } void Trainer::forwardOneBatch(size_t batchSize) { diff --git a/paddle/gserver/gradientmachines/GradientMachine.h b/paddle/gserver/gradientmachines/GradientMachine.h index 0829968d87..bc2f2f8563 100644 --- a/paddle/gserver/gradientmachines/GradientMachine.h +++ b/paddle/gserver/gradientmachines/GradientMachine.h @@ -134,6 +134,10 @@ public: backward(callback); } + virtual Argument getLayerOutput(const std::string& layerName) { + return *((Argument*)nullptr); + } + // see comment in Layer.h for the function with the same name virtual void resetState() {} diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 80f223824d..123273f916 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -282,6 +282,18 @@ void MultiGradientMachine::forwardBackward(const std::vector& inArgs, backwardImp(callback); } +Argument MultiGradientMachine::getLayerOutput(const std::string& layerName) { + std::vector args; + args.reserve(threads_.size()); + + for (auto& thread : threads_) { + args.push_back(thread->getGradientMachine()->getLayerOutput(layerName)); + } + outLayerArgs_.concat(args, false /* use_gpu */, outArgStream_, passType_); + + return outLayerArgs_; +} + void MultiGradientMachine::backwardImp(const UpdateCallback& callback) { for (size_t i = 0; i < parameters_.size(); i++) { if (!parameters_[i]->useGpu() || parameters_[i]->isStatic()) continue; diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.h b/paddle/gserver/gradientmachines/MultiGradientMachine.h index 9be15ef4bc..838a52b515 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.h +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.h @@ -189,6 +189,8 @@ public: PassType passType, const UpdateCallback& callback); + virtual Argument getLayerOutput(const std::string& layerName); + virtual void onPassEnd(); virtual void finish(); @@ -314,6 +316,8 @@ protected: std::vector outArgs_; hl_stream_t outArgStream_; + Argument outLayerArgs_; + /// ParameterType which needs to be merged from each GPU std::vector mergeTypes_; int numDevices_; /* number of gpu devices */ diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index 273a9111c3..4512aacc81 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -293,11 +293,10 @@ void NeuralNetwork::backward(const UpdateCallback& callback) { } } -MatrixPtr NeuralNetwork::getLayerOutput(const std::string& layerName) { - auto it = layerMap_.find(layerName); - CHECK(it != layerMap_.end()) << "Cannot find layer: " << layerName; - return it->second->getOutputValue(); +Argument NeuralNetwork::getLayerOutput(const std::string& layerName) { + return getLayer(layerName)->getOutput(); } + void NeuralNetwork::onPassEnd() { for (auto& layer : layers_) { layer->onPassEnd(); diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.h b/paddle/gserver/gradientmachines/NeuralNetwork.h index 25af4abcf8..e7b6c43840 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.h +++ b/paddle/gserver/gradientmachines/NeuralNetwork.h @@ -87,7 +87,8 @@ public: virtual void backward(const UpdateCallback& callback = nullptr); - MatrixPtr getLayerOutput(const std::string& layerName); + virtual Argument getLayerOutput(const std::string& layerName); + const LayerPtr& getLayer(const std::string& layerName) const { auto it = layerMap_.find(layerName); CHECK(it != layerMap_.end()) << "Unknown layer " << layerName; diff --git a/paddle/gserver/layers/CosSimLayer.cpp b/paddle/gserver/layers/CosSimLayer.cpp index a6c0300acf..57ba124e40 100644 --- a/paddle/gserver/layers/CosSimLayer.cpp +++ b/paddle/gserver/layers/CosSimLayer.cpp @@ -42,7 +42,7 @@ void CosSimLayer::forward(PassType passType) { /* malloc memory for the output_ if necessary */ int batchSize = getInputValue(0)->getHeight(); int size = getSize(); - CHECK_EQ(forward_.size(), 1) << "Only one forward function needed"; + CHECK_EQ(forward_.size(), 1UL) << "Only one forward function needed"; { REGISTER_TIMER_INFO("CosFwResetTimer", getName().c_str()); @@ -68,7 +68,7 @@ void CosSimLayer::forward(PassType passType) { void CosSimLayer::backward(const UpdateCallback& callback) { /* activation */ { REGISTER_TIMER_INFO("CosBpAtvTimer", getName().c_str()); - CHECK_EQ(backward_.size(), 1) << "Only one backward function needed"; + CHECK_EQ(backward_.size(), 1UL) << "Only one backward function needed"; const auto outG = this->getOutputGrad(); const auto outV = this->getOutputValue(); diff --git a/paddle/gserver/layers/CosSimVecMatLayer.cpp b/paddle/gserver/layers/CosSimVecMatLayer.cpp index aabafd473a..0f887d8adf 100644 --- a/paddle/gserver/layers/CosSimVecMatLayer.cpp +++ b/paddle/gserver/layers/CosSimVecMatLayer.cpp @@ -112,7 +112,7 @@ bool CosSimVecMatLayer::init(const LayerMap& layerMap, void CosSimVecMatLayer::forward(PassType passType) { Layer::forward(passType); - CHECK_EQ(forward_.size(), 1) << "Only one forward function needed"; + CHECK_EQ(forward_.size(), 1UL) << "Only one forward function needed"; MatrixPtr inV0 = getInputValue(0); MatrixPtr inV1 = getInputValue(1); @@ -145,7 +145,7 @@ void CosSimVecMatLayer::forward(PassType passType) { } void CosSimVecMatLayer::backward(const UpdateCallback& callback) { - CHECK_EQ(backward_.size(), 1) << "Only one forward function needed"; + CHECK_EQ(backward_.size(), 1UL) << "Only one forward function needed"; MatrixPtr inV0 = getInputValue(0); MatrixPtr inV1 = getInputValue(1); diff --git a/paddle/math/tests/test_RowBuffer.cpp b/paddle/math/tests/test_RowBuffer.cpp index 5f66f22ef7..8cc4c69a1a 100644 --- a/paddle/math/tests/test_RowBuffer.cpp +++ b/paddle/math/tests/test_RowBuffer.cpp @@ -17,10 +17,10 @@ limitations under the License. */ TEST(RowBuffer, testAutoGrow) { paddle::RowBuffer buf(128); - ASSERT_EQ(128, buf.getWidth()); + ASSERT_EQ(128UL, buf.getWidth()); ASSERT_TRUE(buf.isAutoGrowth()); buf.resize(2); - ASSERT_EQ(2, buf.getRowCount()); + ASSERT_EQ(2UL, buf.getRowCount()); for (size_t i = 0; i < buf.getWidth() * 2; ++i) { buf.data()[i] = i; } @@ -35,7 +35,7 @@ TEST(RowBuffer, testAutoGrow) { data[i] = i; } - ASSERT_EQ(3, buf.getRowCount()); + ASSERT_EQ(3UL, buf.getRowCount()); for (size_t i = 0; i < buf.getRowCount() - 1; ++i) { for (size_t j = 0; j < buf.getWidth(); ++j) { ASSERT_NEAR(i * buf.getWidth() + j, buf.get(i)[j], 1e-5); @@ -51,7 +51,7 @@ TEST(RowBuffer, testWithMemBuf) { std::make_shared(128 * 2 * sizeof(real)); paddle::RowBuffer buf(mem, 128); ASSERT_TRUE(!buf.isAutoGrowth()); - ASSERT_EQ(2, buf.getRowCount()); + ASSERT_EQ(2UL, buf.getRowCount()); for (size_t i = 0; i < buf.getWidth() * 2; ++i) { buf.data()[i] = i; } diff --git a/paddle/py_paddle/util.py b/paddle/py_paddle/util.py index ce105d249a..a708def1d2 100644 --- a/paddle/py_paddle/util.py +++ b/paddle/py_paddle/util.py @@ -208,7 +208,7 @@ def __monkeypatch_gradient_machine__(): output = dict() for name in layerNames: - output[name] = __matrix_to_numpy__(self.getLayerOutput(name)) + output[name] = __arguments_to_numpy__(0, self.getLayerOutput(name)) return output swig_paddle.GradientMachine.getLayerOutputs = getLayerOutputs diff --git a/python/paddle/reader/__init__.py b/python/paddle/reader/__init__.py index 493b410e82..7373dc461b 100644 --- a/python/paddle/reader/__init__.py +++ b/python/paddle/reader/__init__.py @@ -21,3 +21,5 @@ # # r = paddle.reader.buffered(paddle.reader.creator.text("hello.txt")) from decorator import * + +import creator diff --git a/python/paddle/reader/creator.py b/python/paddle/reader/creator.py new file mode 100644 index 0000000000..5a91bb0b8e --- /dev/null +++ b/python/paddle/reader/creator.py @@ -0,0 +1,53 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['np_array', 'text_file'] + + +def np_array(x): + """ + Creates a reader that yields elements of x, if it is a + numpy vector. Or rows of x, if it is a numpy matrix. + Or any sub-hyperplane indexed by the highest dimension. + + :param x: the numpy array to create reader from. + :returns: data reader created from x. + """ + + def reader(): + if x.ndim < 1: + yield x + + for e in x: + yield e + + return reader + + +def text_file(path): + """ + Creates a data reader that outputs text line by line from given text file. + Trailing new line ('\n') of each line will be removed. + + :path: path of the text file. + :returns: data reader of text file + """ + + def reader(): + f = open(path, "r") + for l in f: + yield l.rstrip('\n') + f.close() + + return reader diff --git a/python/paddle/reader/tests/CMakeLists.txt b/python/paddle/reader/tests/CMakeLists.txt index 502c897d89..da072fb3db 100644 --- a/python/paddle/reader/tests/CMakeLists.txt +++ b/python/paddle/reader/tests/CMakeLists.txt @@ -2,3 +2,8 @@ add_test(NAME reader_decorator_test COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/ ${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/reader/tests/decorator_test.py WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle) + +add_test(NAME reader_creator_test + COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/ + ${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/reader/tests/creator_test.py + WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle) diff --git a/python/paddle/reader/tests/creator_test.py b/python/paddle/reader/tests/creator_test.py new file mode 100644 index 0000000000..eda8ab6715 --- /dev/null +++ b/python/paddle/reader/tests/creator_test.py @@ -0,0 +1,38 @@ +# Copyright PaddlePaddle contributors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import paddle.reader.creator +import numpy as np +import os + + +class TestNumpyArray(unittest.TestCase): + def test_numpy_array(self): + l = [[1, 2, 3], [4, 5, 6]] + x = np.array(l, np.int32) + reader = paddle.reader.creator.np_array(x) + for idx, e in enumerate(reader()): + self.assertItemsEqual(e, l[idx]) + + +class TestTextFile(unittest.TestCase): + def test_text_file(self): + path = os.path.join(os.path.dirname(__file__), "test_data_creator.txt") + reader = paddle.reader.creator.text_file(path) + for idx, e in enumerate(reader()): + self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/reader/tests/test_data_creator.txt b/python/paddle/reader/tests/test_data_creator.txt new file mode 100644 index 0000000000..a2a8d47d43 --- /dev/null +++ b/python/paddle/reader/tests/test_data_creator.txt @@ -0,0 +1,3 @@ +0 1 +2 3 +4 5 diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/paddle/v2/dataset/config.py b/python/paddle/v2/dataset/config.py new file mode 100644 index 0000000000..69e96d65ef --- /dev/null +++ b/python/paddle/v2/dataset/config.py @@ -0,0 +1,8 @@ +import os + +__all__ = ['DATA_HOME'] + +DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set') + +if not os.path.exists(DATA_HOME): + os.makedirs(DATA_HOME) diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py new file mode 100644 index 0000000000..db84f37aa4 --- /dev/null +++ b/python/paddle/v2/dataset/mnist.py @@ -0,0 +1,39 @@ +import sklearn.datasets.mldata +import sklearn.model_selection +import numpy +from config import DATA_HOME + +__all__ = ['train_creator', 'test_creator'] + + +def __mnist_reader_creator__(data, target): + def reader(): + n_samples = data.shape[0] + for i in xrange(n_samples): + yield (data[i] / 255.0).astype(numpy.float32), int(target[i]) + + return reader + + +TEST_SIZE = 10000 + +data = sklearn.datasets.mldata.fetch_mldata( + "MNIST original", data_home=DATA_HOME) +X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + data.data, data.target, test_size=TEST_SIZE, random_state=0) + + +def train_creator(): + return __mnist_reader_creator__(X_train, y_train) + + +def test_creator(): + return __mnist_reader_creator__(X_test, y_test) + + +def unittest(): + assert len(list(test_creator()())) == TEST_SIZE + + +if __name__ == '__main__': + unittest() diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py index aa2942bc9f..10e255dc94 100644 --- a/python/paddle/v2/optimizer.py +++ b/python/paddle/v2/optimizer.py @@ -3,7 +3,10 @@ import paddle.trainer_config_helpers.optimizers as v1_optimizers import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils import paddle.v2 -__all__ = ['Adam', 'Adamax'] +__all__ = [ + 'Momentum', 'Adam', 'Adamax', 'AdaGrad', 'DecayedAdaGrad', 'AdaDelta', + 'RMSProp', 'ModelAverage', 'L2Regularization' +] class Optimizer(object): @@ -38,6 +41,14 @@ class Optimizer(object): pass_num) +class Momentum(Optimizer): + def __init__(self, momentum=None, sparse=False, **kwargs): + learning_method = v1_optimizers.MomentumOptimizer( + momentum=None, sparse=False) + super(Momentum, self).__init__( + learning_method=learning_method, **kwargs) + + class Adam(Optimizer): def __init__(self, beta1=0.9, beta2=0.999, epsilon=1e-8, **kwargs): learning_method = v1_optimizers.AdamOptimizer( @@ -52,7 +63,45 @@ class Adamax(Optimizer): super(Adamax, self).__init__(learning_method=learning_method, **kwargs) +class AdaGrad(Optimizer): + def __init__(self, **kwargs): + learning_method = v1_optimizers.AdaGradOptimizer() + super(AdaGrad, self).__init__(learning_method=learning_method, **kwargs) + + +class DecayedAdaGrad(Optimizer): + def __init__(self, rho=0.95, epsilon=1e-06, **kwargs): + learning_method = v1_optimizers.DecayedAdaGradOptimizer( + rho=rho, epsilon=epsilon) + super(DecayedAdaGrad, self).__init__( + learning_method=learning_method, **kwargs) + + +class AdaDelta(Optimizer): + def __init__(self, rho=0.95, epsilon=1e-06, **kwargs): + learning_method = v1_optimizers.AdaDeltaOptimizer( + rho=rho, epsilon=epsilon) + super(AdaDelta, self).__init__( + learning_method=learning_method, **kwargs) + + +class RMSProp(Optimizer): + def __init__(self, rho=0.95, epsilon=1e-6, **kwargs): + learning_method = v1_optimizers.RMSPropOptimizer( + rho=rho, epsilon=epsilon) + super(RMSProp, self).__init__(learning_method=learning_method, **kwargs) + + +ModelAverage = v1_optimizers.ModelAverage +L2Regularization = v1_optimizers.L2Regularization + if __name__ == '__main__': swig_api.initPaddle('--use_gpu=false') - opt = paddle.v2.optimizer.Adam() - print opt.enable_types() + for opt in [ + Momentum(), Adam(), Adamax(), AdaGrad(), DecayedAdaGrad(), + AdaDelta(), RMSProp(), Adam( + model_average=ModelAverage(average_window=0.5), + regularization=L2Regularization(rate=0.5), + gradient_clipping_threshold=25) + ]: + print opt, opt.enable_types()