From 529f24c262850974dd8ba4c5b7ad1a4e3e0230fc Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 12 Dec 2016 18:17:27 +0800 Subject: [PATCH 01/26] cpu cmrnorm --- paddle/cuda/src/hl_cuda_cnn.cu | 192 +++++++++-------------- paddle/gserver/tests/test_LayerGrad.cpp | 3 +- paddle/math/Matrix.cpp | 137 ++++++++++------ paddle/math/tests/test_matrixCompare.cpp | 115 ++++++++++++++ 4 files changed, 279 insertions(+), 168 deletions(-) diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 0992286f36..1516accaae 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -381,57 +381,45 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, CHECK_SYNC("hl_avgpool_backward failed"); } -__global__ void KeCMRNormFillScale(size_t nthreads, const real* in, +__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, real* scale, size_t channels, size_t height, size_t width, size_t size, real alpha) { - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - // find out the local offset - size_t w = index % width; - size_t h = (index / width) % height; - size_t n = index / width / height; - size_t offset = (n * channels * height + h) * width + w; - size_t step = height * width; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + in += offset; scale += offset; - size_t head = 0; - size_t pre_pad = (size - 1) / 2; - size_t post_pad = size - pre_pad - 1; - real accum_scale = 0; - // fill the scale at [n, :, h, w] - // accumulate values - while (head < post_pad) { - accum_scale += in[head * step] * in[head * step]; - ++head; - } - // until we reach size, nothing needs to be subtracted - while (head < size) { - accum_scale += in[head * step] * in[head * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - // both add and subtract - while (head < channels) { - accum_scale += in[head * step] * in[head * step]; - accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - // subtract only - while (head < channels + post_pad) { - accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; + const int step = height * width; + const int pre_pad = (size - 1) / 2; + const int post_pad = size - pre_pad - 1; + + real accum = 0; + int index = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += in[index * step] * in[index * step]; + } + if (index >= size) { + accum -= in[(index - size) * step] * in[(index - size) * step]; + } + if (index >= post_pad) { + scale[(index - post_pad) * step] = 1. + accum * alpha; + } + ++index; } } } - __global__ void KeCMRNormOutput(size_t nthreads, const real* in, - const real* scale, real negative_beta, - real* out) { - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { +__global__ void KeCMRNormOutput(size_t inputSize, const real* in, + const real* scale, real negative_beta, + real* out) { + const int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < inputSize) { out[index] = in[index] * pow(scale[index], negative_beta); } } @@ -440,84 +428,60 @@ void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale, real* out, size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta) { - size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 - 1) / 1024; - size_t blocksY = 1; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - - KeCMRNormFillScale<<>> - (threadsNum, in, scale, channels, height, width, sizeX, alpha); - - threadsNum = frameCnt * height * width *channels; - blocksX = (threadsNum + 1024 -1) / 1024; - dim3 threads2(1024, 1); - dim3 grid2(blocksX, blocksY); - KeCMRNormOutput<<>> - (threadsNum, in, scale, beta, out); + size_t imageSize = frameCnt * height * width; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormFillScale<<>> + (imageSize, in, scale, channels, height, width, sizeX, alpha); + + size_t inputSize = frameCnt * height * width *channels; + blockSize = 1024; + gridSize = (inputSize + 1024 - 1) / 1024; + KeCMRNormOutput<<>> + (inputSize, in, scale, beta, out); CHECK_SYNC("hl_CMRNorm_forward"); } -__global__ void KeCMRNormDiff(size_t nthreads, const real* bottom_data, +__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, const real* top_data, const real* scale, const real* top_diff, size_t channels, size_t height, size_t width, size_t size, real negative_beta, real cache_ratio, real* bottom_diff ) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - // find out the local offset - size_t w = index % width; - size_t h = (index / width) % height; - size_t n = index / width / height; - size_t offset = (n * channels * height + h) * width + w; - size_t step = height * width; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; bottom_data += offset; top_data += offset; scale += offset; top_diff += offset; bottom_diff += offset; - int head = 0; - int pre_pad = size - (size + 1) / 2; - int post_pad = size - pre_pad - 1; - real accum_ratio = 0; - // accumulate values - while (head < post_pad) { - accum_ratio += top_diff[head * step] * - top_data[head * step] / scale[head * step]; - ++head; - } - // until we reach size, nothing needs to be subtracted - while (head < size) { - accum_ratio += top_diff[head * step] * - top_data[head * step] / scale[head * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - // both add and subtract - while (head < channels) { - accum_ratio += top_diff[head * step] * top_data[head * step] / - scale[head * step]; - accum_ratio -= top_diff[(head - size) * step] * - top_data[(head - size) * step] / scale[(head - size) * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - // subtract only - while (head < channels + post_pad) { - accum_ratio -= top_diff[(head - size) * step] * - top_data[(head - size) * step] / scale[(head - size) * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; + + const int step = height * width; + const int pre_pad = size - (size + 1) / 2; + const int post_pad = size - pre_pad - 1; + + int index = 0; + real accum = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += top_diff[index * step] * top_data[index * step] / + scale[index * step]; + } + if (index >= size) { + accum -= top_diff[(index - size) * step] * + top_data[(index - size) * step] / scale[(index - size) * step]; + } + if (index >= post_pad) { + bottom_diff[(index - post_pad) * step] += + top_diff[(index - post_pad) * step] * + pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(index - post_pad) * step] * accum; + } + ++index; } } } @@ -528,14 +492,12 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV, real *inDiff, size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta) { - size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 - 1) / 1024; - size_t blocksY = 1; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - KeCMRNormDiff <<>> - (threadsNum, inV, outV, scale, outDiff, channels, - height, width, sizeX, alpha, beta, inDiff); + size_t imageSize = frameCnt * height * width; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormDiff <<>> + (imageSize, inV, outV, scale, outDiff, channels, + height, width, sizeX, alpha, beta, inDiff); CHECK_SYNC("hl_CMRNorm_backward"); } diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 7983d9fe64..8ade15daac 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1021,11 +1021,10 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) { testLayerGrad(config, "norm", 100, trans, useGpu); } -#ifndef PADDLE_ONLY_CPU TEST(Layer, NormLayer) { testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ true); + testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ false); } -#endif void setPoolConfig(TestConfig* config, PoolConfig* pool, diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c69e074a76..2cde11dd47 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -2227,52 +2227,43 @@ void CpuMatrix::crossMapNormalFwd(Matrix& input, size_t sizeX, float scale, float pow) { - size_t num = input.getHeight(); + CHECK(isContiguous()); + CHECK(input.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(getHeight(), input.getHeight()); + CHECK_EQ(getWidth(), input.getWidth()); + CHECK_EQ(getHeight(), denoms.getHeight()); + CHECK_EQ(getWidth(), denoms.getWidth()); + + size_t numSample = input.getHeight(); + size_t numCols = input.getWidth(); size_t height = imgSizeH; size_t width = imgSizeW; - size_t numCols = input.getWidth(); - CHECK(height * width * channels == input.getWidth()); - CHECK(denoms.getHeight() == input.getHeight() && - denoms.getWidth() == input.getWidth() && input.getHeight() == height_ && - input.getWidth() == width_); - real* imgData = input.getData(); - real* diffData = input.getData(); - real* targetData = getData(); - size_t halfSize = sizeX / 2; - size_t imgPixels = height * width; - - // use integral vector to implement the sum in local window - real* integralData = - (real*)malloc((channels + sizeX + 1) * sizeof(real)); // NOLINT // TODO: - for (size_t i = 0; i <= halfSize; i++) { - integralData[i] = 0; - } - for (size_t i = 0; i < num; i++) { - real* targetPtr = targetData + i * numCols; - real* imgPtr = imgData + i * numCols; - real* diffPtr = diffData + i * numCols; - for (size_t m = 0; m < height; m++) { - for (size_t n = 0; n < width; n++) { - for (size_t c = 0; c < channels; c++) { - integralData[c + halfSize + 1] = - integralData[c + halfSize] + _square(*(diffPtr + c * imgPixels)); - } - for (size_t k = channels + halfSize + 1; k <= channels + sizeX; k++) { - integralData[k] = integralData[channels + halfSize]; + CHECK(height * width * channels == numCols); + + // TODO(hedaoyuan) After commit TensorExpress code, + // Reconstruction this code to remove the temporary memory. + CpuMatrix tmp(channels, height * width); + CpuMatrix tmp2(tmp.getData(), 1, channels * height * width); + denoms.zero(); + const int start = -((int)sizeX - 1) / 2; + const int end = (int)sizeX + start; + for (size_t i = 0; i < numSample; i++) { + input.subMatrix(i, 1)->square2(tmp2); + CpuMatrix subDen( + denoms.subMatrix(i, 1)->getData(), channels, height * width); + for (int c = 0; c < (int)channels; c++) { + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + subDen.subMatrix(c, 1)->add(*tmp.subMatrix(c + s, 1)); } - for (size_t k = 0; k < channels; k += 1) { - real a = integralData[k + sizeX] - integralData[k]; - a = scale * a + 1; - targetPtr[k * imgPixels] = imgPtr[k * imgPixels] * _pow(a, -pow); - } - diffPtr++; - targetPtr++; - imgPtr++; } } } - free(integralData); - integralData = NULL; + + denoms.add(scale, (real)1); + this->pow2(denoms, -pow); + this->dotMul(input); } void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, @@ -2282,19 +2273,63 @@ void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, size_t channels, size_t imgSizeH, size_t imgSizeW, - size_t size, + size_t sizeX, float scale, float pow) { - LOG(FATAL) << "Not implemented"; - - CHECK(imgSizeH * imgSizeW * channels == preOutV.getWidth()); - CHECK(denoms.getHeight() == preOutV.getHeight() && - denoms.getWidth() == preOutV.getWidth() && - preOutV.getHeight() == height_ && preOutV.getWidth() == width_); - CHECK(denoms.getHeight() == localGrad.getHeight() && - denoms.getWidth() == localGrad.getWidth()); - - // NOLINT // TODO: + CHECK(isContiguous()); + CHECK(localGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(preOutV.isContiguous()); + CHECK(localOutV.isContiguous()); + CHECK_EQ(getHeight(), localGrad.getHeight()); + CHECK_EQ(getWidth(), localGrad.getWidth()); + CHECK_EQ(getHeight(), denoms.getHeight()); + CHECK_EQ(getWidth(), denoms.getWidth()); + CHECK_EQ(getHeight(), preOutV.getHeight()); + CHECK_EQ(getWidth(), preOutV.getWidth()); + CHECK_EQ(getHeight(), localOutV.getHeight()); + CHECK_EQ(getWidth(), localOutV.getWidth()); + + size_t numSample = getHeight(); + size_t numCols = getWidth(); + size_t height = imgSizeH; + size_t width = imgSizeW; + CHECK(height * width * channels == numCols); + + // TODO(hedaoyuan) After commit TensorExpress code, + // Reconstruction this code to remove the temporary memory. + CpuMatrix tmp(1, height * width); + + const int start = -((int)sizeX) / 2; + const int end = (int)sizeX + start; + const real ratio = -(real)2 * scale * pow; + for (size_t i = 0; i < numSample; i++) { + CpuMatrix inputDiff( + this->subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix outDiff( + localGrad.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix input( + preOutV.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix output( + localOutV.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix subDen( + denoms.subMatrix(i, 1)->getData(), channels, height * width); + + for (int c = 0; c < (int)channels; c++) { + tmp.pow2(*subDen.subMatrix(c, 1), -pow); + inputDiff.subMatrix(c, 1) + ->addDotMul(tmp, *outDiff.subMatrix(c, 1), (real)1, (real)1); + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + tmp.dotMul(*outDiff.subMatrix(c + s, 1), *output.subMatrix(c + s, 1)); + tmp.mulScalar(ratio); + tmp.dotDiv(tmp, *subDen.subMatrix(c + s, 1)); + tmp.dotMul(*input.subMatrix(c, 1)); + inputDiff.subMatrix(c, 1)->add(tmp); + } + } + } + } } /** diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 713792d82b..5233a9af40 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1261,6 +1261,121 @@ TEST(Matrix, MaxOutFwdBwd) { } } } +void testCrossMapNormalFwd( + int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { + float scale = 1.5; + float pow = 0.5; + int width = imgSizeH * imgSizeW * channels; + MatrixPtr input = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr denorms = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr target = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr inputGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr denormsGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr targetGpu = GpuMatrix::create(numSamples, width, false, true); + + input->randomizeUniform(); + target->randomizeUniform(); + inputGpu->copyFrom(*input); + targetGpu->copyFrom(*target); + + target->crossMapNormalFwd( + *input, imgSizeH, imgSizeW, *denorms, channels, sizeX, scale, pow); + targetGpu->crossMapNormalFwd( + *inputGpu, imgSizeH, imgSizeW, *denormsGpu, channels, sizeX, scale, pow); + + TensorCheckErr(*target, *targetGpu); + TensorCheckErr(*denorms, *denormsGpu); +} + +TEST(Matrix, crossMapNormalFwd) { + for (auto numSamples : {5, 32}) { + for (auto channels : {1, 5, 32}) { + for (auto imgSizeH : {5, 33, 100}) { + for (auto imgSizeW : {5, 32, 96}) { + for (auto sizeX : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " sizeX=" << sizeX; + testCrossMapNormalFwd( + numSamples, channels, imgSizeH, imgSizeW, sizeX); + } + } + } + } + } +} + +void testCrossMapNormalBwd( + int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { + float scale = 1.5; + float pow = 0.5; + size_t width = imgSizeH * imgSizeW * channels; + MatrixPtr localGrad = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr denoms = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr output = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr preOutV = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr localOutV = CpuMatrix::create(numSamples, width, false, false); + + localGrad->randomizeUniform(); + denoms->randomizeUniform(); + preOutV->randomizeUniform(); + localOutV->randomizeUniform(); + output->randomizeUniform(); + denoms->add(0.01); + + MatrixPtr localGradGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr denomsGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr outputGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr preOutVGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr localOutVGpu = GpuMatrix::create(numSamples, width, false, true); + + localGradGpu->copyFrom(*localGrad); + denomsGpu->copyFrom(*denoms); + preOutVGpu->copyFrom(*preOutV); + localOutVGpu->copyFrom(*localOutV); + outputGpu->copyFrom(*output); + + output->crossMapNormalBwd(*localGrad, + *denoms, + *preOutV, + *localOutV, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + outputGpu->crossMapNormalBwd(*localGradGpu, + *denomsGpu, + *preOutVGpu, + *localOutVGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + TensorCheckErr(*output, *outputGpu); +} + +TEST(Matrix, crossMapNormalBwd) { + for (auto numSamples : {5, 32}) { + for (auto channels : {1, 5, 32}) { + for (auto imgSizeH : {5, 33, 100}) { + for (auto imgSizeW : {5, 32, 96}) { + for (auto sizeX : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " sizeX=" << sizeX; + testCrossMapNormalBwd( + numSamples, channels, imgSizeH, imgSizeW, sizeX); + } + } + } + } + } +} int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); From 95035908b4f47e61bad12d0ed49bf62a1734b2cf Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 13 Dec 2016 14:27:42 +0800 Subject: [PATCH 02/26] add CrossMapNormal --- paddle/math/cross_map_normal_op.cpp | 129 +++++++++++++++++++++ paddle/math/cross_map_normal_op.h | 47 ++++++++ paddle/math/tests/test_matrixCompare.cpp | 137 ++++++++++++----------- 3 files changed, 248 insertions(+), 65 deletions(-) create mode 100644 paddle/math/cross_map_normal_op.cpp create mode 100644 paddle/math/cross_map_normal_op.h diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp new file mode 100644 index 0000000000..3eb51b5998 --- /dev/null +++ b/paddle/math/cross_map_normal_op.cpp @@ -0,0 +1,129 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "cross_map_normal_op.h" + +namespace paddle { + +// NCHW +void CrossMapNormal::operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(outputs.isContiguous()); + CHECK(inputs.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(outputs.getHeight(), inputs.getHeight()); + CHECK_EQ(outputs.getWidth(), inputs.getWidth()); + CHECK_EQ(outputs.getHeight(), denoms.getHeight()); + CHECK_EQ(outputs.getWidth(), denoms.getWidth()); + + size_t numSample = inputs.getHeight(); + size_t numCols = inputs.getWidth(); + size_t imageSize = imgSizeH * imgSizeW; + CHECK(imageSize * channels == numCols); + + denoms = denoms.constant(1.0); + const int start = -((int)sizeX - 1) / 2; + const int end = (int)sizeX + start; + for (size_t i = 0; i < numSample; i++) { + real* denomsData = denoms.getData() + i * numCols; + real* inputData = inputs.getData() + i * numCols; + for (int c = 0; c < (int)channels; c++) { + CpuVector denom(imageSize, denomsData + c * imageSize); + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + CpuVector input(imageSize, inputData + (c + s) * imageSize); + denom += input.square() * scale; + } + } + } + } + outputs = inputs * denoms.pow(-pow); +} + +void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(inputsGrad.isContiguous()); + CHECK(outputsGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(inputsValue.isContiguous()); + CHECK(outputsValue.isContiguous()); + CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); + + size_t numSample = inputsGrad.getHeight(); + size_t numCols = inputsGrad.getWidth(); + size_t imageSize = imgSizeH * imgSizeW; + CHECK(imageSize * channels == numCols); + + std::function oneImage = [=](real* data, + size_t offset) { + return CpuVector(imageSize, data + offset); + }; + + const int start = -((int)sizeX) / 2; + const int end = (int)sizeX + start; + const real ratio = -(real)2 * scale * pow; + for (size_t i = 0; i < numSample; i++) { + size_t sOffset = i * numCols; + real* inputGradData = inputsGrad.getData() + sOffset; + real* inputData = inputsValue.getData() + sOffset; + real* denomData = denoms.getData() + sOffset; + real* outputGradData = outputsGrad.getData() + sOffset; + real* outputData = outputsValue.getData() + sOffset; + + for (int c = 0; c < (int)channels; c++) { + size_t cOffset = c * imageSize; + CpuVector inputGrad = oneImage(inputGradData, cOffset); + CpuVector inputValue = oneImage(inputData, cOffset); + CpuVector denom = oneImage(denomData, cOffset); + CpuVector outputGrad = oneImage(outputGradData, cOffset); + + inputGrad = inputGrad + denom.pow(-pow) * outputGrad; + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + size_t offset = (c + s) * imageSize; + CpuVector output = oneImage(outputData, offset); + CpuVector outputGrad = oneImage(outputGradData, offset); + CpuVector denom = oneImage(denomData, offset); + + inputGrad += ((outputGrad * output * ratio) / denom) * inputValue; + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h new file mode 100644 index 0000000000..2f99607252 --- /dev/null +++ b/paddle/math/cross_map_normal_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/math/Matrix.h" + +namespace paddle { + +struct CrossMapNormal { + void operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow); +}; + +struct CrossMapNormalGrad { + void operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow); +}; + +} // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 5233a9af40..9bb1fdbdab 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/gserver/tests/TestUtil.h" #include "paddle/utils/Stat.h" #include "TensorCheck.h" +#include "paddle/math/cross_map_normal_op.h" using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -1261,30 +1262,32 @@ TEST(Matrix, MaxOutFwdBwd) { } } } + void testCrossMapNormalFwd( int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { float scale = 1.5; float pow = 0.5; int width = imgSizeH * imgSizeW * channels; - MatrixPtr input = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr denorms = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr target = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr inputGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr denormsGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr targetGpu = GpuMatrix::create(numSamples, width, false, true); - - input->randomizeUniform(); - target->randomizeUniform(); - inputGpu->copyFrom(*input); - targetGpu->copyFrom(*target); - - target->crossMapNormalFwd( - *input, imgSizeH, imgSizeW, *denorms, channels, sizeX, scale, pow); - targetGpu->crossMapNormalFwd( - *inputGpu, imgSizeH, imgSizeW, *denormsGpu, channels, sizeX, scale, pow); - - TensorCheckErr(*target, *targetGpu); - TensorCheckErr(*denorms, *denormsGpu); + CpuMatrix inputs(numSamples, width); + CpuMatrix denoms(numSamples, width); + CpuMatrix outputs(numSamples, width); + GpuMatrix inputsGpu(numSamples, width); + GpuMatrix denomsGpu(numSamples, width); + GpuMatrix outputsGpu(numSamples, width); + + inputs.randomizeUniform(); + outputs.randomizeUniform(); + inputsGpu.copyFrom(inputs); + outputsGpu.copyFrom(outputs); + + CrossMapNormal cross; + cross( + outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); + outputsGpu.crossMapNormalFwd( + inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); + + TensorCheckErr(outputs, outputsGpu); + TensorCheckErr(denoms, denomsGpu); } TEST(Matrix, crossMapNormalFwd) { @@ -1310,53 +1313,57 @@ void testCrossMapNormalBwd( float scale = 1.5; float pow = 0.5; size_t width = imgSizeH * imgSizeW * channels; - MatrixPtr localGrad = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr denoms = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr output = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr preOutV = CpuMatrix::create(numSamples, width, false, false); - MatrixPtr localOutV = CpuMatrix::create(numSamples, width, false, false); - - localGrad->randomizeUniform(); - denoms->randomizeUniform(); - preOutV->randomizeUniform(); - localOutV->randomizeUniform(); - output->randomizeUniform(); - denoms->add(0.01); - - MatrixPtr localGradGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr denomsGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr outputGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr preOutVGpu = GpuMatrix::create(numSamples, width, false, true); - MatrixPtr localOutVGpu = GpuMatrix::create(numSamples, width, false, true); - - localGradGpu->copyFrom(*localGrad); - denomsGpu->copyFrom(*denoms); - preOutVGpu->copyFrom(*preOutV); - localOutVGpu->copyFrom(*localOutV); - outputGpu->copyFrom(*output); - output->crossMapNormalBwd(*localGrad, - *denoms, - *preOutV, - *localOutV, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - outputGpu->crossMapNormalBwd(*localGradGpu, - *denomsGpu, - *preOutVGpu, - *localOutVGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - TensorCheckErr(*output, *outputGpu); + CpuMatrix inputsGrad(numSamples, width); + CpuMatrix inputsValue(numSamples, width); + CpuMatrix outputsGrad(numSamples, width); + CpuMatrix outputsValue(numSamples, width); + CpuMatrix denoms(numSamples, width); + + outputsGrad.randomizeUniform(); + denoms.randomizeUniform(); + inputsValue.randomizeUniform(); + outputsValue.randomizeUniform(); + inputsGrad.randomizeUniform(); + denoms.add(0.01); + + GpuMatrix inputsGradGpu(numSamples, width); + GpuMatrix inputsValueGpu(numSamples, width); + GpuMatrix outputsGradGpu(numSamples, width); + GpuMatrix outputsValueGpu(numSamples, width); + GpuMatrix denomsGpu(numSamples, width); + + outputsGradGpu.copyFrom(outputsGrad); + denomsGpu.copyFrom(denoms); + inputsValueGpu.copyFrom(inputsValue); + outputsValueGpu.copyFrom(outputsValue); + inputsGradGpu.copyFrom(inputsGrad); + + CrossMapNormalGrad cross; + cross(inputsGrad, + inputsValue, + outputsGrad, + outputsValue, + denoms, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + inputsGradGpu.crossMapNormalBwd(outputsGradGpu, + denomsGpu, + inputsValueGpu, + outputsValueGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + TensorCheckErr(inputsGrad, inputsGradGpu); } TEST(Matrix, crossMapNormalBwd) { From e357f2715843cd531ce0b0143647ed5561d2fceb Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 13 Dec 2016 17:55:31 +0800 Subject: [PATCH 03/26] add GPU CrossMapNormal --- paddle/math/cross_map_normal_op.cpp | 42 ++--- paddle/math/cross_map_normal_op.h | 37 ++++- paddle/math/cross_map_normal_op_gpu.cu | 194 +++++++++++++++++++++++ paddle/math/tests/test_matrixCompare.cpp | 66 +++++--- 4 files changed, 286 insertions(+), 53 deletions(-) create mode 100644 paddle/math/cross_map_normal_op_gpu.cu diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index 3eb51b5998..be242926af 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -17,15 +17,16 @@ limitations under the License. */ namespace paddle { // NCHW -void CrossMapNormal::operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { +template <> +void CrossMapNormal::operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { CHECK(outputs.isContiguous()); CHECK(inputs.isContiguous()); CHECK(denoms.isContiguous()); @@ -58,17 +59,18 @@ void CrossMapNormal::operator()(CpuMatrix& outputs, outputs = inputs * denoms.pow(-pow); } -void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { +template <> +void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { CHECK(inputsGrad.isContiguous()); CHECK(outputsGrad.isContiguous()); CHECK(denoms.isContiguous()); diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index 2f99607252..c2bb95f6b1 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -18,10 +18,30 @@ limitations under the License. */ namespace paddle { +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0, + DEVICE_TYPE_CPU = 1, + DEVICE_TYPE_GPU = 2, +}; + +template +struct MatrixT; + +template <> +struct MatrixT { + using type = CpuMatrix; +}; + +template <> +struct MatrixT { + using type = GpuMatrix; +}; + +template struct CrossMapNormal { - void operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, + void operator()(typename MatrixT::type& outputs, + typename MatrixT::type& denoms, + typename MatrixT::type& inputs, size_t channels, size_t imgSizeH, size_t imgSizeW, @@ -30,12 +50,13 @@ struct CrossMapNormal { real pow); }; +template struct CrossMapNormalGrad { - void operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, + void operator()(typename MatrixT::type& inputsGrad, + typename MatrixT::type& inputsValue, + typename MatrixT::type& outputsGrad, + typename MatrixT::type& outputsValue, + typename MatrixT::type& denoms, size_t channels, size_t imgSizeH, size_t imgSizeW, diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu new file mode 100644 index 0000000000..0a154d97ac --- /dev/null +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -0,0 +1,194 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "cross_map_normal_op.h" + +namespace paddle { + +__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, + real* scale, size_t channels, + size_t height, size_t width, size_t size, + real alpha) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + + in += offset; + scale += offset; + const int step = height * width; + const int pre_pad = (size - 1) / 2; + const int post_pad = size - pre_pad - 1; + + real accum = 0; + int index = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += in[index * step] * in[index * step]; + } + if (index >= size) { + accum -= in[(index - size) * step] * in[(index - size) * step]; + } + if (index >= post_pad) { + scale[(index - post_pad) * step] = 1. + accum * alpha; + } + ++index; + } + } +} + +__global__ void KeCMRNormOutput(size_t inputSize, const real* in, + const real* scale, real negative_beta, + real* out) { + const int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < inputSize) { + out[index] = in[index] * pow(scale[index], negative_beta); + } +} + +template <> +void CrossMapNormal::operator()(GpuMatrix& outputs, + GpuMatrix& denoms, + GpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(outputs.isContiguous()); + CHECK(inputs.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(outputs.getHeight(), inputs.getHeight()); + CHECK_EQ(outputs.getWidth(), inputs.getWidth()); + CHECK_EQ(outputs.getHeight(), denoms.getHeight()); + CHECK_EQ(outputs.getWidth(), denoms.getWidth()); + + size_t numSample = inputs.getHeight(); + size_t numCols = inputs.getWidth(); + CHECK(imgSizeH * imgSizeW * channels == numCols); + + real* inputsData = inputs.getData(); + real* denomsData = denoms.getData(); + real* outputsData = outputs.getData(); + + size_t imageSize = numSample * imgSizeH * imgSizeW; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormFillScale<<>> + (imageSize, inputsData, denomsData, + channels, imgSizeH, imgSizeW, sizeX, scale); + + size_t inputSize = numSample * imgSizeH * imgSizeW *channels; + blockSize = 1024; + gridSize = (inputSize + 1024 - 1) / 1024; + KeCMRNormOutput<<>> + (inputSize, inputsData, denomsData, -pow, outputsData); + + CHECK_SYNC("CrossMapNormalFwd"); +} + +__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, + const real* top_data, const real* scale, + const real* top_diff, size_t channels, + size_t height, size_t width, size_t size, + real negative_beta, real cache_ratio, + real* bottom_diff ) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + bottom_data += offset; + top_data += offset; + scale += offset; + top_diff += offset; + bottom_diff += offset; + + const int step = height * width; + const int pre_pad = size - (size + 1) / 2; + const int post_pad = size - pre_pad - 1; + + int index = 0; + real accum = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += top_diff[index * step] * top_data[index * step] / + scale[index * step]; + } + if (index >= size) { + accum -= top_diff[(index - size) * step] * + top_data[(index - size) * step] / scale[(index - size) * step]; + } + if (index >= post_pad) { + bottom_diff[(index - post_pad) * step] += + top_diff[(index - post_pad) * step] * + pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(index - post_pad) * step] * accum; + } + ++index; + } + } +} + +template <> +void CrossMapNormalGrad::operator()(GpuMatrix& inputsGrad, + GpuMatrix& inputsValue, + GpuMatrix& outputsGrad, + GpuMatrix& outputsValue, + GpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(inputsGrad.isContiguous()); + CHECK(outputsGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(inputsValue.isContiguous()); + CHECK(outputsValue.isContiguous()); + CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); + + size_t numSample = inputsGrad.getHeight(); + size_t numCols = inputsGrad.getWidth(); + CHECK(imgSizeH * imgSizeW * channels == numCols); + + size_t imageSize = numSample * imgSizeH * imgSizeW; + real* inputsGradData = inputsGrad.getData(); + real* inputsData = inputsValue.getData(); + real* denomsData = denoms.getData(); + real* outputsGradData = outputsGrad.getData(); + real* outputsData = outputsValue.getData(); + + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormDiff <<>> + (imageSize, inputsData, outputsData, denomsData, outputsGradData, channels, + imgSizeH, imgSizeW, sizeX, -pow, 2.0f * pow * scale, inputsGradData); + CHECK_SYNC("KeCMRNormDiff"); +} + +} // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 9bb1fdbdab..8d7a4fb94d 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1280,11 +1280,25 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); - CrossMapNormal cross; - cross( + CrossMapNormal cpuCross; + cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); + + CrossMapNormal gpuCross; + gpuCross(outputsGpu, + denomsGpu, + inputsGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + +#if 0 outputsGpu.crossMapNormalFwd( inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); +#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); @@ -1339,29 +1353,31 @@ void testCrossMapNormalBwd( outputsValueGpu.copyFrom(outputsValue); inputsGradGpu.copyFrom(inputsGrad); - CrossMapNormalGrad cross; - cross(inputsGrad, - inputsValue, - outputsGrad, - outputsValue, - denoms, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - inputsGradGpu.crossMapNormalBwd(outputsGradGpu, - denomsGpu, - inputsValueGpu, - outputsValueGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); + CrossMapNormalGrad cpuCross; + cpuCross(inputsGrad, + inputsValue, + outputsGrad, + outputsValue, + denoms, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + CrossMapNormalGrad gpuCross; + gpuCross(inputsGradGpu, + inputsValueGpu, + outputsGradGpu, + outputsValueGpu, + denomsGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); TensorCheckErr(inputsGrad, inputsGradGpu); } From a1d2abc16d9c7b42af6dcb41902423ae2904ee9a Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 14 Dec 2016 18:46:40 +0800 Subject: [PATCH 04/26] add Function --- paddle/math/Function.cpp | 47 +++++++++++++ paddle/math/Function.h | 84 ++++++++++++++++++++++++ paddle/math/cross_map_normal_op.cpp | 46 +++++++++++++ paddle/math/cross_map_normal_op.h | 20 +----- paddle/math/tests/test_matrixCompare.cpp | 15 +++-- 5 files changed, 188 insertions(+), 24 deletions(-) create mode 100644 paddle/math/Function.cpp create mode 100644 paddle/math/Function.h diff --git a/paddle/math/Function.cpp b/paddle/math/Function.cpp new file mode 100644 index 0000000000..21d2719172 --- /dev/null +++ b/paddle/math/Function.cpp @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" + +namespace paddle { + +template <> +size_t FuncConfig::get(const std::string& key) const { + auto it = valueMap_.find(key); + CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'"; + return it->second.s; +} + +template <> +real FuncConfig::get(const std::string& key) const { + auto it = valueMap_.find(key); + CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'"; + return it->second.r; +} + +template <> +void FuncConfig::set(const std::string& key, size_t v) { + CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; + valueMap_[key].s = v; +} + +template <> +void FuncConfig::set(const std::string& key, real v) { + CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; + valueMap_[key].r = v; +} + +ClassRegistrar FunctionBase::funcRegistrar_; + +} // namespace paddle diff --git a/paddle/math/Function.h b/paddle/math/Function.h new file mode 100644 index 0000000000..b41ba2a13d --- /dev/null +++ b/paddle/math/Function.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/utils/ClassRegistrar.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0, + DEVICE_TYPE_CPU = 1, + DEVICE_TYPE_GPU = 2, +}; + +template +struct MatrixT; + +template <> +struct MatrixT { + using type = CpuMatrix; +}; + +template <> +struct MatrixT { + using type = GpuMatrix; +}; + +typedef std::vector Arguments; + +class FuncConfig { +public: + union value { + size_t s; + real r; + }; + + template + T get(const std::string& key) const; + + template + void set(const std::string& key, T v); + +protected: + std::map valueMap_; +}; + +class FunctionBase { +public: + virtual ~FunctionBase() {} + + virtual void init(const FuncConfig& config) {} + + virtual void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) {} + + static ClassRegistrar funcRegistrar_; +}; + +#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName + +#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ + static InitFunction __reg_type_##typeName([]() { \ + FunctionBase::funcRegistrar_ \ + .registerClass>( \ + FUNC_NAME(typeName, deviceName)); \ + }) + +} // namespace paddle diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index be242926af..0b72732063 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -128,4 +128,50 @@ void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, } } +template +class CrossMapNormalFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + size_ = config.get("size"); + scale_ = config.get("scale"); + pow_ = config.get("pow"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(1, inputs.size()); + CHECK_EQ(2, outputs.size()); + CHECK_EQ(0, inouts.size()); + + auto input = dynamic_cast::type&>(inputs[0]); + auto output = + dynamic_cast::type&>(outputs[0]); + auto denom = + dynamic_cast::type&>(outputs[1]); + + CHECK(input.isContiguous()); + CHECK(output.isContiguous()); + CHECK(denom.isContiguous()); + CHECK_EQ(output.getHeight(), input.getHeight()); + CHECK_EQ(output.getWidth(), input.getWidth()); + CHECK_EQ(output.getHeight(), denom.getHeight()); + CHECK_EQ(output.getWidth(), denom.getWidth()); + + // CrossMapNormal cross; + // need: + // size_t channels, + // size_t imgSizeH, + // size_t imgSizeW, + // cross(output, denom, input, ); + } + +private: + size_t size_; + real scale_; + real pow_; +}; + +REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); + } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index c2bb95f6b1..86f54abde1 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -14,29 +14,11 @@ limitations under the License. */ #pragma once +#include "Function.h" #include "paddle/math/Matrix.h" namespace paddle { -enum DeviceType { - DEVICE_TYPE_UNSPECIFIED = 0, - DEVICE_TYPE_CPU = 1, - DEVICE_TYPE_GPU = 2, -}; - -template -struct MatrixT; - -template <> -struct MatrixT { - using type = CpuMatrix; -}; - -template <> -struct MatrixT { - using type = GpuMatrix; -}; - template struct CrossMapNormal { void operator()(typename MatrixT::type& outputs, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 8d7a4fb94d..0b75785528 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/utils/Stat.h" #include "TensorCheck.h" #include "paddle/math/cross_map_normal_op.h" +#include "paddle/math/Function.h" using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -1280,6 +1281,15 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); + FuncConfig config; + config.set("size", (size_t)sizeX); + config.set("scale", scale); + config.set("pow", pow); + FunctionBase* cpu = + FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); + cpu->init(config); + // cpu->calc(); + CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); @@ -1295,11 +1305,6 @@ void testCrossMapNormalFwd( scale, pow); -#if 0 - outputsGpu.crossMapNormalFwd( - inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); -#endif - TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); } From ce1d98e083017afadac9fcd9f94f5c59aceaf6c0 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 10:31:45 +0800 Subject: [PATCH 05/26] Add a Tensor to use as a Function argument --- paddle/math/Function.h | 12 +++++++- paddle/math/cross_map_normal_op.cpp | 37 +++++++++++------------- paddle/math/tests/test_matrixCompare.cpp | 9 ++++-- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/paddle/math/Function.h b/paddle/math/Function.h index b41ba2a13d..539759782b 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -40,7 +40,17 @@ struct MatrixT { using type = GpuMatrix; }; -typedef std::vector Arguments; +typedef std::vector Dims; + +class Tensor { +public: + Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} + + real* buf_; + Dims dims_; +}; + +typedef std::vector Arguments; class FuncConfig { public: diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index 0b72732063..d55bd78c62 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -144,26 +144,23 @@ public: CHECK_EQ(2, outputs.size()); CHECK_EQ(0, inouts.size()); - auto input = dynamic_cast::type&>(inputs[0]); - auto output = - dynamic_cast::type&>(outputs[0]); - auto denom = - dynamic_cast::type&>(outputs[1]); - - CHECK(input.isContiguous()); - CHECK(output.isContiguous()); - CHECK(denom.isContiguous()); - CHECK_EQ(output.getHeight(), input.getHeight()); - CHECK_EQ(output.getWidth(), input.getWidth()); - CHECK_EQ(output.getHeight(), denom.getHeight()); - CHECK_EQ(output.getWidth(), denom.getWidth()); - - // CrossMapNormal cross; - // need: - // size_t channels, - // size_t imgSizeH, - // size_t imgSizeW, - // cross(output, denom, input, ); + CHECK_EQ(inputs[0].dims_.size(), 4); + for (size_t i = 0; i < inputs[0].dims_.size(); i++) { + CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]); + } + + size_t samples = inputs[0].dims_[0]; + size_t channels = inputs[0].dims_[1]; + size_t height = inputs[0].dims_[2]; + size_t width = inputs[0].dims_[3]; + size_t imageSize = channels * height * width; + CpuMatrix input(inputs[0].buf_, samples, imageSize); + CpuMatrix output(outputs[0].buf_, samples, imageSize); + CpuMatrix denom(outputs[1].buf_, samples, imageSize); + + CrossMapNormal cross; + cross(output, denom, input, channels, height, width, size_, scale_, pow_); } private: diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 0b75785528..cd34ea18a7 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1288,12 +1288,17 @@ void testCrossMapNormalFwd( FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); cpu->init(config); - // cpu->calc(); + Dims dims{ + (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; + cpu->calc({Tensor(inputs.getData(), dims)}, + {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, + {}); +#if 0 CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); - +#endif CrossMapNormal gpuCross; gpuCross(outputsGpu, denomsGpu, From 4ebb3eb759903bf95968b578eec99b1364d3bd10 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 11:55:35 +0800 Subject: [PATCH 06/26] imporve Function --- paddle/gserver/layers/NormProjectionLayer.cpp | 60 +++++++++++---- paddle/gserver/layers/NormProjectionLayer.h | 4 + paddle/math/Function.cpp | 6 +- paddle/math/Function.h | 14 ++-- paddle/math/cross_map_normal_op.cpp | 75 ++++++++++--------- paddle/math/cross_map_normal_op.h | 13 ++++ paddle/math/cross_map_normal_op_gpu.cu | 46 ++++-------- paddle/math/tests/test_matrixCompare.cpp | 21 +++++- 8 files changed, 147 insertions(+), 92 deletions(-) diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index ea301292e0..5dda7ee205 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" +#include "paddle/math/cross_map_normal_op.h" #include "NormProjectionLayer.h" namespace paddle { @@ -45,6 +46,16 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, /* the size of inputs for norm-layer is 1 */ CHECK_EQ(config_.inputs_size(), 1); + if (useGpu_) { + normal_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormal, GPU)); + } else { + normal_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormal, CPU)); + } + normal_->init( + FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); + return true; } @@ -62,10 +73,14 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - denoms_->zeroMem(); - - outV->crossMapNormalFwd( - *input, imgSizeH_, imgSizeW_, *denoms_, channels_, size_, scale_, pow_); + Dims dims{(size_t)batchSize, + (size_t)channels_, + (size_t)imgSizeH_, + (size_t)imgSizeW_}; + normal_->calc( + {Tensor(input->getData(), dims)}, + {Tensor(outV->getData(), dims), Tensor(denoms_->getData(), dims)}, + {}); } void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { @@ -80,15 +95,32 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); - preOutGrad->crossMapNormalBwd(*localGrad, - *denoms_, - *preOutV, - *localOutV, - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); + if (useGpu_) { + CrossMapNormalGrad crossGrad; + crossGrad(dynamic_cast(*preOutGrad), + dynamic_cast(*preOutV), + dynamic_cast(*localGrad), + dynamic_cast(*localOutV), + dynamic_cast(*denoms_), + channels_, + imgSizeH_, + imgSizeW_, + size_, + scale_, + pow_); + } else { + CrossMapNormalGrad crossGrad; + crossGrad(dynamic_cast(*preOutGrad), + dynamic_cast(*preOutV), + dynamic_cast(*localGrad), + dynamic_cast(*localOutV), + dynamic_cast(*denoms_), + channels_, + imgSizeH_, + imgSizeW_, + size_, + scale_, + pow_); + } } } // namespace paddle diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 0db8e2551f..ea44669be3 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "NormLayer.h" #include "paddle/math/Matrix.h" +#include "paddle/math/Function.h" #include namespace paddle { @@ -39,5 +40,8 @@ public: bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType); void backward(const UpdateCallback& callback = nullptr); + +protected: + FunctionBase* normal_; }; } // namespace paddle diff --git a/paddle/math/Function.cpp b/paddle/math/Function.cpp index 21d2719172..02880e5ea1 100644 --- a/paddle/math/Function.cpp +++ b/paddle/math/Function.cpp @@ -31,15 +31,17 @@ real FuncConfig::get(const std::string& key) const { } template <> -void FuncConfig::set(const std::string& key, size_t v) { +FuncConfig& FuncConfig::set(const std::string& key, size_t v) { CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; valueMap_[key].s = v; + return *this; } template <> -void FuncConfig::set(const std::string& key, real v) { +FuncConfig& FuncConfig::set(const std::string& key, real v) { CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key; valueMap_[key].r = v; + return *this; } ClassRegistrar FunctionBase::funcRegistrar_; diff --git a/paddle/math/Function.h b/paddle/math/Function.h index 539759782b..f8fab972a6 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -46,6 +46,8 @@ class Tensor { public: Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} + real* getData() const { return buf_; } + real* buf_; Dims dims_; }; @@ -63,7 +65,7 @@ public: T get(const std::string& key) const; template - void set(const std::string& key, T v); + FuncConfig& set(const std::string& key, T v); protected: std::map valueMap_; @@ -84,11 +86,11 @@ public: #define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName -#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ - static InitFunction __reg_type_##typeName([]() { \ - FunctionBase::funcRegistrar_ \ - .registerClass>( \ - FUNC_NAME(typeName, deviceName)); \ +#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \ + static InitFunction __reg_type_##typeName##deviceName([]() { \ + FunctionBase::funcRegistrar_ \ + .registerClass>( \ + FUNC_NAME(typeName, deviceName)); \ }) } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index d55bd78c62..e520351d2e 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -18,45 +18,41 @@ namespace paddle { // NCHW template <> -void CrossMapNormal::operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(outputs.isContiguous()); - CHECK(inputs.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(outputs.getHeight(), inputs.getHeight()); - CHECK_EQ(outputs.getWidth(), inputs.getWidth()); - CHECK_EQ(outputs.getHeight(), denoms.getHeight()); - CHECK_EQ(outputs.getWidth(), denoms.getWidth()); - - size_t numSample = inputs.getHeight(); - size_t numCols = inputs.getWidth(); - size_t imageSize = imgSizeH * imgSizeW; - CHECK(imageSize * channels == numCols); - - denoms = denoms.constant(1.0); - const int start = -((int)sizeX - 1) / 2; - const int end = (int)sizeX + start; - for (size_t i = 0; i < numSample; i++) { - real* denomsData = denoms.getData() + i * numCols; - real* inputData = inputs.getData() + i * numCols; +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t oneImage = height * width; + size_t oneSample = channels * oneImage; + + CpuVector outputsV(numSamples * oneSample, outputs); + CpuVector inputsV(numSamples * oneSample, inputs); + CpuVector denomsV(numSamples * oneSample, denoms); + + denomsV = denomsV.constant(1.0); + const int start = -((int)size - 1) / 2; + const int end = (int)size + start; + for (size_t i = 0; i < numSamples; i++) { + real* oneDenom = denoms + i * oneSample; + real* oneInput = inputs + i * oneSample; for (int c = 0; c < (int)channels; c++) { - CpuVector denom(imageSize, denomsData + c * imageSize); + CpuVector denom(oneImage, oneDenom + c * oneImage); for (int s = start; s < end; s++) { if (c + s >= 0 && c + s < (int)channels) { - CpuVector input(imageSize, inputData + (c + s) * imageSize); + CpuVector input(oneImage, oneInput + (c + s) * oneImage); denom += input.square() * scale; } } } } - outputs = inputs * denoms.pow(-pow); + + outputsV = inputsV * denomsV.pow(-pow); } template <> @@ -154,13 +150,17 @@ public: size_t channels = inputs[0].dims_[1]; size_t height = inputs[0].dims_[2]; size_t width = inputs[0].dims_[3]; - size_t imageSize = channels * height * width; - CpuMatrix input(inputs[0].buf_, samples, imageSize); - CpuMatrix output(outputs[0].buf_, samples, imageSize); - CpuMatrix denom(outputs[1].buf_, samples, imageSize); - CrossMapNormal cross; - cross(output, denom, input, channels, height, width, size_, scale_, pow_); + CrossMapNormal(outputs[0].getData(), + outputs[1].getData(), + inputs[0].getData(), + samples, + channels, + height, + width, + size_, + scale_, + pow_); } private: @@ -170,5 +170,6 @@ private: }; REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); +REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index 86f54abde1..ef9533485e 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -19,6 +19,18 @@ limitations under the License. */ namespace paddle { +template +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow); +#if 0 template struct CrossMapNormal { void operator()(typename MatrixT::type& outputs, @@ -31,6 +43,7 @@ struct CrossMapNormal { real scale, real pow); }; +#endif template struct CrossMapNormalGrad { diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu index 0a154d97ac..9b92974344 100644 --- a/paddle/math/cross_map_normal_op_gpu.cu +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -61,45 +61,29 @@ __global__ void KeCMRNormOutput(size_t inputSize, const real* in, } template <> -void CrossMapNormal::operator()(GpuMatrix& outputs, - GpuMatrix& denoms, - GpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(outputs.isContiguous()); - CHECK(inputs.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(outputs.getHeight(), inputs.getHeight()); - CHECK_EQ(outputs.getWidth(), inputs.getWidth()); - CHECK_EQ(outputs.getHeight(), denoms.getHeight()); - CHECK_EQ(outputs.getWidth(), denoms.getWidth()); - - size_t numSample = inputs.getHeight(); - size_t numCols = inputs.getWidth(); - CHECK(imgSizeH * imgSizeW * channels == numCols); - - real* inputsData = inputs.getData(); - real* denomsData = denoms.getData(); - real* outputsData = outputs.getData(); - - size_t imageSize = numSample * imgSizeH * imgSizeW; +void CrossMapNormal(real* outputs, + real* denoms, + real* inputs, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t imageSize = numSamples * height * width; int blockSize = 1024; int gridSize = (imageSize + 1024 - 1) / 1024; KeCMRNormFillScale<<>> - (imageSize, inputsData, denomsData, - channels, imgSizeH, imgSizeW, sizeX, scale); + (imageSize, inputs, denoms, channels, height, width, size, scale); - size_t inputSize = numSample * imgSizeH * imgSizeW *channels; + size_t inputSize = numSamples * height * width *channels; blockSize = 1024; gridSize = (inputSize + 1024 - 1) / 1024; KeCMRNormOutput<<>> - (inputSize, inputsData, denomsData, -pow, outputsData); + (inputSize, inputs, denoms, -pow, outputs); - CHECK_SYNC("CrossMapNormalFwd"); + CHECK_SYNC("CrossMapNormal"); } __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index cd34ea18a7..aac3f75799 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1281,24 +1281,40 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); +#if 0 FuncConfig config; config.set("size", (size_t)sizeX); config.set("scale", scale); config.set("pow", pow); +#endif FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); - cpu->init(config); + FunctionBase* gpu = + FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, GPU)); + cpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); + gpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); Dims dims{ (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; cpu->calc({Tensor(inputs.getData(), dims)}, {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, {}); + + gpu->calc( + {Tensor(inputsGpu.getData(), dims)}, + {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, + {}); #if 0 CrossMapNormal cpuCross; cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); -#endif + CrossMapNormal gpuCross; gpuCross(outputsGpu, denomsGpu, @@ -1309,6 +1325,7 @@ void testCrossMapNormalFwd( sizeX, scale, pow); +#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); From d2d0010609b6ba621360973b6c6972b836607de3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 16:19:10 +0800 Subject: [PATCH 07/26] add CrossMapNormalGradFunc --- paddle/gserver/layers/NormProjectionLayer.cpp | 41 +++-- paddle/gserver/layers/NormProjectionLayer.h | 7 +- paddle/math/Function.h | 2 +- paddle/math/cross_map_normal_op.cpp | 145 ++++++++++++------ paddle/math/cross_map_normal_op.h | 40 ++--- paddle/math/cross_map_normal_op_gpu.cu | 54 ++----- paddle/math/tests/test_matrixCompare.cpp | 57 ++++--- 7 files changed, 190 insertions(+), 156 deletions(-) diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 03c6952c30..d6923c2192 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -13,10 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "NormProjectionLayer.h" +#include "paddle/math/cross_map_normal_op.h" #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" -#include "paddle/math/cross_map_normal_op.h" -#include "NormProjectionLayer.h" namespace paddle { size_t CMRProjectionNormLayer::getSize() { @@ -48,13 +47,23 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, CHECK_EQ(config_.inputs_size(), 1); if (useGpu_) { - normal_ = FunctionBase::funcRegistrar_.createByType( + forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, GPU)); } else { - normal_ = FunctionBase::funcRegistrar_.createByType( + forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, CPU)); } - normal_->init( + forward_->init( + FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); + + if (useGpu_) { + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, GPU)); + } else { + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, CPU)); + } + backward_->init( FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); return true; @@ -74,13 +83,13 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - Dims dims{(size_t)batchSize, - (size_t)channels_, - (size_t)imgSizeH_, - (size_t)imgSizeW_}; - normal_->calc( - {Tensor(input->getData(), dims)}, - {Tensor(outV->getData(), dims), Tensor(denoms_->getData(), dims)}, + dims_ = {(size_t)batchSize, + (size_t)channels_, + (size_t)imgSizeH_, + (size_t)imgSizeW_}; + forward_->calc( + {Tensor(input->getData(), dims_)}, + {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, {}); } @@ -96,6 +105,13 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); + backward_->calc({Tensor(preOutV->getData(), dims_), + Tensor(localOutV->getData(), dims_), + Tensor(localGrad->getData(), dims_), + Tensor(denoms_->getData(), dims_)}, + {Tensor(preOutGrad->getData(), dims_)}, + {}); +#if 0 if (useGpu_) { CrossMapNormalGrad crossGrad; crossGrad(dynamic_cast(*preOutGrad), @@ -123,5 +139,6 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { scale_, pow_); } +#endif } } // namespace paddle diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 1dc3921283..82aa427f8d 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,9 +16,8 @@ limitations under the License. */ #include #include "NormLayer.h" -#include "paddle/math/Matrix.h" #include "paddle/math/Function.h" -#include +#include "paddle/math/Matrix.h" namespace paddle { @@ -43,6 +42,8 @@ public: void backward(const UpdateCallback& callback = nullptr); protected: - FunctionBase* normal_; + Dims dims_; + FunctionBase* forward_; + FunctionBase* backward_; }; } // namespace paddle diff --git a/paddle/math/Function.h b/paddle/math/Function.h index f8fab972a6..095584c0b1 100644 --- a/paddle/math/Function.h +++ b/paddle/math/Function.h @@ -16,8 +16,8 @@ limitations under the License. */ #include #include -#include "paddle/utils/ClassRegistrar.h" #include "paddle/math/Matrix.h" +#include "paddle/utils/ClassRegistrar.h" namespace paddle { diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index e520351d2e..8547978c99 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "cross_map_normal_op.h" +#include "paddle/math/Vector.h" namespace paddle { @@ -56,66 +57,49 @@ void CrossMapNormal(real* outputs, } template <> -void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(inputsGrad.isContiguous()); - CHECK(outputsGrad.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK(inputsValue.isContiguous()); - CHECK(outputsValue.isContiguous()); - CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); - - size_t numSample = inputsGrad.getHeight(); - size_t numCols = inputsGrad.getWidth(); - size_t imageSize = imgSizeH * imgSizeW; - CHECK(imageSize * channels == numCols); - +void CrossMapNormalGrad(real* inputsGrad, + real* inputsValue, + real* outputsValue, + real* outputsGrad, + real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t oneSample = channels * height * width; std::function oneImage = [=](real* data, size_t offset) { - return CpuVector(imageSize, data + offset); + return CpuVector(height * width, data + offset); }; - const int start = -((int)sizeX) / 2; - const int end = (int)sizeX + start; + const int start = -((int)size) / 2; + const int end = (int)size + start; const real ratio = -(real)2 * scale * pow; - for (size_t i = 0; i < numSample; i++) { - size_t sOffset = i * numCols; - real* inputGradData = inputsGrad.getData() + sOffset; - real* inputData = inputsValue.getData() + sOffset; - real* denomData = denoms.getData() + sOffset; - real* outputGradData = outputsGrad.getData() + sOffset; - real* outputData = outputsValue.getData() + sOffset; + for (size_t i = 0; i < numSamples; i++) { + size_t sOffset = i * oneSample; + real* oneInputGrad = inputsGrad + sOffset; + real* oneInputValue = inputsValue + sOffset; + real* oneDenom = denoms + sOffset; + real* oneOutputGrad = outputsGrad + sOffset; + real* oneOutputValue = outputsValue + sOffset; for (int c = 0; c < (int)channels; c++) { - size_t cOffset = c * imageSize; - CpuVector inputGrad = oneImage(inputGradData, cOffset); - CpuVector inputValue = oneImage(inputData, cOffset); - CpuVector denom = oneImage(denomData, cOffset); - CpuVector outputGrad = oneImage(outputGradData, cOffset); + size_t cOffset = c * height * width; + CpuVector inputGrad = oneImage(oneInputGrad, cOffset); + CpuVector inputValue = oneImage(oneInputValue, cOffset); + CpuVector denom = oneImage(oneDenom, cOffset); + CpuVector outputGrad = oneImage(oneOutputGrad, cOffset); inputGrad = inputGrad + denom.pow(-pow) * outputGrad; for (int s = start; s < end; s++) { if (c + s >= 0 && c + s < (int)channels) { - size_t offset = (c + s) * imageSize; - CpuVector output = oneImage(outputData, offset); - CpuVector outputGrad = oneImage(outputGradData, offset); - CpuVector denom = oneImage(denomData, offset); + size_t offset = (c + s) * height * width; + CpuVector output = oneImage(oneOutputValue, offset); + CpuVector outputGrad = oneImage(oneOutputGrad, offset); + CpuVector denom = oneImage(oneDenom, offset); inputGrad += ((outputGrad * output * ratio) / denom) * inputValue; } @@ -124,6 +108,11 @@ void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, } } +/** + * \param inputs[0] input value. + * \param outputs[0] output value. + * \param outputs[1] denoms. + */ template class CrossMapNormalFunc : public FunctionBase { public: @@ -169,7 +158,65 @@ private: real pow_; }; +/** + * \param inputs[0] input value. + * \param inputs[1] output value. + * \param inputs[2] output grad. + * \param inputs[3] denoms. + * \param outputs[0] input grad. + */ +template +class CrossMapNormalGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { + size_ = config.get("size"); + scale_ = config.get("scale"); + pow_ = config.get("pow"); + } + + void calc(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) override { + CHECK_EQ(4, inputs.size()); + CHECK_EQ(1, outputs.size()); + CHECK_EQ(0, inouts.size()); + + CHECK_EQ(inputs[0].dims_.size(), 4); + for (size_t i = 0; i < inputs[0].dims_.size(); i++) { + CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]); + CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); + } + + size_t samples = inputs[0].dims_[0]; + size_t channels = inputs[0].dims_[1]; + size_t height = inputs[0].dims_[2]; + size_t width = inputs[0].dims_[3]; + + CrossMapNormalGrad(outputs[0].getData(), + inputs[0].getData(), + inputs[1].getData(), + inputs[2].getData(), + inputs[3].getData(), + samples, + channels, + height, + width, + size_, + scale_, + pow_); + } + +private: + size_t size_; + real scale_; + real pow_; +}; + REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); +REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc); +REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc); } // namespace paddle diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index ef9533485e..f065208084 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "Function.h" -#include "paddle/math/Matrix.h" namespace paddle { @@ -30,34 +29,19 @@ void CrossMapNormal(real* outputs, size_t size, real scale, real pow); -#if 0 -template -struct CrossMapNormal { - void operator()(typename MatrixT::type& outputs, - typename MatrixT::type& denoms, - typename MatrixT::type& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow); -}; -#endif template -struct CrossMapNormalGrad { - void operator()(typename MatrixT::type& inputsGrad, - typename MatrixT::type& inputsValue, - typename MatrixT::type& outputsGrad, - typename MatrixT::type& outputsValue, - typename MatrixT::type& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow); -}; +void CrossMapNormalGrad(real* inputsGrad, + real* inputsValue, + real* outputsValue, + real* outputsGrad, + real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow); } // namespace paddle diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu index 9b92974344..6339c04194 100644 --- a/paddle/math/cross_map_normal_op_gpu.cu +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -131,48 +131,26 @@ __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, } template <> -void CrossMapNormalGrad::operator()(GpuMatrix& inputsGrad, - GpuMatrix& inputsValue, - GpuMatrix& outputsGrad, - GpuMatrix& outputsValue, - GpuMatrix& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { - CHECK(inputsGrad.isContiguous()); - CHECK(outputsGrad.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK(inputsValue.isContiguous()); - CHECK(outputsValue.isContiguous()); - CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); - CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); - CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); - - size_t numSample = inputsGrad.getHeight(); - size_t numCols = inputsGrad.getWidth(); - CHECK(imgSizeH * imgSizeW * channels == numCols); - - size_t imageSize = numSample * imgSizeH * imgSizeW; - real* inputsGradData = inputsGrad.getData(); - real* inputsData = inputsValue.getData(); - real* denomsData = denoms.getData(); - real* outputsGradData = outputsGrad.getData(); - real* outputsData = outputsValue.getData(); +void CrossMapNormalGrad(real* inputsGrad, + real* inputsValue, + real* outputsValue, + real* outputsGrad, + real* denoms, + size_t numSamples, + size_t channels, + size_t height, + size_t width, + size_t size, + real scale, + real pow) { + size_t imageSize = numSamples * height * width; int blockSize = 1024; int gridSize = (imageSize + 1024 - 1) / 1024; KeCMRNormDiff <<>> - (imageSize, inputsData, outputsData, denomsData, outputsGradData, channels, - imgSizeH, imgSizeW, sizeX, -pow, 2.0f * pow * scale, inputsGradData); - CHECK_SYNC("KeCMRNormDiff"); + (imageSize, inputsValue, outputsValue, denoms, outputsGrad, channels, + height, width, size, -pow, 2.0f * pow * scale, inputsGrad); + CHECK_SYNC("CrossMapNormalGrad"); } } // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 0341d757f3..bc14651457 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -19,12 +19,11 @@ limitations under the License. */ #include #include "TensorCheck.h" #include "paddle/gserver/tests/TestUtil.h" +#include "paddle/math/Function.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" -#include "paddle/utils/Stat.h" -#include "TensorCheck.h" #include "paddle/math/cross_map_normal_op.h" -#include "paddle/math/Function.h" +#include "paddle/utils/Stat.h" #include "paddle/utils/Util.h" using namespace paddle; // NOLINT @@ -1282,12 +1281,6 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); -#if 0 - FuncConfig config; - config.set("size", (size_t)sizeX); - config.set("scale", scale); - config.set("pow", pow); -#endif FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); FunctionBase* gpu = @@ -1311,22 +1304,6 @@ void testCrossMapNormalFwd( {Tensor(inputsGpu.getData(), dims)}, {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, {}); -#if 0 - CrossMapNormal cpuCross; - cpuCross( - outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); - - CrossMapNormal gpuCross; - gpuCross(outputsGpu, - denomsGpu, - inputsGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); -#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); @@ -1381,6 +1358,35 @@ void testCrossMapNormalBwd( outputsValueGpu.copyFrom(outputsValue); inputsGradGpu.copyFrom(inputsGrad); + FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, CPU)); + FunctionBase* gpu = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, GPU)); + cpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); + gpu->init(FuncConfig() + .set("size", (size_t)sizeX) + .set("scale", scale) + .set("pow", pow)); + + Dims dims{ + (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; + cpu->calc({Tensor(inputsValue.getData(), dims), + Tensor(outputsValue.getData(), dims), + Tensor(outputsGrad.getData(), dims), + Tensor(denoms.getData(), dims)}, + {Tensor(inputsGrad.getData(), dims)}, + {}); + + gpu->calc({Tensor(inputsValueGpu.getData(), dims), + Tensor(outputsValueGpu.getData(), dims), + Tensor(outputsGradGpu.getData(), dims), + Tensor(denomsGpu.getData(), dims)}, + {Tensor(inputsGradGpu.getData(), dims)}, + {}); +#if 0 CrossMapNormalGrad cpuCross; cpuCross(inputsGrad, inputsValue, @@ -1406,6 +1412,7 @@ void testCrossMapNormalBwd( sizeX, scale, pow); +#endif TensorCheckErr(inputsGrad, inputsGradGpu); } From 22a5e478f3b6ecc0e43d31abce39a686b6331165 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 16:36:51 +0800 Subject: [PATCH 08/26] move Function to function dir --- paddle/{math => function}/Function.cpp | 0 paddle/{math => function}/Function.h | 0 paddle/{math => function}/cross_map_normal_op.cpp | 0 paddle/{math => function}/cross_map_normal_op.h | 0 paddle/{math => function}/cross_map_normal_op_gpu.cu | 0 paddle/gserver/layers/NormProjectionLayer.cpp | 1 - paddle/gserver/layers/NormProjectionLayer.h | 2 +- paddle/math/tests/test_matrixCompare.cpp | 3 +-- 8 files changed, 2 insertions(+), 4 deletions(-) rename paddle/{math => function}/Function.cpp (100%) rename paddle/{math => function}/Function.h (100%) rename paddle/{math => function}/cross_map_normal_op.cpp (100%) rename paddle/{math => function}/cross_map_normal_op.h (100%) rename paddle/{math => function}/cross_map_normal_op_gpu.cu (100%) diff --git a/paddle/math/Function.cpp b/paddle/function/Function.cpp similarity index 100% rename from paddle/math/Function.cpp rename to paddle/function/Function.cpp diff --git a/paddle/math/Function.h b/paddle/function/Function.h similarity index 100% rename from paddle/math/Function.h rename to paddle/function/Function.h diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp similarity index 100% rename from paddle/math/cross_map_normal_op.cpp rename to paddle/function/cross_map_normal_op.cpp diff --git a/paddle/math/cross_map_normal_op.h b/paddle/function/cross_map_normal_op.h similarity index 100% rename from paddle/math/cross_map_normal_op.h rename to paddle/function/cross_map_normal_op.h diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/function/cross_map_normal_op_gpu.cu similarity index 100% rename from paddle/math/cross_map_normal_op_gpu.cu rename to paddle/function/cross_map_normal_op_gpu.cu diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index d6923c2192..e69c406993 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "NormProjectionLayer.h" -#include "paddle/math/cross_map_normal_op.h" #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 82aa427f8d..3c4876ece6 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "NormLayer.h" -#include "paddle/math/Function.h" +#include "paddle/function/Function.h" #include "paddle/math/Matrix.h" namespace paddle { diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index bc14651457..da7a585484 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -18,11 +18,10 @@ limitations under the License. */ #include #include "TensorCheck.h" +#include "paddle/function/Function.h" #include "paddle/gserver/tests/TestUtil.h" -#include "paddle/math/Function.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" -#include "paddle/math/cross_map_normal_op.h" #include "paddle/utils/Stat.h" #include "paddle/utils/Util.h" From 558e86927caa2bbe0bc97b287f9d1abe73cfaaa3 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 17:12:22 +0800 Subject: [PATCH 09/26] add CMakeLists --- cmake/util.cmake | 1 + paddle/CMakeLists.txt | 1 + paddle/function/CMakeLists.txt | 12 ++++++++++++ paddle/function/cross_map_normal_op.cpp | 4 +++- paddle/gserver/CMakeLists.txt | 8 ++------ 5 files changed, 19 insertions(+), 7 deletions(-) create mode 100644 paddle/function/CMakeLists.txt diff --git a/cmake/util.cmake b/cmake/util.cmake index 38366373c6..03734e7839 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -96,6 +96,7 @@ function(link_paddle_exe TARGET_NAME) target_circle_link_libraries(${TARGET_NAME} ARCHIVE_START paddle_gserver + paddle_function ${METRIC_LIBS} ARCHIVE_END paddle_pserver diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index fb3af8ea92..2daea052b0 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(cuda) +add_subdirectory(function) add_subdirectory(utils) add_subdirectory(math) add_subdirectory(parameter) diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt new file mode 100644 index 0000000000..8fad0e3ebd --- /dev/null +++ b/paddle/function/CMakeLists.txt @@ -0,0 +1,12 @@ +file(GLOB FUNCTION_HEADERS . *.h) + +if(NOT WITH_GPU) + file(GLOB FUNCTION_SOURCES . *.cpp) + add_library(paddle_function STATIC ${FUNCTION_SOURCES}) +else() + file(GLOB FUNCTION_SOURCES . *.cpp *.cu) + cuda_add_library(paddle_function ${FUNCTION_SOURCES}) +endif() + +add_style_check_target(paddle_function ${FUNCTION_SOURCES}) +add_style_check_target(paddle_function ${FUNCTION_HEADERS}) diff --git a/paddle/function/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp index 8547978c99..0391a58d89 100644 --- a/paddle/function/cross_map_normal_op.cpp +++ b/paddle/function/cross_map_normal_op.cpp @@ -215,8 +215,10 @@ private: }; REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); -REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc); +#endif } // namespace paddle diff --git a/paddle/gserver/CMakeLists.txt b/paddle/gserver/CMakeLists.txt index a066f80c22..4f92150ec8 100644 --- a/paddle/gserver/CMakeLists.txt +++ b/paddle/gserver/CMakeLists.txt @@ -27,16 +27,12 @@ if(NOT WITH_GPU) list(REMOVE_ITEM GSERVER_HEADER layers/CudnnConvLayer.h layers/CudnnPoolLayer.h - layers/CudnnBatchNormLayer.h - layers/NormProjectionLayer.h - layers/NormLayer.h) + layers/CudnnBatchNormLayer.h) list(REMOVE_ITEM GSERVER_SOURCES layers/CudnnConvLayer.cpp layers/CudnnPoolLayer.cpp - layers/CudnnBatchNormLayer.cpp - layers/NormProjectionLayer.cpp - layers/NormLayer.cpp) + layers/CudnnBatchNormLayer.cpp) compile_cu_as_cpp(layers/LstmCompute.cu) compile_cu_as_cpp(layers/GruCompute.cu) endif() From d11e2b401348c147b20507863a43b8952f17d6a1 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 17:33:01 +0800 Subject: [PATCH 10/26] Remove some useless code --- paddle/cuda/include/hl_cnn.h | 56 ------ paddle/cuda/include/stub/hl_cnn_stub.h | 24 --- paddle/cuda/src/hl_cuda_cnn.cu | 120 ------------ paddle/gserver/layers/NormProjectionLayer.cpp | 29 --- paddle/math/Matrix.cpp | 176 ------------------ paddle/math/Matrix.h | 65 ------- paddle/math/tests/test_matrixCompare.cpp | 27 --- 7 files changed, 497 deletions(-) diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 06ee3b3654..c5787630ab 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -240,62 +240,6 @@ extern void hl_avgpool_backward(const int frameCnt, real* backGrad, const int outStride); -/** - * @brief Cross-map-respose normalize forward. - * - * @param[in] frameCnt batch size of input image. - * @param[in] in input data. - * @param[in] scale buffer. - * @param[out] out output data. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] sizeX size. - * @param[in] alpha scale. - * @param[in] beta scale. - * - */ -extern void hl_CMRNorm_forward(size_t frameCnt, - const real* in, - real* scale, - real* out, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta); - -/** - * @brief Cross-map-respose normalize backward. - * - * @param[in] frameCnt batch size of input image. - * @param[in] inV input data. - * @param[in] scale buffer. - * @param[out] outV output value. - * @param[out] outDiff output grad. - * @param[out] inDiff input grad. - * @param[in] channels number of channel. - * @param[in] height image height. - * @param[in] width image width. - * @param[in] sizeX size. - * @param[in] alpha scale. - * @param[in] beta scale. - * - */ -extern void hl_CMRNorm_backward(size_t frameCnt, - const real* inV, - const real* scale, - const real* outV, - const real* outDiff, - real* inDiff, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta); - /** * @brief Bilinear interpolation forward. * diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index 52c9787352..039551c6cc 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -117,30 +117,6 @@ inline void hl_avgpool_backward(const int frameCnt, real* backGrad, const int outStride) {} -inline void hl_CMRNorm_forward(size_t frameCnt, - const real* in, - real* scale, - real* out, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta) {} - -inline void hl_CMRNorm_backward(size_t frameCnt, - const real* inV, - const real* scale, - const real* outV, - const real* outDiff, - real* inDiff, - size_t channels, - size_t height, - size_t width, - size_t sizeX, - real alpha, - real beta) {} - inline void hl_bilinear_forward(const real* inData, const size_t inImgH, const size_t inImgW, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 1516accaae..b94f4d8fe4 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -381,126 +381,6 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, CHECK_SYNC("hl_avgpool_backward failed"); } -__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, - real* scale, size_t channels, - size_t height, size_t width, size_t size, - real alpha) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < imageSize) { - const int w = idx % width; - const int h = (idx / width) % height; - const int n = idx / width / height; - const int offset = (n * channels * height + h) * width + w; - - in += offset; - scale += offset; - const int step = height * width; - const int pre_pad = (size - 1) / 2; - const int post_pad = size - pre_pad - 1; - - real accum = 0; - int index = 0; - while (index < channels + post_pad) { - if (index < channels) { - accum += in[index * step] * in[index * step]; - } - if (index >= size) { - accum -= in[(index - size) * step] * in[(index - size) * step]; - } - if (index >= post_pad) { - scale[(index - post_pad) * step] = 1. + accum * alpha; - } - ++index; - } - } -} - -__global__ void KeCMRNormOutput(size_t inputSize, const real* in, - const real* scale, real negative_beta, - real* out) { - const int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < inputSize) { - out[index] = in[index] * pow(scale[index], negative_beta); - } -} - -void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale, - real* out, size_t channels, - size_t height, size_t width, size_t sizeX, - real alpha, real beta) { - size_t imageSize = frameCnt * height * width; - int blockSize = 1024; - int gridSize = (imageSize + 1024 - 1) / 1024; - KeCMRNormFillScale<<>> - (imageSize, in, scale, channels, height, width, sizeX, alpha); - - size_t inputSize = frameCnt * height * width *channels; - blockSize = 1024; - gridSize = (inputSize + 1024 - 1) / 1024; - KeCMRNormOutput<<>> - (inputSize, in, scale, beta, out); - CHECK_SYNC("hl_CMRNorm_forward"); -} - -__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, - const real* top_data, const real* scale, - const real* top_diff, size_t channels, - size_t height, size_t width, size_t size, - real negative_beta, real cache_ratio, - real* bottom_diff ) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < imageSize) { - const int w = idx % width; - const int h = (idx / width) % height; - const int n = idx / width / height; - const int offset = (n * channels * height + h) * width + w; - bottom_data += offset; - top_data += offset; - scale += offset; - top_diff += offset; - bottom_diff += offset; - - const int step = height * width; - const int pre_pad = size - (size + 1) / 2; - const int post_pad = size - pre_pad - 1; - - int index = 0; - real accum = 0; - while (index < channels + post_pad) { - if (index < channels) { - accum += top_diff[index * step] * top_data[index * step] / - scale[index * step]; - } - if (index >= size) { - accum -= top_diff[(index - size) * step] * - top_data[(index - size) * step] / scale[(index - size) * step]; - } - if (index >= post_pad) { - bottom_diff[(index - post_pad) * step] += - top_diff[(index - post_pad) * step] * - pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(index - post_pad) * step] * accum; - } - ++index; - } - } -} - -void hl_CMRNorm_backward(size_t frameCnt, const real* inV, - const real* scale, - const real* outV, const real* outDiff, - real *inDiff, size_t channels, - size_t height, size_t width, size_t sizeX, - real alpha, real beta) { - size_t imageSize = frameCnt * height * width; - int blockSize = 1024; - int gridSize = (imageSize + 1024 - 1) / 1024; - KeCMRNormDiff <<>> - (imageSize, inV, outV, scale, outDiff, channels, - height, width, sizeX, alpha, beta, inDiff); - CHECK_SYNC("hl_CMRNorm_backward"); -} - __global__ void KeBilinearInterpFw(const real* in, const size_t inImgH, const size_t inImgW, diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index e69c406993..4ff3b805fb 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -110,34 +110,5 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { Tensor(denoms_->getData(), dims_)}, {Tensor(preOutGrad->getData(), dims_)}, {}); -#if 0 - if (useGpu_) { - CrossMapNormalGrad crossGrad; - crossGrad(dynamic_cast(*preOutGrad), - dynamic_cast(*preOutV), - dynamic_cast(*localGrad), - dynamic_cast(*localOutV), - dynamic_cast(*denoms_), - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); - } else { - CrossMapNormalGrad crossGrad; - crossGrad(dynamic_cast(*preOutGrad), - dynamic_cast(*preOutV), - dynamic_cast(*localGrad), - dynamic_cast(*localOutV), - dynamic_cast(*denoms_), - channels_, - imgSizeH_, - imgSizeW_, - size_, - scale_, - pow_); - } -#endif } } // namespace paddle diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 2cde11dd47..a36c31d32b 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1265,69 +1265,6 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad, outGrad.getStride()); } -void GpuMatrix::crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - size_t num = input.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - - CHECK(height * width * channels == input.getWidth()); - CHECK(denoms.getHeight() == input.getHeight() && - denoms.getWidth() == input.getWidth() && input.getHeight() == height_ && - input.getWidth() == width_); - hl_CMRNorm_forward(num, - input.getData(), - denoms.getData(), - data_, - channels, - height, - width, - sizeX, - scale, - -pow); -} - -void GpuMatrix::crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow) { - size_t num = preOutV.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - - CHECK(width * height * channels == preOutV.getWidth()); - CHECK(denoms.getHeight() == preOutV.getHeight() && - denoms.getWidth() == preOutV.getWidth() && - preOutV.getHeight() == height_ && preOutV.getWidth() == width_); - CHECK(denoms.getHeight() == localGrad.getHeight() && - denoms.getWidth() == localGrad.getWidth()); - - hl_CMRNorm_backward(num, - preOutV.getData(), - denoms.getData(), - localOutV.getData(), - localGrad.getData(), - data_, - channels, - height, - width, - sizeX, - -pow, - 2.0f * pow * scale); -} - void GpuMatrix::maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index) { @@ -2219,119 +2156,6 @@ void CpuMatrix::avgPoolBackward(Matrix& input, } } -void CpuMatrix::crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - CHECK(isContiguous()); - CHECK(input.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK_EQ(getHeight(), input.getHeight()); - CHECK_EQ(getWidth(), input.getWidth()); - CHECK_EQ(getHeight(), denoms.getHeight()); - CHECK_EQ(getWidth(), denoms.getWidth()); - - size_t numSample = input.getHeight(); - size_t numCols = input.getWidth(); - size_t height = imgSizeH; - size_t width = imgSizeW; - CHECK(height * width * channels == numCols); - - // TODO(hedaoyuan) After commit TensorExpress code, - // Reconstruction this code to remove the temporary memory. - CpuMatrix tmp(channels, height * width); - CpuMatrix tmp2(tmp.getData(), 1, channels * height * width); - denoms.zero(); - const int start = -((int)sizeX - 1) / 2; - const int end = (int)sizeX + start; - for (size_t i = 0; i < numSample; i++) { - input.subMatrix(i, 1)->square2(tmp2); - CpuMatrix subDen( - denoms.subMatrix(i, 1)->getData(), channels, height * width); - for (int c = 0; c < (int)channels; c++) { - for (int s = start; s < end; s++) { - if (c + s >= 0 && c + s < (int)channels) { - subDen.subMatrix(c, 1)->add(*tmp.subMatrix(c + s, 1)); - } - } - } - } - - denoms.add(scale, (real)1); - this->pow2(denoms, -pow); - this->dotMul(input); -} - -void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow) { - CHECK(isContiguous()); - CHECK(localGrad.isContiguous()); - CHECK(denoms.isContiguous()); - CHECK(preOutV.isContiguous()); - CHECK(localOutV.isContiguous()); - CHECK_EQ(getHeight(), localGrad.getHeight()); - CHECK_EQ(getWidth(), localGrad.getWidth()); - CHECK_EQ(getHeight(), denoms.getHeight()); - CHECK_EQ(getWidth(), denoms.getWidth()); - CHECK_EQ(getHeight(), preOutV.getHeight()); - CHECK_EQ(getWidth(), preOutV.getWidth()); - CHECK_EQ(getHeight(), localOutV.getHeight()); - CHECK_EQ(getWidth(), localOutV.getWidth()); - - size_t numSample = getHeight(); - size_t numCols = getWidth(); - size_t height = imgSizeH; - size_t width = imgSizeW; - CHECK(height * width * channels == numCols); - - // TODO(hedaoyuan) After commit TensorExpress code, - // Reconstruction this code to remove the temporary memory. - CpuMatrix tmp(1, height * width); - - const int start = -((int)sizeX) / 2; - const int end = (int)sizeX + start; - const real ratio = -(real)2 * scale * pow; - for (size_t i = 0; i < numSample; i++) { - CpuMatrix inputDiff( - this->subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix outDiff( - localGrad.subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix input( - preOutV.subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix output( - localOutV.subMatrix(i, 1)->getData(), channels, height * width); - CpuMatrix subDen( - denoms.subMatrix(i, 1)->getData(), channels, height * width); - - for (int c = 0; c < (int)channels; c++) { - tmp.pow2(*subDen.subMatrix(c, 1), -pow); - inputDiff.subMatrix(c, 1) - ->addDotMul(tmp, *outDiff.subMatrix(c, 1), (real)1, (real)1); - for (int s = start; s < end; s++) { - if (c + s >= 0 && c + s < (int)channels) { - tmp.dotMul(*outDiff.subMatrix(c + s, 1), *output.subMatrix(c + s, 1)); - tmp.mulScalar(ratio); - tmp.dotDiv(tmp, *subDen.subMatrix(c + s, 1)); - tmp.dotMul(*input.subMatrix(c, 1)); - inputDiff.subMatrix(c, 1)->add(tmp); - } - } - } - } -} - /** * Input: one or more sequences. Each sequence contains some instances. * Output: output size is the number of input sequences (NOT input instances). diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 5685cb7bcb..62bc1b16fc 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -952,31 +952,6 @@ public: LOG(FATAL) << "Not implemeted"; } - /// normalize-operation. - virtual void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow) { - LOG(FATAL) << "Not implemeted"; - } - - virtual void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t size, - float scale, - float pow) { - LOG(FATAL) << "Not implemeted"; - } - /** * Input: one or more sequences. Each sequence contains some instances. * @@ -1459,26 +1434,6 @@ public: size_t paddingH, size_t paddingW); - void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow); - - void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow); - void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index); @@ -1685,26 +1640,6 @@ public: size_t paddingH, size_t paddingW); - void crossMapNormalFwd(Matrix& input, - size_t imgSizeH, - size_t imgSizeW, - Matrix& denoms, - size_t channels, - size_t sizeX, - float scale, - float pow); - - void crossMapNormalBwd(Matrix& localGrad, - Matrix& denoms, - Matrix& preOutV, - Matrix& localOutV, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - float scale, - float pow); - void maxSequenceForward(Matrix& input, const IVector& sequence, IVector& index); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index da7a585484..c89b7ff490 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1385,33 +1385,6 @@ void testCrossMapNormalBwd( Tensor(denomsGpu.getData(), dims)}, {Tensor(inputsGradGpu.getData(), dims)}, {}); -#if 0 - CrossMapNormalGrad cpuCross; - cpuCross(inputsGrad, - inputsValue, - outputsGrad, - outputsValue, - denoms, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - CrossMapNormalGrad gpuCross; - gpuCross(inputsGradGpu, - inputsValueGpu, - outputsGradGpu, - outputsValueGpu, - denomsGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); -#endif TensorCheckErr(inputsGrad, inputsGradGpu); } From f13aeb52e9fc666ac1e24acf5315cbdccf108402 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 20:12:53 +0800 Subject: [PATCH 11/26] fix swig_api --- paddle/api/CMakeLists.txt | 1 + paddle/api/paddle_ld_flags.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index 6ad1d79e59..ed69bd764f 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -46,6 +46,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp WORKING_DIRECTORY ${PROJ_ROOT}/paddle DEPENDS python_swig_sources paddle_parameter + paddle_function paddle_math paddle_utils paddle_gserver diff --git a/paddle/api/paddle_ld_flags.py b/paddle/api/paddle_ld_flags.py index 51d7dfee58..7c8206e3fe 100644 --- a/paddle/api/paddle_ld_flags.py +++ b/paddle/api/paddle_ld_flags.py @@ -30,8 +30,8 @@ try: whole_end = "" LIB_DIRS = [ - "math", 'utils', 'parameter', "gserver", "api", "cuda", "pserver", - "trainer" + "math", 'function', 'utils', 'parameter', "gserver", "api", "cuda", + "pserver", "trainer" ] PARENT_LIB_DIRS = ['proto'] @@ -75,6 +75,7 @@ try: libs = [ whole_start, "-lpaddle_gserver", + "-lpaddle_function", whole_end, "-lpaddle_pserver", "-lpaddle_trainer_lib", From 22b9b6662b215b663ce2cebdf7624ea1212bb9c1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 15 Dec 2016 21:01:47 +0800 Subject: [PATCH 12/26] Add unittest to coverage SgdThreadUpdater's enableBufType --- paddle/trainer/ThreadParameterUpdater.cpp | 3 +++ paddle/trainer/tests/fake_file_list.list | 1 + .../tests/simple_sparse_neural_network.py | 23 +++++++++++++++++++ .../tests/simple_sparse_neural_network_dp.py | 21 +++++++++++++++++ paddle/trainer/tests/test_TrainerOnePass.cpp | 9 +++++++- 5 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 paddle/trainer/tests/fake_file_list.list create mode 100644 paddle/trainer/tests/simple_sparse_neural_network.py create mode 100644 paddle/trainer/tests/simple_sparse_neural_network_dp.py diff --git a/paddle/trainer/ThreadParameterUpdater.cpp b/paddle/trainer/ThreadParameterUpdater.cpp index 9caa92a4d7..049022b1f1 100644 --- a/paddle/trainer/ThreadParameterUpdater.cpp +++ b/paddle/trainer/ThreadParameterUpdater.cpp @@ -55,6 +55,9 @@ void SgdThreadUpdater::init(std::vector& parameters) { // not create parameter buf for PARAMETER_GRADIENT for sparse update in // Parameter::enableType(). But gradient parameter buf is still used // in SgdThreadUpdater. We need to explicitly create it. + // + // The AverageOptimizer::restore/apply method will use PARAMETER_GRADIENT + // as a temp buffer. para->enableBufType(PARAMETER_GRADIENT); } } diff --git a/paddle/trainer/tests/fake_file_list.list b/paddle/trainer/tests/fake_file_list.list new file mode 100644 index 0000000000..f27ceed277 --- /dev/null +++ b/paddle/trainer/tests/fake_file_list.list @@ -0,0 +1 @@ +do_not_matter.txt diff --git a/paddle/trainer/tests/simple_sparse_neural_network.py b/paddle/trainer/tests/simple_sparse_neural_network.py new file mode 100644 index 0000000000..9604e1b9b4 --- /dev/null +++ b/paddle/trainer/tests/simple_sparse_neural_network.py @@ -0,0 +1,23 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=128, learning_method=AdaGradOptimizer(), learning_rate=1e-4) + +file_list = 'trainer/tests/fake_file_list.list' + +define_py_data_sources2( + train_list=file_list, + test_list=file_list, + module="simple_sparse_neural_network_dp", + obj="process") + +embedding = embedding_layer( + input=data_layer( + name="word_ids", size=65536), + size=128, + param_attr=ParamAttr(sparse_update=True)) +prediction = fc_layer(input=embedding, size=10, act=SoftmaxActivation()) + +outputs( + classification_cost( + input=prediction, label=data_layer( + name='label', size=10))) diff --git a/paddle/trainer/tests/simple_sparse_neural_network_dp.py b/paddle/trainer/tests/simple_sparse_neural_network_dp.py new file mode 100644 index 0000000000..8bfd1f37e7 --- /dev/null +++ b/paddle/trainer/tests/simple_sparse_neural_network_dp.py @@ -0,0 +1,21 @@ +from paddle.trainer.PyDataProvider2 import provider, integer_sequence, integer_value +import random + + +def init_hook(settings, is_train, **kwargs): + settings.is_train = is_train + + +@provider( + input_types={'word_ids': integer_value(65536), + 'label': integer_value(10)}, + min_pool_size=0, + init_hook=init_hook) +def process(settings, filename): + if settings.is_train: + data_size = 2**20 + else: + data_size = 2**10 + + for _ in xrange(data_size): + yield random.randint(0, 65535), random.randint(0, 9) diff --git a/paddle/trainer/tests/test_TrainerOnePass.cpp b/paddle/trainer/tests/test_TrainerOnePass.cpp index ee21008aec..4d0174f784 100644 --- a/paddle/trainer/tests/test_TrainerOnePass.cpp +++ b/paddle/trainer/tests/test_TrainerOnePass.cpp @@ -27,6 +27,9 @@ static const string& configFile1 = "trainer/tests/sample_trainer_config.conf"; static const string& configFile2 = "trainer/tests/sample_trainer_config_parallel.conf"; +static const string& configFileSimpleSparse = + "trainer/tests/simple_sparse_neural_network.py"; + DECLARE_bool(use_gpu); DECLARE_string(config); DECLARE_int32(gpu_id); @@ -298,11 +301,15 @@ TEST(checkRemoteUpdater, cpuDeltaTrainerOldUpdater) { checkRemoteParameterUpdaterTest(configFile1, false, false, 1, true, 10); } +TEST(SgdThreadUpdater, simpleSparseNN) { + trainerOnePassTest(configFileSimpleSparse, false, false, 1, 0.5, true); +} + int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); initMain(argc, argv); initPython(argc, argv); gNumDevices = hl_get_device_count(); - testing::InitGoogleTest(&argc, argv); FLAGS_num_passes = 1; // train one pass FLAGS_saving_period = 100000; // do not save parameteres From cee934680467c50d4084dbaf7273a39a40cc832d Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 15 Dec 2016 21:23:05 +0800 Subject: [PATCH 13/26] add some comments --- paddle/function/cross_map_normal_op.cpp | 5 ++- paddle/function/cross_map_normal_op.h | 34 +++++++++++++++++++ paddle/gserver/layers/Layer.h | 6 ++++ paddle/gserver/layers/NormProjectionLayer.cpp | 18 ++++------ paddle/gserver/layers/NormProjectionLayer.h | 3 -- 5 files changed, 50 insertions(+), 16 deletions(-) diff --git a/paddle/function/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp index 0391a58d89..a18c0bb750 100644 --- a/paddle/function/cross_map_normal_op.cpp +++ b/paddle/function/cross_map_normal_op.cpp @@ -17,7 +17,6 @@ limitations under the License. */ namespace paddle { -// NCHW template <> void CrossMapNormal(real* outputs, real* denoms, @@ -36,6 +35,10 @@ void CrossMapNormal(real* outputs, CpuVector inputsV(numSamples * oneSample, inputs); CpuVector denomsV(numSamples * oneSample, denoms); + // f(x) = x * ( 1 + scale * SUM((x)^2) )^(-pow) + // x represents inputs + // f(x) represents outputs + // denoms save the intermediate result for backward denomsV = denomsV.constant(1.0); const int start = -((int)size - 1) / 2; const int end = (int)size + start; diff --git a/paddle/function/cross_map_normal_op.h b/paddle/function/cross_map_normal_op.h index f065208084..e935b26e12 100644 --- a/paddle/function/cross_map_normal_op.h +++ b/paddle/function/cross_map_normal_op.h @@ -18,6 +18,22 @@ limitations under the License. */ namespace paddle { +/** + * \brief Cross map respose normalize forward. + * The data structure of image data is NCHW. + * + * \param[out] outputs output data. + * \param[in] denoms denoms buffer. + * \param[in] inputs input data. + * \param[in] numSamples batch size of input image. + * \param[in] channels number of channel. + * \param[in] height image height. + * \param[in] width image width. + * \param[in] size size. + * \param[in] scale scale. + * \param[in] pow scale. + * + */ template void CrossMapNormal(real* outputs, real* denoms, @@ -30,6 +46,24 @@ void CrossMapNormal(real* outputs, real scale, real pow); +/** + * \brief Cross map respose normalize backward. + * The data structure of image data is NCHW. + * + * \param[out] inputsGrad input grad. + * \param[in] inputsValue input value. + * \param[out] outputsValue output value. + * \param[out] outputsGrad output grad. + * \param[in] denoms denoms buffer. + * \param[in] numSamples batch size of input image. + * \param[in] channels number of channel. + * \param[in] height image height. + * \param[in] width image width. + * \param[in] size size. + * \param[in] scale scale. + * \param[in] pow scale. + * + */ template void CrossMapNormalGrad(real* inputsGrad, real* inputsValue, diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 172e558b82..16f66a2205 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include "ModelConfig.pb.h" +#include "paddle/function/Function.h" #include "paddle/math/CpuSparseMatrix.h" #include "paddle/parameter/Parameter.h" #include "paddle/utils/ClassRegistrar.h" @@ -100,6 +101,11 @@ protected: /// Mark input grad in(true) or out(false) of backward function. std::vector markInBackward_; + /// Layer forward function + FunctionBase* forward_; + /// Layer backward function + FunctionBase* backward_; + public: /** * Wait until all input value ready. diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 4ff3b805fb..0f6f9b91d0 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -48,20 +48,17 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, if (useGpu_) { forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, GPU)); + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, GPU)); } else { forward_ = FunctionBase::funcRegistrar_.createByType( FUNC_NAME(CrossMapNormal, CPU)); + backward_ = FunctionBase::funcRegistrar_.createByType( + FUNC_NAME(CrossMapNormalGrad, CPU)); } forward_->init( FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); - if (useGpu_) { - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, GPU)); - } else { - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, CPU)); - } backward_->init( FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); @@ -74,7 +71,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { /* malloc memory for the output_ if necessary */ /* note: one sample correspond to one row */ MatrixPtr input = inputLayers_[0]->getOutputValue(); - int batchSize = input->getHeight(); + size_t batchSize = input->getHeight(); int size = getSize(); resetOutput(batchSize, size); @@ -82,10 +79,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - dims_ = {(size_t)batchSize, - (size_t)channels_, - (size_t)imgSizeH_, - (size_t)imgSizeW_}; + dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; forward_->calc( {Tensor(input->getData(), dims_)}, {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 3c4876ece6..6b2c5dde0d 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "NormLayer.h" -#include "paddle/function/Function.h" #include "paddle/math/Matrix.h" namespace paddle { @@ -43,7 +42,5 @@ public: protected: Dims dims_; - FunctionBase* forward_; - FunctionBase* backward_; }; } // namespace paddle From 96eab5046a14ff901b8685d6d7adf3d9ea4a8c5c Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 16 Dec 2016 10:15:36 +0800 Subject: [PATCH 14/26] Complete unittest setting in CMake --- paddle/trainer/tests/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/trainer/tests/CMakeLists.txt b/paddle/trainer/tests/CMakeLists.txt index 60c129f4e2..28c3d6f263 100644 --- a/paddle/trainer/tests/CMakeLists.txt +++ b/paddle/trainer/tests/CMakeLists.txt @@ -27,7 +27,8 @@ add_test(NAME test_Trainer add_unittest_without_exec(test_TrainerOnePass test_TrainerOnePass.cpp) add_test(NAME test_TrainerOnePass - COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/ + COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d + ${PROJ_ROOT}/python/:${PROJ_ROOT}/paddle/trainer/tests ${PROJ_ROOT}/paddle/.set_port.sh -p port ${CMAKE_CURRENT_BINARY_DIR}/test_TrainerOnePass WORKING_DIRECTORY ${PROJ_ROOT}/paddle/) From 5222b586e2db3a4dc46cacf884afae9e4d6e51f2 Mon Sep 17 00:00:00 2001 From: yangwenbo02 Date: Fri, 16 Dec 2016 15:43:40 +0800 Subject: [PATCH 15/26] support UBUNTU MIRROR and modify doc --- .../build_and_install/docker_install_en.rst | 16 ++++++++++++++++ paddle/scripts/docker/Dockerfile | 2 ++ paddle/scripts/docker/Dockerfile.gpu | 2 ++ 3 files changed, 20 insertions(+) diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index 7633bf4d57..1252ff3974 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -142,6 +142,22 @@ to install CUDA driver and let Docker knows about it: export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') docker run ${CUDA_SO} ${DEVICES} -it paddledev/paddle:gpu-latest + +UBUNTU MIRROR +------------- + +Building Paddle Docker image hits some wrong with apt-get update, you +can use other UBUNTU MIRROR instead of the default + +.. code-block:: bash + + cd ~ + git clone https://github.com/PaddlePaddle/Paddle.git + cd Paddle + git submodule update --init --recursive + docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:cpu-avx -f paddle/scripts/docker/Dockerfile . + docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:gpu-avx -f paddle/scripts/docker/Dockerfile.gpu . + Non-AVX Images -------------- diff --git a/paddle/scripts/docker/Dockerfile b/paddle/scripts/docker/Dockerfile index 207f97c4a6..f26055d0d4 100644 --- a/paddle/scripts/docker/Dockerfile +++ b/paddle/scripts/docker/Dockerfile @@ -2,6 +2,8 @@ FROM ubuntu:14.04 MAINTAINER PaddlePaddle Authors ARG DEBIAN_FRONTEND=noninteractive +ARG UBUNTU_MIRROR +RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ubuntu.com#${UBUNTU_MIRROR}#g' /etc/apt/sources.list; fi' RUN apt-get update \ && apt-get install -y cmake libprotobuf-dev protobuf-compiler git \ libgoogle-glog-dev libgflags-dev libgtest-dev \ diff --git a/paddle/scripts/docker/Dockerfile.gpu b/paddle/scripts/docker/Dockerfile.gpu index 33f6adfea2..d13b977147 100644 --- a/paddle/scripts/docker/Dockerfile.gpu +++ b/paddle/scripts/docker/Dockerfile.gpu @@ -2,6 +2,8 @@ FROM nvidia/cuda:7.5-cudnn5-devel-ubuntu14.04 MAINTAINER PaddlePaddle Authors ARG DEBIAN_FRONTEND=noninteractive +ARG UBUNTU_MIRROR +RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ubuntu.com#${UBUNTU_MIRROR}#g' /etc/apt/sources.list; fi' RUN apt-get update \ && apt-get install -y cmake libprotobuf-dev protobuf-compiler git \ libgoogle-glog-dev libgflags-dev libgtest-dev \ From 5b746fb183572bc04a0697f3ef9d043849506862 Mon Sep 17 00:00:00 2001 From: yangwenbo02 Date: Fri, 16 Dec 2016 17:23:24 +0800 Subject: [PATCH 16/26] modify doc doc/getstarted/build_and_install/docker_install_en.rst --- .../build_and_install/docker_install_en.rst | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index 1252ff3974..ffda796470 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -39,12 +39,18 @@ The general development workflow with Docker and Bazel is as follows: code. This image contains all the development tools and dependencies of PaddlePaddle. - .. code-block:: bash cd paddle docker build -t paddle:dev -f paddle/scripts/docker/Dockerfile . + Apt-get source errors may occur when building paddle docker image. + **You can specify the UBUNTU MIRROR with :code:`--build-arg UBUNTU_MIRROR` like the example below.** + + .. code-block:: bash + + docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:dev -f paddle/scripts/docker/Dockerfile . + 3. Run the image as a container and mounting local source code directory into the container. This allows us to change the code on @@ -142,22 +148,6 @@ to install CUDA driver and let Docker knows about it: export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') docker run ${CUDA_SO} ${DEVICES} -it paddledev/paddle:gpu-latest - -UBUNTU MIRROR -------------- - -Building Paddle Docker image hits some wrong with apt-get update, you -can use other UBUNTU MIRROR instead of the default - -.. code-block:: bash - - cd ~ - git clone https://github.com/PaddlePaddle/Paddle.git - cd Paddle - git submodule update --init --recursive - docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:cpu-avx -f paddle/scripts/docker/Dockerfile . - docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:gpu-avx -f paddle/scripts/docker/Dockerfile.gpu . - Non-AVX Images -------------- From 36af605a2d13f7be0a8d326144b88d7d2ed5d242 Mon Sep 17 00:00:00 2001 From: yangwenbo02 Date: Fri, 16 Dec 2016 17:33:14 +0800 Subject: [PATCH 17/26] modify doc --- doc/getstarted/build_and_install/docker_install_en.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index ffda796470..1cc23ac3aa 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -45,11 +45,14 @@ The general development workflow with Docker and Bazel is as follows: docker build -t paddle:dev -f paddle/scripts/docker/Dockerfile . Apt-get source errors may occur when building paddle docker image. - **You can specify the UBUNTU MIRROR with :code:`--build-arg UBUNTU_MIRROR` like the example below.** + **You can specify the UBUNTU MIRROR with** :code:`--build-arg UBUNTU_MIRROR` **like the example below.** .. code-block:: bash - docker build --build-arg UBUNTU_MIRROR="http://mirrors.163.com" -t paddle:dev -f paddle/scripts/docker/Dockerfile . + docker build \ + --build-arg UBUNTU_MIRROR="http://mirrors.163.com" \ + -t paddle:dev \ + -f paddle/scripts/docker/Dockerfile . 3. Run the image as a container and mounting local source code From 148bd4d0b3240d31c1c96ddac89ffd4935f71b03 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 19 Dec 2016 15:04:48 +0800 Subject: [PATCH 18/26] add Layer::createFunction --- paddle/gserver/layers/Layer.h | 24 +++++++++++-- paddle/gserver/layers/NormProjectionLayer.cpp | 34 +++++++------------ 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 16f66a2205..6dfd48fb96 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -102,9 +102,9 @@ protected: std::vector markInBackward_; /// Layer forward function - FunctionBase* forward_; + std::vector> forward_; /// Layer backward function - FunctionBase* backward_; + std::vector> backward_; public: /** @@ -132,6 +132,26 @@ public: virtual void markAllInputGrad(); protected: + /** + * Create layer function. Function is called in forward or backward. + * \param function, Layer::forward_ or Layer::backward_ + * \param name, function name + * \param config, initialization configuration for the function + */ + void createFunction(std::vector>& function, + const std::string& name, + const FuncConfig& config) { + if (useGpu_) { + function.emplace_back( + FunctionBase::funcRegistrar_.createByType(name + "-GPU")); + } else { + function.emplace_back( + FunctionBase::funcRegistrar_.createByType(name + "-CPU")); + } + auto& func = function.back(); + func->init(config); + } + /** * Notify specified layer the output grad ready. * Called in the backward function. diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 0f6f9b91d0..262d757c67 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -45,21 +45,13 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, /* the size of inputs for norm-layer is 1 */ CHECK_EQ(config_.inputs_size(), 1); - if (useGpu_) { - forward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormal, GPU)); - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, GPU)); - } else { - forward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormal, CPU)); - backward_ = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, CPU)); - } - forward_->init( + createFunction( + forward_, + "CrossMapNormal", FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); - - backward_->init( + createFunction( + backward_, + "CrossMapNormalGrad", FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_)); return true; @@ -80,7 +72,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; - forward_->calc( + forward_[0]->calc( {Tensor(input->getData(), dims_)}, {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, {}); @@ -98,11 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { MatrixPtr localOutV = getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); - backward_->calc({Tensor(preOutV->getData(), dims_), - Tensor(localOutV->getData(), dims_), - Tensor(localGrad->getData(), dims_), - Tensor(denoms_->getData(), dims_)}, - {Tensor(preOutGrad->getData(), dims_)}, - {}); + backward_[0]->calc({Tensor(preOutV->getData(), dims_), + Tensor(localOutV->getData(), dims_), + Tensor(localGrad->getData(), dims_), + Tensor(denoms_->getData(), dims_)}, + {Tensor(preOutGrad->getData(), dims_)}, + {}); } } // namespace paddle From 706c572424b6f273fd948d60675c25c378e7021a Mon Sep 17 00:00:00 2001 From: xutianbing Date: Fri, 16 Dec 2016 15:14:02 -0800 Subject: [PATCH 19/26] Matrix API refactor, when passing parameters, convert shared_ptr (MatrixPtr) to reference or raw matrix (Matrix & or Matrix *) contextProjectionForward contextProjectionBackward contextProjectionBackwardData contextProjectionBackwardWeight classificationError The mul functions would be updated later. --- paddle/gserver/evaluators/Evaluator.cpp | 2 +- paddle/gserver/layers/ContextProjection.cpp | 12 +- paddle/math/Matrix.cpp | 171 ++++++++------------ paddle/math/Matrix.h | 34 ++-- paddle/math/tests/test_matrixCompare.cpp | 20 +-- 5 files changed, 103 insertions(+), 136 deletions(-) diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 2f99281911..ae7508e2bb 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -78,7 +78,7 @@ public: useGpu(arguments[0].deviceId)); errorMat->zeroMem(); if (label != nullptr) { - errorMat->classificationError(output, label); + errorMat->classificationError(*output, *label); } else if (dynamic_cast(multiBinaryLabel.get()) || dynamic_cast(multiBinaryLabel.get())) { errorMat->classificationErrorMulti( diff --git a/paddle/gserver/layers/ContextProjection.cpp b/paddle/gserver/layers/ContextProjection.cpp index 7ac56e3a2a..51c0ae5cc9 100644 --- a/paddle/gserver/layers/ContextProjection.cpp +++ b/paddle/gserver/layers/ContextProjection.cpp @@ -90,8 +90,8 @@ void ContextProjection::forward() { REGISTER_TIMER_INFO("ContextProjectionForward", getName().c_str()); bool isPadding = config_.trainable_padding(); out_->value->contextProjectionForward( - in_->value, - state_ ? state_ : isPadding ? weight_->getW() : nullptr, + *(in_->value), + state_ ? state_.get() : isPadding ? weight_->getW().get() : nullptr, *startPositions, config_.context_length(), config_.context_start(), @@ -128,8 +128,8 @@ void ContextProjection::backward(const UpdateCallback& callback) { bool isPadding = config_.trainable_padding(); if (!out_->grad->useGpu()) { out_->grad->contextProjectionBackward( - in_->grad, - isPadding ? weight_->getWGrad() : nullptr, + in_->grad.get(), + isPadding ? weight_->getWGrad().get() : nullptr, *startPositions, config_.context_length(), config_.context_start(), @@ -137,7 +137,7 @@ void ContextProjection::backward(const UpdateCallback& callback) { isPadding); } else { if (in_->grad) { - out_->grad->contextProjectionBackwardData(in_->grad, + out_->grad->contextProjectionBackwardData(*(in_->grad), *startPositions, config_.context_length(), config_.context_start()); @@ -145,7 +145,7 @@ void ContextProjection::backward(const UpdateCallback& callback) { if (isPadding && weight_->getWGrad()) { out_->grad->contextProjectionBackwardWeight( - weight_->getWGrad(), + *(weight_->getWGrad()), *startPositions, config_.context_length(), config_.context_start(), diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c69e074a76..3b3c1d7d48 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -766,20 +766,19 @@ void GpuMatrix::maxoutBackward(Matrix& a, } /*calulate the error of classification */ -void GpuMatrix::classificationError(MatrixPtr output, IVectorPtr label) { - GpuMatrixPtr output_ptr = std::dynamic_pointer_cast(output); - GpuIVectorPtr label_ptr = std::dynamic_pointer_cast(label); - +void GpuMatrix::classificationError(Matrix& output, IVector& label) { + auto output_ptr = dynamic_cast(&output); + auto label_ptr = dynamic_cast(&label); CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; CHECK(height_ == output_ptr->height_ && width_ == 1) << "Matrix dimensions are not equal"; - real* output_d = output_ptr->data_; - real* recResult_d = data_; - int* label_d = label_ptr->getData(); - hl_matrix_classification_error( - output_d, label_d, recResult_d, height_, output_ptr->width_); + hl_matrix_classification_error((real*)output_ptr->data_, + (int*)label_ptr->getData(), + data_, + height_, + output_ptr->width_); } /* copy -log(output[i * width + label]) to this->data[i] */ @@ -1370,86 +1369,62 @@ void GpuMatrix::maxSequenceBackward(Matrix& outputGrad, hl_max_sequence_backward(outGrad, maxIndex, inputGrad, numSequences, dim); } -void GpuMatrix::contextProjectionForward(MatrixPtr input, - MatrixPtr weight, +void GpuMatrix::contextProjectionForward(Matrix& input, + Matrix* weight, const IVector& sequence, int contextLength, int contextStart, size_t beginPad, bool isPadding) { - CHECK(dynamic_cast(input.get())); + CHECK(dynamic_cast(&input)); CHECK(dynamic_cast(&sequence)); - if (weight) CHECK(dynamic_cast(weight.get())); - - size_t numSequences = sequence.getSize() - 1; - int64_t inputDim = input->getWidth(); - int64_t dim = getWidth(); - CHECK_EQ(dim, inputDim * contextLength); - - real* outData = getData(); - real* inputData = input->getData(); - const int* starts = sequence.getData(); + if (weight) CHECK(dynamic_cast(weight)); + CHECK_EQ(getWidth(), input.getWidth() * contextLength); - hl_context_projection_forward(inputData, - starts, + hl_context_projection_forward(input.getData(), + sequence.getData(), isPadding ? weight->getData() : NULL, - outData, - numSequences, - inputDim, + getData(), + sequence.getSize() - 1, + input.getWidth(), contextLength, contextStart, beginPad, isPadding); } -void GpuMatrix::contextProjectionBackwardData(MatrixPtr inputGrad, +void GpuMatrix::contextProjectionBackwardData(Matrix& inputGrad, const IVector& sequence, int contextLength, int contextStart) { - CHECK(dynamic_cast(inputGrad.get())); + CHECK(dynamic_cast(&inputGrad)); CHECK(dynamic_cast(&sequence)); + CHECK_EQ(getWidth(), inputGrad.getWidth() * contextLength); - size_t numSequences = sequence.getSize() - 1; - int64_t inputDim = inputGrad->getWidth(); - int64_t dim = getWidth(); - CHECK_EQ(dim, inputDim * contextLength); - - real* outGrad = getData(); - real* inGrad = inputGrad->getData(); - const int* starts = sequence.getData(); - - hl_context_projection_backward_data(outGrad, - starts, - inGrad, - numSequences, - inputDim, + hl_context_projection_backward_data(getData(), + sequence.getData(), + inputGrad.getData(), + sequence.getSize() - 1, + inputGrad.getWidth(), contextLength, contextStart); } -void GpuMatrix::contextProjectionBackwardWeight(MatrixPtr weightGrad, +void GpuMatrix::contextProjectionBackwardWeight(Matrix& weightGrad, const IVector& sequence, int contextLength, int contextStart, int totalPad, size_t beginPad) { - CHECK(dynamic_cast(weightGrad.get())); + CHECK(dynamic_cast(&weightGrad)); CHECK(dynamic_cast(&sequence)); + CHECK_EQ(getWidth(), weightGrad.getWidth() * contextLength); - size_t numSequences = sequence.getSize() - 1; - int64_t weightDim = weightGrad->getWidth(); - int64_t dim = getWidth(); - CHECK_EQ(dim, weightDim * contextLength); - - real* outGrad = getData(); - real* wtGrad = weightGrad->getData(); - const int* starts = sequence.getData(); - - hl_context_projection_backward_weight(outGrad, - starts, - wtGrad, - numSequences, - weightDim, + hl_context_projection_backward_weight(getData(), + sequence.getData(), + weightGrad.getData(), + sequence.getSize() - 1, + weightGrad.getWidth(), totalPad, contextLength, contextStart, @@ -2371,23 +2346,21 @@ void CpuMatrix::maxSequenceBackward(Matrix& outputGrad, } } -void CpuMatrix::contextProjectionForward(MatrixPtr input, - MatrixPtr weight, +void CpuMatrix::contextProjectionForward(Matrix& input, + Matrix* weight, const IVector& sequence, int contextLength, int contextStart, size_t beginPad, bool isPadding) { - CHECK(dynamic_cast(input.get())); - CHECK(dynamic_cast(&sequence)); - if (weight) CHECK(dynamic_cast(weight.get())); - - size_t numSequences = sequence.getSize() - 1; - int64_t inputDim = input->getWidth(); - int64_t dim = getWidth(); - CHECK_EQ(dim, inputDim * contextLength); - const int* starts = sequence.getData(); - + auto input_ptr = dynamic_cast(&input); + auto seq_ptr = dynamic_cast(&sequence); + CHECK(input_ptr && seq_ptr); + if (weight) CHECK(dynamic_cast(weight)); + CHECK_EQ(getWidth(), input_ptr->getWidth() * contextLength); + + const int* starts = seq_ptr->getData(); + size_t numSequences = seq_ptr->getSize() - 1; for (size_t i = 0; i < numSequences; ++i) { for (int j = 0; j < contextLength; ++j) { int begin = starts[i] + contextStart + j; @@ -2400,7 +2373,7 @@ void CpuMatrix::contextProjectionForward(MatrixPtr input, MatrixPtr mat = this->subMatrix(starts[i], padSize); if (isPadding) { MatrixPtr sub = weight->subMatrix(j, padSize); - mat->addAtOffset(*sub, j * inputDim); + mat->addAtOffset(*sub, j * input_ptr->getWidth()); } dstBegin = starts[i] + padSize; begin = starts[i]; @@ -2412,41 +2385,36 @@ void CpuMatrix::contextProjectionForward(MatrixPtr input, if (isPadding) { MatrixPtr sub = weight->subMatrix(beginPad + contextStart + j - padSize, padSize); - mat->addAtOffset(*sub, j * inputDim); + mat->addAtOffset(*sub, j * input_ptr->getWidth()); } dstEnd = starts[i + 1] - padSize; end = starts[i + 1]; } if (end <= begin) continue; - MatrixPtr src = input->subMatrix(begin, end - begin); + MatrixPtr src = input_ptr->subMatrix(begin, end - begin); MatrixPtr dst = this->subMatrix(dstBegin, dstEnd - dstBegin); - dst->addAtOffset(*src, j * inputDim); + dst->addAtOffset(*src, j * input_ptr->getWidth()); } } } -void CpuMatrix::contextProjectionBackward(MatrixPtr inputGrad, - MatrixPtr weightGrad, +void CpuMatrix::contextProjectionBackward(Matrix* inputGrad, + Matrix* weightGrad, const IVector& sequence, int contextLength, int contextStart, size_t beginPad, bool isPadding) { - if (inputGrad) CHECK(dynamic_cast(inputGrad.get())); - if (weightGrad) CHECK(dynamic_cast(weightGrad.get())); + if (inputGrad) CHECK(dynamic_cast(inputGrad)); + if (weightGrad) CHECK(dynamic_cast(weightGrad)); CHECK(dynamic_cast(&sequence)); - int64_t inputDim = 0; - int64_t dim = getWidth(); - size_t numSequences = sequence.getSize() - 1; - const int* starts = sequence.getData(); - if (inputGrad) { - inputDim = inputGrad->getWidth(); - } else { - inputDim = weightGrad->getWidth(); - } - CHECK_EQ(dim, inputDim * contextLength); + int64_t inputDim = inputGrad ? inputGrad->getWidth() + : weightGrad ? weightGrad->getWidth() : 0; + CHECK_EQ(getWidth(), inputDim * contextLength); + const int* starts = sequence.getData(); + size_t numSequences = sequence.getSize() - 1; for (size_t i = 0; i < numSequences; ++i) { for (int j = 0; j < contextLength; ++j) { int begin = starts[i] + contextStart + j; @@ -3544,21 +3512,20 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) { } /* calulate classification error */ -void CpuMatrix::classificationError(MatrixPtr output, IVectorPtr label) { - CHECK(dynamic_cast(output.get())); - CHECK(dynamic_cast(label.get())); +void CpuMatrix::classificationError(Matrix& output, IVector& label) { + CHECK(dynamic_cast(&output)); + CHECK(dynamic_cast(&label)); - size_t numSamples = getHeight(); - size_t dim = output->getWidth(); - CHECK_EQ(label->getSize(), numSamples); - CHECK_EQ(output->getHeight(), numSamples); CHECK_EQ(getWidth(), (size_t)1); + size_t numSamples = getHeight(); + CHECK_EQ(label.getSize(), numSamples); + CHECK_EQ(output.getHeight(), numSamples); - real* out = output->getData(); - real* result = getData(); - int* lbl = label->getData(); - real maxData; - int maxIndex; + size_t dim = output.getWidth(); + real* out = output.getData(); + int* lbl = label.getData(); + real maxData = 0.0; + int maxIndex = -1; for (size_t i = 0; i < numSamples; ++i) { CHECK_GE(lbl[i], 0); CHECK_LT((size_t)lbl[i], dim); @@ -3570,7 +3537,7 @@ void CpuMatrix::classificationError(MatrixPtr output, IVectorPtr label) { maxData = out[i * dim + j]; } } - result[i] = (maxIndex != lbl[i]); + getData()[i] = (maxIndex != lbl[i]); } } diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 1cfb90a9db..b8c7adf948 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -835,7 +835,7 @@ public: * * output[i] = 0 if row i is correct. */ - virtual void classificationError(MatrixPtr output, IVectorPtr label) { + virtual void classificationError(Matrix& output, IVector& label) { LOG(FATAL) << "Not implemented"; } @@ -997,8 +997,8 @@ public: LOG(FATAL) << "Not implemeted"; } - virtual void contextProjectionForward(MatrixPtr input, - MatrixPtr weight, + virtual void contextProjectionForward(Matrix& input, + Matrix* weight, const IVector& sequence, int contextLength, int contextStart, @@ -1007,8 +1007,8 @@ public: LOG(FATAL) << "Not implemeted"; } - virtual void contextProjectionBackward(MatrixPtr inputGrad, - MatrixPtr weightGrad, + virtual void contextProjectionBackward(Matrix* inputGrad, + Matrix* weightGrad, const IVector& sequence, int contextLength, int contextStart, @@ -1017,14 +1017,14 @@ public: LOG(FATAL) << "Not implemeted"; } - virtual void contextProjectionBackwardData(MatrixPtr inputGrad, + virtual void contextProjectionBackwardData(Matrix& inputGrad, const IVector& sequence, int contextLength, int contextStart) { LOG(FATAL) << "Not implemeted"; } - virtual void contextProjectionBackwardWeight(MatrixPtr weightGrad, + virtual void contextProjectionBackwardWeight(Matrix& weightGrad, const IVector& sequence, int contextLength, int contextStart, @@ -1373,7 +1373,7 @@ public: void check(std::ostream& os, Matrix& refMat, bool printDiff = true); void randomizeUniform(); - void classificationError(MatrixPtr output, IVectorPtr label); + void classificationError(Matrix& output, IVector& label); void convExpand(Matrix& feature, int feaImgHeight, @@ -1487,20 +1487,20 @@ public: const IVector& sequence, IVector& index); - void contextProjectionForward(MatrixPtr input, - MatrixPtr weight, + void contextProjectionForward(Matrix& input, + Matrix* weight, const IVector& sequence, int contextLength, int contextStart, size_t beginPad, bool isPadding); - void contextProjectionBackwardData(MatrixPtr inputGrad, + void contextProjectionBackwardData(Matrix& inputGrad, const IVector& sequence, int contextLength, int contextStart); - void contextProjectionBackwardWeight(MatrixPtr weightGrad, + void contextProjectionBackwardWeight(Matrix& weightGrad, const IVector& sequence, int contextLength, int contextStart, @@ -1713,16 +1713,16 @@ public: const IVector& sequence, IVector& index); - void contextProjectionForward(MatrixPtr input, - MatrixPtr weight, + void contextProjectionForward(Matrix& input, + Matrix* weight, const IVector& sequence, int contextLength, int contextStart, size_t beginPad, bool isPadding); - void contextProjectionBackward(MatrixPtr inputGrad, - MatrixPtr weightGrad, + void contextProjectionBackward(Matrix* inputGrad, + Matrix* weightGrad, const IVector& sequence, int contextLength, int contextStart, @@ -1881,7 +1881,7 @@ public: void randomizeUniform(); - void classificationError(MatrixPtr output, IVectorPtr label); + void classificationError(Matrix& output, IVector& label); void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 62de5b25e4..10289940a4 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -65,16 +65,16 @@ void testMatrixProjectionForward(int contextStart, // calculate int beginPad = std::max(0, -contextStart); - cpuOutput->contextProjectionForward(cpuInput, - cpuWeight, + cpuOutput->contextProjectionForward(*cpuInput, + cpuWeight.get(), *cpuSequence, contextLength, contextStart, beginPad, padding); - gpuOutput->contextProjectionForward(gpuInput, - gpuWeight, + gpuOutput->contextProjectionForward(*gpuInput, + gpuWeight.get(), *gpuSequence, contextLength, contextStart, @@ -120,17 +120,17 @@ void testMatrixProjectionBackward(int contextStart, // calculate int beginPad = std::max(0, -contextStart); - cpuOutputGrad->contextProjectionBackward(cpuInputGrad, - cpuWeightGrad, + cpuOutputGrad->contextProjectionBackward(cpuInputGrad.get(), + cpuWeightGrad.get(), *cpuSequence, contextLength, contextStart, beginPad, padding); gpuOutputGrad->contextProjectionBackwardData( - gpuInputGrad, *gpuSequence, contextLength, contextStart); + *gpuInputGrad, *gpuSequence, contextLength, contextStart); if (padding) { - gpuOutputGrad->contextProjectionBackwardWeight(gpuWeightGrad, + gpuOutputGrad->contextProjectionBackwardWeight(*gpuWeightGrad, *gpuSequence, contextLength, contextStart, @@ -939,8 +939,8 @@ void testClassificationError(int numSamples, int dim) { gpuOutput->copyFrom(*cpuOutput); gpuLabel->copyFrom(*cpuLabel); - cpuError->classificationError(cpuOutput, cpuLabel); - gpuError->classificationError(gpuOutput, gpuLabel); + cpuError->classificationError(*cpuOutput, *cpuLabel); + gpuError->classificationError(*gpuOutput, *gpuLabel); TensorCheckEqual(*cpuError, *gpuError); } From 4fbf94993b0699bb06c5347612a7b97d692a2625 Mon Sep 17 00:00:00 2001 From: xutianbing Date: Mon, 19 Dec 2016 17:21:06 -0800 Subject: [PATCH 20/26] Refactor MUL functions, pass object reference instead of shared_ptr. --- .../gserver/layers/ConvexCombinationLayer.cpp | 6 +- paddle/gserver/layers/ExpandConvBaseLayer.cpp | 6 +- .../gserver/layers/FullMatrixProjection.cpp | 7 ++- paddle/gserver/layers/FullyConnectedLayer.cpp | 8 +-- paddle/gserver/layers/LinearChainCRF.cpp | 2 +- paddle/gserver/layers/LstmLayer.cpp | 26 ++++----- paddle/gserver/layers/MDLstmLayer.cpp | 8 +-- paddle/gserver/layers/OuterProdLayer.cpp | 6 +- paddle/gserver/layers/RecurrentLayer.cpp | 32 +++++------ .../layers/SelectiveFullyConnectedLayer.cpp | 10 ++-- paddle/gserver/layers/TensorLayer.cpp | 8 +-- .../layers/TransposedFullMatrixProjection.cpp | 7 ++- paddle/math/CpuSparseMatrix.cpp | 15 ++--- paddle/math/CpuSparseMatrix.h | 2 +- paddle/math/Matrix.cpp | 49 +++++++---------- paddle/math/Matrix.h | 14 ++--- paddle/math/SparseMatrix.cpp | 55 +++++++++---------- paddle/math/SparseMatrix.h | 7 +-- paddle/math/tests/test_SparseMatrix.cpp | 14 ++--- paddle/math/tests/test_matrixCompare.cpp | 12 ++-- .../math/tests/test_sparseMatrixCompare.cpp | 4 +- 21 files changed, 144 insertions(+), 154 deletions(-) diff --git a/paddle/gserver/layers/ConvexCombinationLayer.cpp b/paddle/gserver/layers/ConvexCombinationLayer.cpp index 3f4d77a2fe..ed57f2af3c 100644 --- a/paddle/gserver/layers/ConvexCombinationLayer.cpp +++ b/paddle/gserver/layers/ConvexCombinationLayer.cpp @@ -113,7 +113,7 @@ void ConvexCombinationLayer::forward(PassType passType) { tmpRow0->setData(inV0->getData() + i * weightDim); tmpRow1->setData(outV->getData() + i * dataDim); - tmpRow1->mul(tmpRow0, tmpMtx0, 1, 0); + tmpRow1->mul(*tmpRow0, *tmpMtx0, 1, 0); } } @@ -136,7 +136,7 @@ void ConvexCombinationLayer::backward(const UpdateCallback& callback) { tmpRow1->setData(outG->getData() + i * dataDim); tmpMtx0->setData(inV1->getData() + i * weightDim * dataDim); - tmpRow0->mul(tmpRow1, tmpMtx0->getTranspose(), 1, 1); + tmpRow0->mul(*tmpRow1, *(tmpMtx0->getTranspose()), 1, 1); } } @@ -146,7 +146,7 @@ void ConvexCombinationLayer::backward(const UpdateCallback& callback) { tmpRow1->setData(outG->getData() + i * dataDim); tmpMtx0->setData(inG1->getData() + i * weightDim * dataDim); - tmpMtx0->mul(tmpRow0->getTranspose(), tmpRow1, 1, 1); + tmpMtx0->mul(*(tmpRow0->getTranspose()), *tmpRow1, 1, 1); } } } diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.cpp b/paddle/gserver/layers/ExpandConvBaseLayer.cpp index 25948747fe..9ddccc2027 100644 --- a/paddle/gserver/layers/ExpandConvBaseLayer.cpp +++ b/paddle/gserver/layers/ExpandConvBaseLayer.cpp @@ -150,7 +150,7 @@ void ExpandConvBaseLayer::expandFwdOnce(MatrixPtr image, Matrix::create(wgtData, subM, subK, false, useGpu_); // mark transpose MatrixPtr B = Matrix::create(expInData, subK, subN, false, useGpu_); MatrixPtr C = Matrix::create(outData, subM, subN, false, useGpu_); - C->mul(A, B, 1, 1); + C->mul(*A, *B, 1, 1); A->clear(); B->clear(); @@ -185,7 +185,7 @@ void ExpandConvBaseLayer::bpropActs(MatrixPtr out, MatrixPtr C = Matrix::create(expandInData, subK, subN, false, useGpu_); MatrixPtr B = Matrix::create(localGradData, subM, subN, false, useGpu_); MatrixPtr A = Matrix::create(wgtData, subM, subK, true, useGpu_); - C->mul(A, B); // mul + C->mul(*A, *B); // mul // clear the temporary matrix A->clear(); @@ -252,7 +252,7 @@ void ExpandConvBaseLayer::bpropWeights(MatrixPtr image, MatrixPtr A = Matrix::create(expandInData, subK, subN, true, useGpu_); MatrixPtr B = Matrix::create(gradData, subM, subN, false, useGpu_); MatrixPtr C = Matrix::create(wGradData, subM, subK, false, useGpu_); - C->mul(B, A, 1, 1); + C->mul(*B, *A, 1, 1); A->clear(); B->clear(); diff --git a/paddle/gserver/layers/FullMatrixProjection.cpp b/paddle/gserver/layers/FullMatrixProjection.cpp index 9e72a33a3c..b8b6f403d6 100644 --- a/paddle/gserver/layers/FullMatrixProjection.cpp +++ b/paddle/gserver/layers/FullMatrixProjection.cpp @@ -28,7 +28,7 @@ FullMatrixProjection::FullMatrixProjection(const ProjectionConfig& config, void FullMatrixProjection::forward() { REGISTER_TIMER_INFO("FwMulTimer", getName().c_str()); - out_->value->mul(in_->value, weight_->getW(), 1, 1); + out_->value->mul(*(in_->value), *(weight_->getW()), 1, 1); } void FullMatrixProjection::backward(const UpdateCallback& callback) { @@ -37,7 +37,8 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) { /* Calculate the W-gradient for the current layer */ if (weight_->getWGrad()) { REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); - weight_->getWGrad()->mul(in_->value->getTranspose(), out_->grad, 1, 1); + weight_->getWGrad()->mul( + *(in_->value->getTranspose()), *(out_->grad), 1, 1); } // If callback does not change value, backward propagation error @@ -47,7 +48,7 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) { /* Calculate the input layers error */ if (in_->grad) { REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); - in_->grad->mul(out_->grad, weight_->getW()->getTranspose(), 1, 1); + in_->grad->mul(*(out_->grad), *(weight_->getW()->getTranspose()), 1, 1); } hl_set_sync_flag(syncFlag); diff --git a/paddle/gserver/layers/FullyConnectedLayer.cpp b/paddle/gserver/layers/FullyConnectedLayer.cpp index 89afe33c36..d8a667ff8d 100644 --- a/paddle/gserver/layers/FullyConnectedLayer.cpp +++ b/paddle/gserver/layers/FullyConnectedLayer.cpp @@ -84,8 +84,8 @@ void FullyConnectedLayer::forward(PassType passType) { auto input = getInput(i); CHECK(input.value) << "The input of 'fc' layer must be matrix"; REGISTER_TIMER_INFO("FwMulTimer", getName().c_str()); - i == 0 ? outV->mul(input.value, weights_[i]->getW(), 1, 0) - : outV->mul(input.value, weights_[i]->getW(), 1, 1); + i == 0 ? outV->mul(*input.value, *weights_[i]->getW(), 1, 0) + : outV->mul(*input.value, *weights_[i]->getW(), 1, 1); } /* add the bias-vector */ @@ -123,7 +123,7 @@ void FullyConnectedLayer::backward(const UpdateCallback& callback) { MatrixPtr oGrad = getOutputGrad(); { REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); - weights_[i]->getWGrad()->mul(input_T, oGrad, 1, 1); + weights_[i]->getWGrad()->mul(*input_T, *oGrad, 1, 1); } } @@ -136,7 +136,7 @@ void FullyConnectedLayer::backward(const UpdateCallback& callback) { if (NULL != preGrad) { MatrixPtr weights_T = weights_[i]->getW()->getTranspose(); REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); - preGrad->mul(getOutputGrad(), weights_T, 1, 1); + preGrad->mul(*getOutputGrad(), *weights_T, 1, 1); } hl_set_sync_flag(syncFlag); diff --git a/paddle/gserver/layers/LinearChainCRF.cpp b/paddle/gserver/layers/LinearChainCRF.cpp index af550c7a01..b7f748f3bb 100644 --- a/paddle/gserver/layers/LinearChainCRF.cpp +++ b/paddle/gserver/layers/LinearChainCRF.cpp @@ -59,7 +59,7 @@ real LinearChainCRF::forward(real* x, int* s, int length) { matX->rowMax(*maxX_); expX_->assign(*matX); // subtract max to avoid overflow or underflow - expX_->mul(maxX_, ones_, (real)-1, (real)1); + expX_->mul(*maxX_, *ones_, (real)-1, (real)1); expX_->exp2(); real* a = a_->getData(); diff --git a/paddle/gserver/layers/LstmLayer.cpp b/paddle/gserver/layers/LstmLayer.cpp index 2543d1b49a..01cc5fec8b 100644 --- a/paddle/gserver/layers/LstmLayer.cpp +++ b/paddle/gserver/layers/LstmLayer.cpp @@ -316,7 +316,7 @@ void LstmLayer::forwardSequence(int batchSize, } if (prevOutput_) { frameGate->setData(lstmValue.gateValue); - frameGate->mul(prevOutput_, weight_->getW(), 1, 1); + frameGate->mul(*prevOutput_, *weight_->getW(), 1, 1); } } AsyncGpuBlock asyncGpuBlock; @@ -338,7 +338,7 @@ void LstmLayer::forwardSequence(int batchSize, frameOutput->setData(lstmValue.outputValue); nextFrame(reversed_, getSize()); frameGate->setData(lstmValue.gateValue); - frameGate->mul(frameOutput, weight_->getW(), 1, 1); + frameGate->mul(*frameOutput, *weight_->getW(), 1, 1); } } if (n != numSequences - 1) { @@ -348,7 +348,7 @@ void LstmLayer::forwardSequence(int batchSize, if (!reversed_) { if (!prevState_) lstmValue.prevStateValue = nullptr; if (prevOutput_) { - frameGate->mul(frameOutput, weight_->getW(), 1, 1); + frameGate->mul(*frameOutput, *weight_->getW(), 1, 1); } } else { lstmValue.prevStateValue = nullptr; @@ -470,7 +470,7 @@ void LstmLayer::backwardSequence(int batchSize, frameGate->setData(lstmGrad.gateGrad); nextFrame(reversed_, getSize()); frameOutput->setData(lstmGrad.outputGrad); - frameOutput->mul(frameGate, weightT, 1, 1); + frameOutput->mul(*frameGate, *weightT, 1, 1); } else { nextFrame(reversed_, getSize()); } @@ -479,14 +479,14 @@ void LstmLayer::backwardSequence(int batchSize, if (weight_->getWGrad()) { if (!reversed_) { weight_->getWGrad()->mul( - output_.value->subMatrix(start, length - 1)->getTranspose(), - gate_.grad->subMatrix(start + 1, length - 1), + *output_.value->subMatrix(start, length - 1)->getTranspose(), + *gate_.grad->subMatrix(start + 1, length - 1), 1, 1); } else { weight_->getWGrad()->mul( - output_.value->subMatrix(start + 1, length - 1)->getTranspose(), - gate_.grad->subMatrix(start, length - 1), + *output_.value->subMatrix(start + 1, length - 1)->getTranspose(), + *gate_.grad->subMatrix(start, length - 1), 1, 1); } @@ -541,7 +541,7 @@ void LstmLayer::forwardBatch(int batchSize, if (n != 0) { MatrixPtr batch1 = batchValue_->getBatchValue(n - 1, batchSize); - gateValue->mul(batch1, weight_->getW(), 1, 1); + gateValue->mul(*batch1, *weight_->getW(), 1, 1); } else if (prevOutput_) { Matrix::resizeOrCreate(prevBatchOutput2_, gateValue->getHeight(), @@ -549,7 +549,7 @@ void LstmLayer::forwardBatch(int batchSize, false, useGpu_); batchValue_->prevOutput2Batch(*prevOutput_, *prevBatchOutput2_); - gateValue->mul(prevBatchOutput2_, weight_->getW(), 1, 1); + gateValue->mul(*prevBatchOutput2_, *weight_->getW(), 1, 1); batchValue_->prevOutput2Batch(*prevState_, *totalState_->subMatrix(0, numSequences)); @@ -672,16 +672,16 @@ void LstmLayer::backwardBatch(int batchSize, if (n != 0) { MatrixPtr tmp = batchGrad_->getBatchValue(n - 1, batchSize); - tmp->mul(gateGrad, weightT, 1, 1); + tmp->mul(*gateGrad, *weightT, 1, 1); } if (n != 0 && weight_->getWGrad()) { /* backward weight */ MatrixPtr outputValue = batchValue_->getBatchValue(n - 1, batchSize); - weight_->getWGrad()->mul(outputValue->getTranspose(), gateGrad, 1, 1); + weight_->getWGrad()->mul(*outputValue->getTranspose(), *gateGrad, 1, 1); } else if (prevOutput_ && weight_->getWGrad()) { weight_->getWGrad()->mul( - prevBatchOutput2_->getTranspose(), gateGrad, 1, 1); + *prevBatchOutput2_->getTranspose(), *gateGrad, 1, 1); } } } diff --git a/paddle/gserver/layers/MDLstmLayer.cpp b/paddle/gserver/layers/MDLstmLayer.cpp index 1243c12889..fb41af5631 100644 --- a/paddle/gserver/layers/MDLstmLayer.cpp +++ b/paddle/gserver/layers/MDLstmLayer.cpp @@ -547,7 +547,7 @@ void MDLstmLayer::forwardOneSequence(int start, CoordIterator& coordIter) { if (coordIter.getPrePos(delays_, i, prePos)) { int preOffset = coordIter.offset(prePos); frameGate_[start + offset].value->mul( - frameOutput_[start + preOffset].value, weight_->getW(), 1.0, 1.0); + *frameOutput_[start + preOffset].value, *weight_->getW(), 1.0, 1.0); } } forwardGate2OutputSequence(start, coordIter); @@ -747,11 +747,11 @@ void MDLstmLayer::backwardOneSequence(int start, CoordIterator& coordIter) { if (coordIter.getPrePos(delays_, i, prePos)) { int preOffset = coordIter.offset(prePos); frameOutput_[start + preOffset].grad->mul( - frameGate_[start + offset].grad, weightT, 1.0, 1.0); + *frameGate_[start + offset].grad, *weightT, 1.0, 1.0); if (weight_->getWGrad()) { weight_->getWGrad()->mul( - frameOutput_[start + preOffset].value->getTranspose(), - frameGate_[start + offset].grad, + *frameOutput_[start + preOffset].value->getTranspose(), + *frameGate_[start + offset].grad, 1.0, 1.0); } diff --git a/paddle/gserver/layers/OuterProdLayer.cpp b/paddle/gserver/layers/OuterProdLayer.cpp index cf9a008318..b606e44365 100644 --- a/paddle/gserver/layers/OuterProdLayer.cpp +++ b/paddle/gserver/layers/OuterProdLayer.cpp @@ -96,7 +96,7 @@ void OuterProdLayer::forward(PassType passType) { tmpRow0->setData(inV0->getData() + i * dim0); tmpRow1->setData(inV1->getData() + i * dim1); - tmpMtx0->mul(tmpRow0->getTranspose(), tmpRow1); + tmpMtx0->mul(*tmpRow0->getTranspose(), *tmpRow1); } } } @@ -121,7 +121,7 @@ void OuterProdLayer::backward(const UpdateCallback& callback) { tmpRow0->setData(inG0->getData() + i * dim0); tmpRow1->setData(inV1->getData() + i * dim1); - tmpRow0->mul(tmpRow1, tmpMtx0->getTranspose(), 1, 1); + tmpRow0->mul(*tmpRow1, *tmpMtx0->getTranspose(), 1, 1); } } @@ -131,7 +131,7 @@ void OuterProdLayer::backward(const UpdateCallback& callback) { tmpRow0->setData(inV0->getData() + i * dim0); tmpRow1->setData(inG1->getData() + i * dim1); - tmpRow1->mul(tmpRow0, tmpMtx0, 1, 1); + tmpRow1->mul(*tmpRow0, *tmpMtx0, 1, 1); } } } diff --git a/paddle/gserver/layers/RecurrentLayer.cpp b/paddle/gserver/layers/RecurrentLayer.cpp index 85812c9d66..94b16996a8 100644 --- a/paddle/gserver/layers/RecurrentLayer.cpp +++ b/paddle/gserver/layers/RecurrentLayer.cpp @@ -215,12 +215,12 @@ void RecurrentLayer::forwardSequence(int batchSize, void RecurrentLayer::forwardOneSequence(int start, int length) { if (!reversed_) { if (prevOutput_) { - frameOutput_[start].value->mul(prevOutput_, weight_->getW(), 1, 1); + frameOutput_[start].value->mul(*prevOutput_, *weight_->getW(), 1, 1); } activation_->forward(frameOutput_[start]); for (int i = 1; i < length; ++i) { frameOutput_[start + i].value->mul( - frameOutput_[start + i - 1].value, weight_->getW(), 1, 1); + *frameOutput_[start + i - 1].value, *weight_->getW(), 1, 1); activation_->forward(frameOutput_[start + i]); } if (prevOutput_) { @@ -230,7 +230,7 @@ void RecurrentLayer::forwardOneSequence(int start, int length) { activation_->forward(frameOutput_[start + length - 1]); for (int i = length - 2; i >= 0; --i) { frameOutput_[start + i].value->mul( - frameOutput_[start + i + 1].value, weight_->getW(), 1, 1); + *frameOutput_[start + i + 1].value, *weight_->getW(), 1, 1); activation_->forward(frameOutput_[start + i]); } } @@ -282,13 +282,13 @@ void RecurrentLayer::backwardOneSequence(int start, int length) { for (int i = length - 1; i > 0; --i) { activation_->backward(frameOutput_[start + i]); frameOutput_[start + i - 1].grad->mul( - frameOutput_[start + i].grad, weightT, 1, 1); + *frameOutput_[start + i].grad, *weightT, 1, 1); } activation_->backward(frameOutput_[start]); if (weight_->getWGrad()) { weight_->getWGrad()->mul( - output_.value->subMatrix(start, length - 1)->getTranspose(), - output_.grad->subMatrix(start + 1, length - 1), + *output_.value->subMatrix(start, length - 1)->getTranspose(), + *output_.grad->subMatrix(start + 1, length - 1), 1, 1); } @@ -296,13 +296,13 @@ void RecurrentLayer::backwardOneSequence(int start, int length) { for (int i = 0; i < length - 1; ++i) { activation_->backward(frameOutput_[start + i]); frameOutput_[start + i + 1].grad->mul( - frameOutput_[start + i].grad, weightT, 1, 1); + *frameOutput_[start + i].grad, *weightT, 1, 1); } activation_->backward(frameOutput_[start + length - 1]); if (weight_->getWGrad()) { weight_->getWGrad()->mul( - output_.value->subMatrix(start + 1, length - 1)->getTranspose(), - output_.grad->subMatrix(start, length - 1), + *output_.value->subMatrix(start + 1, length - 1)->getTranspose(), + *output_.grad->subMatrix(start, length - 1), 1, 1); } @@ -329,7 +329,7 @@ void RecurrentLayer::forwardBatch(int batchSize, if (n != 0) { MatrixPtr batch1 = batchValue_->getBatchValue(n - 1, batch2->getHeight()); - batch2->mul(batch1, weight_->getW(), 1, 1); + batch2->mul(*batch1, *weight_->getW(), 1, 1); } Argument arg; arg.value = batch2; @@ -367,14 +367,14 @@ void RecurrentLayer::backwardBatch(int batchSize, if (n != 0) { batch1 = batchGrad_->getBatchValue(n - 1, batch2->getHeight()); - batch1->mul(batch2, weightT, 1, 1); + batch1->mul(*batch2, *weightT, 1, 1); } if (backwardByBatch && weight_->getWGrad()) { if (n != 0) { /* backward weight */ batch1 = batchValue_->getBatchValue(n - 1, batch2->getHeight()); - weight_->getWGrad()->mul(batch1->getTranspose(), batch2, 1, 1); + weight_->getWGrad()->mul(*batch1->getTranspose(), *batch2, 1, 1); } } } @@ -389,14 +389,14 @@ void RecurrentLayer::backwardBatch(int batchSize, int len = starts[seq + 1] - starts[seq]; if (!reversed_) { weight_->getWGrad()->mul( - output_.value->subMatrix(starts[seq], len - 1)->getTranspose(), - output_.grad->subMatrix(starts[seq] + 1, len - 1), + *output_.value->subMatrix(starts[seq], len - 1)->getTranspose(), + *output_.grad->subMatrix(starts[seq] + 1, len - 1), 1, 1); } else { weight_->getWGrad()->mul( - output_.value->subMatrix(starts[seq] + 1, len - 1)->getTranspose(), - output_.grad->subMatrix(starts[seq], len - 1), + *output_.value->subMatrix(starts[seq] + 1, len - 1)->getTranspose(), + *output_.grad->subMatrix(starts[seq], len - 1), 1, 1); } diff --git a/paddle/gserver/layers/SelectiveFullyConnectedLayer.cpp b/paddle/gserver/layers/SelectiveFullyConnectedLayer.cpp index 9200a01eee..5eacff6b71 100644 --- a/paddle/gserver/layers/SelectiveFullyConnectedLayer.cpp +++ b/paddle/gserver/layers/SelectiveFullyConnectedLayer.cpp @@ -155,20 +155,20 @@ void SelectiveFullyConnectedLayer::forward(PassType passType) { // manully compute the multiplication of // the input vector and the selected rows. REGISTER_TIMER("selective.plain"); - interOutput_->mul(input, weight->getTranspose(), 1, scaleT); + interOutput_->mul(*input, *weight->getTranspose(), 1, scaleT); } else { // if the indecies is not sparse enough, // use full mul instead REGISTER_TIMER("selective.mul"); if (fullOutput_) { - interOutput_->mul(input, weight->getTranspose(), 1, scaleT); + interOutput_->mul(*input, *weight->getTranspose(), 1, scaleT); } else { Matrix::resizeOrCreate(mmat_, hsize, wsize, /*trans=*/false, /*useGpu=*/useGpu_); - mmat_->mul(input, weight->getTranspose()); + mmat_->mul(*input, *weight->getTranspose()); interOutput_->add3(mmat_); } } @@ -242,14 +242,14 @@ void SelectiveFullyConnectedLayer::backward(const UpdateCallback& callback) { MatrixPtr preGrad = getInputGrad(i); if (preGrad) { REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); - preGrad->mul(interOutGrad_, weights_[i]->getW(), 1, 1); + preGrad->mul(*interOutGrad_, *weights_[i]->getW(), 1, 1); } MatrixPtr wGrad = weights_[i]->getWGrad(); if (wGrad) { REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); MatrixPtr input = getInputValue(i); - wGrad->mul(interOutGrad_->getTranspose(), input, 1, 1); + wGrad->mul(*interOutGrad_->getTranspose(), *input, 1, 1); } { diff --git a/paddle/gserver/layers/TensorLayer.cpp b/paddle/gserver/layers/TensorLayer.cpp index 642eb1bdd3..5be88d7c05 100644 --- a/paddle/gserver/layers/TensorLayer.cpp +++ b/paddle/gserver/layers/TensorLayer.cpp @@ -77,7 +77,7 @@ void TensorLayer::forward(PassType passType) { REGISTER_TIMER_INFO("TensorFwMulTimer", getName().c_str()); for (size_t i = 0; i < getSize(); ++i) { MatrixPtr weights = weights_[i]->getW(); - tmpMat->mul(input1, weights, 1, 0); + tmpMat->mul(*input1, *weights, 1, 0); outV->rowDotMul(i, *tmpMat, *input2); } } @@ -112,7 +112,7 @@ void TensorLayer::backward(const UpdateCallback& callback) { if (weights_[i]->getWGrad()) { tmpMat->rowScale(i, *input1, *oGrad); MatrixPtr input1_T = tmpMat->getTranspose(); - weights_[i]->getWGrad()->mul(input1_T, input2, 1, 1); + weights_[i]->getWGrad()->mul(*input1_T, *input2, 1, 1); } } } @@ -130,11 +130,11 @@ void TensorLayer::backward(const UpdateCallback& callback) { if (NULL != preGrad1) { /* (grad * e2) * trans(W) */ tmpMat->rowScale(i, *input2, *oGrad); MatrixPtr weights_T = weights->getTranspose(); - preGrad1->mul(tmpMat, weights_T, 1, 1); + preGrad1->mul(*tmpMat, *weights_T, 1, 1); } if (NULL != preGrad2) { /* (grad * e1) * W */ tmpMat->rowScale(i, *input1, *oGrad); - preGrad2->mul(tmpMat, weights, 1, 1); + preGrad2->mul(*tmpMat, *weights, 1, 1); } } } diff --git a/paddle/gserver/layers/TransposedFullMatrixProjection.cpp b/paddle/gserver/layers/TransposedFullMatrixProjection.cpp index 3f7ff04882..2a12499e5b 100644 --- a/paddle/gserver/layers/TransposedFullMatrixProjection.cpp +++ b/paddle/gserver/layers/TransposedFullMatrixProjection.cpp @@ -46,7 +46,7 @@ TransposedFullMatrixProjection::TransposedFullMatrixProjection( void TransposedFullMatrixProjection::forward() { REGISTER_TIMER_INFO("FwMulTimer", getName().c_str()); - out_->value->mul(in_->value, weight_->getW()->getTranspose(), 1, 1); + out_->value->mul(*(in_->value), *(weight_->getW()->getTranspose()), 1, 1); } void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) { @@ -55,7 +55,8 @@ void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) { /* Calculate the W-gradient for the current layer */ if (weight_->getWGrad()) { REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); - weight_->getWGrad()->mul(out_->grad->getTranspose(), in_->value, 1, 1); + weight_->getWGrad()->mul( + *(out_->grad->getTranspose()), *(in_->value), 1, 1); } // If callback does not change value, backprop error asynchronously so that @@ -69,7 +70,7 @@ void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) { /* Calculate the input layers error */ if (in_->grad) { REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); - in_->grad->mul(out_->grad, weight_->getW(), 1, 1); + in_->grad->mul(*(out_->grad), *(weight_->getW()), 1, 1); } hl_set_sync_flag(syncFlag); diff --git a/paddle/math/CpuSparseMatrix.cpp b/paddle/math/CpuSparseMatrix.cpp index b5d5b6ef61..82a482f701 100644 --- a/paddle/math/CpuSparseMatrix.cpp +++ b/paddle/math/CpuSparseMatrix.cpp @@ -163,15 +163,16 @@ MatrixPtr CpuSparseMatrix::getTranspose() { SparseValueType CpuSparseMatrix::getValueType() { return valueType_; } -void CpuSparseMatrix::mul(MatrixPtr a, MatrixPtr b, real scaleAB, real scaleT) { +void CpuSparseMatrix::mul(const Matrix& a, + const Matrix& b, + real scaleAB, + real scaleT) { CHECK(!isTransposed()) << "Not supported"; + const auto a_ptr = dynamic_cast(&a); + const auto b_ptr = dynamic_cast(&b); - if (dynamic_cast(a.get()) && dynamic_cast(b.get())) { - CpuMatrix::mul(dynamic_cast(a.get()), - dynamic_cast(b.get()), - this, - scaleAB, - scaleT); + if (a_ptr && b_ptr) { + CpuMatrix::mul((CpuMatrix*)a_ptr, (CpuMatrix*)b_ptr, this, scaleAB, scaleT); } else { LOG(FATAL) << "not supported"; } diff --git a/paddle/math/CpuSparseMatrix.h b/paddle/math/CpuSparseMatrix.h index 9676f8864f..d3e8871cb5 100644 --- a/paddle/math/CpuSparseMatrix.h +++ b/paddle/math/CpuSparseMatrix.h @@ -203,7 +203,7 @@ public: /// mem MUST be alloced outside (memAlloc=false) void transpose(MatrixPtr matTrans, bool memAlloc); - void mul(MatrixPtr A, MatrixPtr B, real alpha, real beta); + void mul(const Matrix& A, const Matrix& B, real alpha, real beta); /** * @brief sparseMatrix += denseMatrix diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 3b3c1d7d48..0193f2f997 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -582,18 +582,16 @@ void GpuMatrix::mul(const GpuMatrix& a, } /* this = a*b */ -void GpuMatrix::mul(const MatrixPtr a, const MatrixPtr b) { - mul(a, b, 1.0, 0.0); -} +void GpuMatrix::mul(const Matrix& a, const Matrix& b) { mul(a, b, 1.0, 0.0); } -void GpuMatrix::mul(const MatrixPtr a, - const MatrixPtr b, +void GpuMatrix::mul(const Matrix& a, + const Matrix& b, real scaleAB, real scaleT) { - GpuMatrixPtr a_ptr = std::dynamic_pointer_cast(a); - GpuMatrixPtr b_ptr = std::dynamic_pointer_cast(b); - GpuSparseMatrixPtr a_ptr_s = std::dynamic_pointer_cast(a); - GpuSparseMatrixPtr b_ptr_s = std::dynamic_pointer_cast(b); + const auto a_ptr = dynamic_cast(&a); + const auto b_ptr = dynamic_cast(&b); + const auto a_ptr_s = dynamic_cast(&a); + const auto b_ptr_s = dynamic_cast(&b); if (a_ptr && b_ptr) { mul(*a_ptr, *b_ptr, scaleAB, scaleT); @@ -2598,29 +2596,22 @@ void CpuMatrix::sequenceAvgForward(Matrix& a, } /* this = scaleAB*(a*b) + scaleT*this*/ -void CpuMatrix::mul(const MatrixPtr a, - const MatrixPtr b, +void CpuMatrix::mul(const Matrix& a, + const Matrix& b, real scaleAB, real scaleT) { CHECK(!isTransposed()) << "Not supported"; + const auto a_ptr = dynamic_cast(&a); + const auto b_ptr = dynamic_cast(&b); + const auto a_ptr_s = dynamic_cast(&a); + const auto b_ptr_s = dynamic_cast(&b); - if (dynamic_cast(a.get()) && dynamic_cast(b.get())) { - mul(dynamic_cast(a.get()), - dynamic_cast(b.get()), - scaleAB, - scaleT); - } else if (dynamic_cast(a.get()) && - dynamic_cast(b.get())) { - mul(dynamic_cast(a.get()), - dynamic_cast(b.get()), - scaleAB, - scaleT); - } else if (dynamic_cast(a.get()) && - dynamic_cast(b.get())) { - mul(dynamic_cast(a.get()), - dynamic_cast(b.get()), - scaleAB, - scaleT); + if (a_ptr && b_ptr) { + mul((CpuMatrix*)a_ptr, (CpuMatrix*)b_ptr, scaleAB, scaleT); + } else if (a_ptr_s && b_ptr) { + mul((CpuSparseMatrix*)a_ptr_s, (CpuMatrix*)b_ptr, scaleAB, scaleT); + } else if (a_ptr && b_ptr_s) { + mul((CpuMatrix*)a_ptr, (CpuSparseMatrix*)b_ptr_s, scaleAB, scaleT); } else { LOG(FATAL) << "Not supported"; } @@ -3289,7 +3280,7 @@ void CpuMatrix::addColumnVector(const Matrix& b) { } /* this = a*b */ -void CpuMatrix::mul(const MatrixPtr a, const MatrixPtr b) { +void CpuMatrix::mul(const Matrix& a, const Matrix& b) { return mul(a, b, 1.0, 0.0); } diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index b8c7adf948..dfcb0853df 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -444,8 +444,8 @@ public: * this = scaleAB*(a*b) + scaleT*this * @endcode */ - virtual void mul(const MatrixPtr a, - const MatrixPtr b, + virtual void mul(const Matrix& a, + const Matrix& b, real scaleAB, real scaleT) { LOG(FATAL) << "Not implemented"; @@ -643,7 +643,7 @@ public: * this = a*b * @endcode */ - virtual void mul(const MatrixPtr a, const MatrixPtr b) { + virtual void mul(const Matrix& a, const Matrix& b) { LOG(FATAL) << "Not implemented"; } @@ -1272,14 +1272,14 @@ public: * this = scaleAB*(a*b) + scaleT*this * @endcode */ - void mul(const MatrixPtr a, const MatrixPtr b, real scaleAB, real scaleT); + void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT); /** * @code * this = a*b * @endcode */ - void mul(const MatrixPtr a, const MatrixPtr b); + void mul(const Matrix& a, const Matrix& b); void mul(const GpuMatrix& a, const GpuMatrix& b, real scaleAB, real scaleT); @@ -1784,7 +1784,7 @@ public: void addColumnVector(const Matrix& b); - void mul(const MatrixPtr a, const MatrixPtr b, real scaleAB, real scaleT); + void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT); void mul(CpuMatrix* a, CpuMatrix* b, real scaleAB, real scaleT); void mul(CpuMatrix* a, CpuSparseMatrix* b, real scaleAB, real scaleT); @@ -1807,7 +1807,7 @@ public: virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT); - void mul(const MatrixPtr a, const MatrixPtr b); + void mul(const Matrix& a, const Matrix& b); void rightMul(Matrix& b, real scaleAB, real scaleT); void rightMul(Matrix& b); diff --git a/paddle/math/SparseMatrix.cpp b/paddle/math/SparseMatrix.cpp index 9154503c21..720a035ecb 100644 --- a/paddle/math/SparseMatrix.cpp +++ b/paddle/math/SparseMatrix.cpp @@ -571,49 +571,48 @@ void GpuSparseMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { hl_stream_synchronize(stream); } -void GpuSparseMatrix::mul(const GpuMatrixPtr a, - const GpuMatrixPtr b, +void GpuSparseMatrix::mul(const GpuMatrix& a, + const GpuMatrix& b, real scaleAB, real scaleT) { - CHECK(a->useGpu_ && b->useGpu_) << "type not match"; + CHECK(a.useGpu_ && b.useGpu_) << "type not match"; CHECK(!trans_) << "trans not supported"; - real* A_d = a->getData(); - real* B_d = b->getData(); + real* A_d = (real*)a.getData(); + real* B_d = (real*)b.getData(); hl_sparse_matrix_s C_d = sMatrix_.get(); - hl_trans_op_t a_trans = a->trans_ ? HPPL_OP_T : HPPL_OP_N; - hl_trans_op_t b_trans = b->trans_ ? HPPL_OP_T : HPPL_OP_N; - - if (!a->trans_ && !b->trans_) { - CHECK(height_ == a->getHeight()); - CHECK(width_ == b->getWidth()); - CHECK(a->getWidth() == b->getHeight()); - } else if (a->trans_ && !b->trans_) { - CHECK(height_ == a->getWidth()); - CHECK(width_ == b->getWidth()); - CHECK(a->getHeight() == b->getHeight()); - } else if (!a->trans_ && b->trans_) { - CHECK(height_ == a->getHeight()); - CHECK(width_ == b->getHeight()); - CHECK(a->getWidth() == b->getWidth()); + hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N; + hl_trans_op_t b_trans = b.trans_ ? HPPL_OP_T : HPPL_OP_N; + + if (!a.trans_ && !b.trans_) { + CHECK(height_ == a.getHeight()); + CHECK(width_ == b.getWidth()); + CHECK(a.getWidth() == b.getHeight()); + } else if (a.trans_ && !b.trans_) { + CHECK(height_ == a.getWidth()); + CHECK(width_ == b.getWidth()); + CHECK(a.getHeight() == b.getHeight()); + } else if (!a.trans_ && b.trans_) { + CHECK(height_ == a.getHeight()); + CHECK(width_ == b.getHeight()); + CHECK(a.getWidth() == b.getWidth()); } else { LOG(INFO) << "Not support"; } int dimM = height_; int dimN = width_; - int dimK = !b->trans_ ? b->getHeight() : b->getWidth(); + int dimK = !b.trans_ ? b.getHeight() : b.getWidth(); hl_sparse_matrix_mul( A_d, a_trans, B_d, b_trans, C_d, dimM, dimN, dimK, scaleAB, scaleT); } -void GpuSparseMatrix::mul(const MatrixPtr a, - const MatrixPtr b, +void GpuSparseMatrix::mul(const Matrix& a, + const Matrix& b, real scaleAB, real scaleT) { - if (std::dynamic_pointer_cast(a) && - std::dynamic_pointer_cast(b)) { - GpuMatrixPtr a_ptr = std::dynamic_pointer_cast(a); - GpuMatrixPtr b_ptr = std::dynamic_pointer_cast(b); - mul(a_ptr, b_ptr, scaleAB, scaleT); + const auto a_ptr = dynamic_cast(&a); + const auto b_ptr = dynamic_cast(&b); + if (a_ptr && b_ptr) { + mul(*a_ptr, *b_ptr, scaleAB, scaleT); } else { LOG(FATAL) << "not supported"; } diff --git a/paddle/math/SparseMatrix.h b/paddle/math/SparseMatrix.h index bd96a3301d..1d3801548e 100644 --- a/paddle/math/SparseMatrix.h +++ b/paddle/math/SparseMatrix.h @@ -104,10 +104,7 @@ public: size_t newNnz, SparseValueType valueType); - void mul(const GpuMatrixPtr a, - const GpuMatrixPtr b, - real scaleAB, - real scaleT); + void mul(const GpuMatrix& a, const GpuMatrix& b, real scaleAB, real scaleT); /// B = A , B.trans = !A.trans MatrixPtr getTranspose(); @@ -218,7 +215,7 @@ protected: void copyRow(int offsets, size_t colNum, const sparse_float_value_t* row); public: - void mul(const MatrixPtr a, const MatrixPtr b, real scaleAB, real scaleT); + void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT); void copyFrom(CpuSparseMatrix& src, hl_stream_t stream); void copyFrom(GpuSparseMatrix& src, hl_stream_t stream); diff --git a/paddle/math/tests/test_SparseMatrix.cpp b/paddle/math/tests/test_SparseMatrix.cpp index 88b75b6d83..0949ab7ffb 100644 --- a/paddle/math/tests/test_SparseMatrix.cpp +++ b/paddle/math/tests/test_SparseMatrix.cpp @@ -33,8 +33,8 @@ TEST(Matrix, CopyCpuMatrixToSparseMatrix) { ret2(new CpuMatrix(HEIGHT, WIDTH_TEST)); ret1->zeroMem(); ret2->zeroMem(); - ret1->mul(testMatrix, mulCpuMatrix, 1.0, 1.0); - ret2->mul(testCpuMatrix, mulCpuMatrix, 1.0, 1.0); + ret1->mul(*testMatrix, *mulCpuMatrix, 1.0, 1.0); + ret2->mul(*testCpuMatrix, *mulCpuMatrix, 1.0, 1.0); checkMatrixEqual(ret1, ret2); } @@ -147,9 +147,9 @@ void test_sparse_matrix_mul(MatrixPara paraA, hl_stream_synchronize(stream); /*matrix mul*/ - cpuMatrixC->mul(cpuMatrixA, cpuMatrixB, 1.0, 1.0); - gpuMatrixC->mul(gpuMatrixA, gpuMatrixB, 1.0, 1.0); - cpuDenseC->mul(cpuDenseA, cpuDenseB, 1.0, 1.0); + cpuMatrixC->mul(*cpuMatrixA, *cpuMatrixB, 1.0, 1.0); + gpuMatrixC->mul(*gpuMatrixA, *gpuMatrixB, 1.0, 1.0); + cpuDenseC->mul(*cpuDenseA, *cpuDenseB, 1.0, 1.0); gpuMatrixC_d2h->copyFrom(*gpuMatrixC, stream); hl_stream_synchronize(stream); @@ -224,8 +224,8 @@ TEST(Matrix, CopySparseMatrixToGpuSparseMatrix) { MatrixPtr ret2(new GpuMatrix(HEIGHT, WIDTH_TEST)); ret1->zeroMem(); ret2->zeroMem(); - ret1->mul(testMatrix, mulCpuMatrix, 1.0, 1.0); - ret2->mul(testGpuMatrix, mulGpuMatrix, 1.0, 1.0); + ret1->mul(*testMatrix, *mulCpuMatrix, 1.0, 1.0); + ret2->mul(*testGpuMatrix, *mulGpuMatrix, 1.0, 1.0); checkMatrixEqual(ret1, ret2); } diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 10289940a4..c6fc849ba0 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -318,7 +318,7 @@ void testMatrixInverse(int height) { cpu->randomizeUniform(); MatrixPtr cpuT = cpu->getTranspose(); MatrixPtr outputCheck = std::make_shared(height, height); - outputCheck->mul(cpu, cpuT); + outputCheck->mul(*cpu, *cpuT); cpu->setDiag(1.0); cpu->add(*outputCheck); @@ -328,7 +328,7 @@ void testMatrixInverse(int height) { TensorCheckErr(*cpuI, *gpuI); - outputCheck->mul(cpu, cpuI); + outputCheck->mul(*cpu, *cpuI); cpu->setDiag(1.0); TensorCheckErr(*cpu, *outputCheck); } @@ -509,8 +509,8 @@ void testMatrixMul(bool transa, bool transb, int dimM, int dimN, int dimK) { gpuB->copyFrom(*cpuB); gpuC->copyFrom(*cpuC); - cpuC->mul(cpuA, cpuB, alpha, beta); - gpuC->mul(gpuA, gpuB, alpha, beta); + cpuC->mul(*cpuA, *cpuB, alpha, beta); + gpuC->mul(*gpuA, *gpuB, alpha, beta); TensorCheckErr(*cpuC, *gpuC); } @@ -581,8 +581,8 @@ void testSubMatrixMul(bool transa, bool transb, int dimM, int dimN, int dimK) { MatrixPtr subCpuC = cpuC->subMatrix(startM, endM, startN, endN); MatrixPtr subGpuC = gpuC->subMatrix(startM, endM, startN, endN); - subCpuC->mul(subCpuA, subCpuB, alpha, beta); - subGpuC->mul(subGpuA, subGpuB, alpha, beta); + subCpuC->mul(*subCpuA, *subCpuB, alpha, beta); + subGpuC->mul(*subGpuA, *subGpuB, alpha, beta); TensorCheckErr(*cpuC, *gpuC); } diff --git a/paddle/math/tests/test_sparseMatrixCompare.cpp b/paddle/math/tests/test_sparseMatrixCompare.cpp index 6f6de238ba..dcdbccffc3 100644 --- a/paddle/math/tests/test_sparseMatrixCompare.cpp +++ b/paddle/math/tests/test_sparseMatrixCompare.cpp @@ -102,8 +102,8 @@ void testSpMatrixMul(int M, int N, int K, real rate) { gpuC->copyFrom(*cpuC, stream); hl_stream_synchronize(stream); - cpuC->mul(cpuA, cpuB->getTranspose(), 1, 1); - gpuC->mul(gpuA, gpuB->getTranspose(), 1, 1); + cpuC->mul(*cpuA, *cpuB->getTranspose(), 1, 1); + gpuC->mul(*gpuA, *gpuB->getTranspose(), 1, 1); MatrixPtr outputCheck(new CpuSparseMatrix(M, N, nnz)); outputCheck->copyFrom(*gpuC, stream); From bf26679c3214f2c0c24f02218d3c15e720557a38 Mon Sep 17 00:00:00 2001 From: yangwenbo02 Date: Tue, 20 Dec 2016 13:51:56 +0800 Subject: [PATCH 21/26] update docker_install_en.rst --- doc/getstarted/build_and_install/docker_install_en.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index 1cc23ac3aa..57725c0d85 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -44,8 +44,7 @@ The general development workflow with Docker and Bazel is as follows: cd paddle docker build -t paddle:dev -f paddle/scripts/docker/Dockerfile . - Apt-get source errors may occur when building paddle docker image. - **You can specify the UBUNTU MIRROR with** :code:`--build-arg UBUNTU_MIRROR` **like the example below.** + Sometimes docker build might suffer from a slow network connection to the official Ubuntu apt-source servers. In such case, we can specify an apt-source mirror server that is geologically nearer to us. In the following example, we specified an apt-source server that responds fast in China.You can specify the UBUNTU MIRROR with :code:`--build-arg UBUNTU_MIRROR` like the example below. .. code-block:: bash From 8fe3a3aa73be9b7d1f748e3809dff9f5323be719 Mon Sep 17 00:00:00 2001 From: Peng Li Date: Tue, 20 Dec 2016 16:46:42 +0800 Subject: [PATCH 22/26] Add excluded_chunk_types to ChunkEvaluator The chunks of types in excluded_chunk_types will not be counted in ChunkEvaluator. This is useful for tasks such as SRL, in which chunks of type V (verb) will not be taken into account in evaluation. --- paddle/gserver/evaluators/ChunkEvaluator.cpp | 17 ++++++-- proto/ModelConfig.proto | 10 ++++- python/paddle/trainer/config_parser.py | 6 ++- .../trainer_config_helpers/evaluators.py | 39 +++++++++++-------- 4 files changed, 50 insertions(+), 22 deletions(-) diff --git a/paddle/gserver/evaluators/ChunkEvaluator.cpp b/paddle/gserver/evaluators/ChunkEvaluator.cpp index 3d8af5bcd4..15e0e95206 100644 --- a/paddle/gserver/evaluators/ChunkEvaluator.cpp +++ b/paddle/gserver/evaluators/ChunkEvaluator.cpp @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "paddle/math/Vector.h" @@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator { std::vector labelSegments_; std::vector outputSegments_; + std::set excludedChunkTypes_; public: virtual void init(const EvaluatorConfig& config) { @@ -105,6 +107,10 @@ public: } CHECK(config.has_num_chunk_types()) << "Missing num_chunk_types in config"; otherChunkType_ = numChunkTypes_ = config.num_chunk_types(); + + // the chunks of types in excludedChunkTypes_ will not be counted + auto& tmp = config.excluded_chunk_types(); + excludedChunkTypes_.insert(tmp.begin(), tmp.end()); } virtual void start() { @@ -157,7 +163,8 @@ public: size_t i = 0, j = 0; while (i < outputSegments_.size() && j < labelSegments_.size()) { if (outputSegments_[i] == labelSegments_[j]) { - ++numCorrect_; + if (excludedChunkTypes_.count(outputSegments_[i].type) != 1) + ++numCorrect_; } if (outputSegments_[i].end < labelSegments_[j].end) { ++i; @@ -168,8 +175,12 @@ public: ++j; } } - numLabelSegments_ += labelSegments_.size(); - numOutputSegments_ += outputSegments_.size(); + for (auto& segment : labelSegments_) { + if (excludedChunkTypes_.count(segment.type) != 1) ++numLabelSegments_; + } + for (auto& segment : outputSegments_) { + if (excludedChunkTypes_.count(segment.type) != 1) ++numOutputSegments_; + } } void getSegments(int* label, int length, std::vector& segments) { diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 552af71e76..e24ed21fbb 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -433,8 +433,12 @@ message EvaluatorConfig { repeated string input_layers = 3; // Used by ChunkEvaluator - optional string chunk_scheme = 4; // one of "IOB", "IOE", "IOBES" - optional int32 num_chunk_types = 5; // number of chunk types other than "other" + // one of "IOB", "IOE", "IOBES" + optional string chunk_scheme = 4; + // number of chunk types other than "other" + optional int32 num_chunk_types = 5; + // chunk of these types are not counted + repeated int32 excluded_chunk_types = 12; // Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator // For multi binary labels: true if output > classification_threshold @@ -453,6 +457,8 @@ message EvaluatorConfig { // whether to delimit the sequence in the seq_text_printer optional bool delimited = 11 [default = true]; + + // NOTE: 12 has been occupied by excluded_chunk_types } message LinkConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index ea3e4308fe..39892d0533 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1240,7 +1240,8 @@ def Evaluator( dict_file=None, result_file=None, num_results=None, - delimited=None, ): + delimited=None, + excluded_chunk_types=None, ): evaluator = g_config.model_config.evaluators.add() evaluator.type = type evaluator.name = MakeLayerNameInSubmodel(name) @@ -1269,6 +1270,9 @@ def Evaluator( if delimited is not None: evaluator.delimited = delimited + if excluded_chunk_types: + evaluator.excluded_chunk_types.extend(excluded_chunk_types) + class LayerBase(object): def __init__( diff --git a/python/paddle/trainer_config_helpers/evaluators.py b/python/paddle/trainer_config_helpers/evaluators.py index 3e0e88972c..731e30d367 100644 --- a/python/paddle/trainer_config_helpers/evaluators.py +++ b/python/paddle/trainer_config_helpers/evaluators.py @@ -57,19 +57,21 @@ def evaluator(*attrs): return impl -def evaluator_base(input, - type, - label=None, - weight=None, - name=None, - chunk_scheme=None, - num_chunk_types=None, - classification_threshold=None, - positive_label=None, - dict_file=None, - result_file=None, - num_results=None, - delimited=None): +def evaluator_base( + input, + type, + label=None, + weight=None, + name=None, + chunk_scheme=None, + num_chunk_types=None, + classification_threshold=None, + positive_label=None, + dict_file=None, + result_file=None, + num_results=None, + delimited=None, + excluded_chunk_types=None, ): """ Evaluator will evaluate the network status while training/testing. @@ -127,7 +129,8 @@ def evaluator_base(input, positive_label=positive_label, dict_file=dict_file, result_file=result_file, - delimited=delimited) + delimited=delimited, + excluded_chunk_types=excluded_chunk_types, ) @evaluator(EvaluatorAttribute.FOR_CLASSIFICATION) @@ -330,7 +333,8 @@ def chunk_evaluator( label, chunk_scheme, num_chunk_types, - name=None, ): + name=None, + excluded_chunk_types=None, ): """ Chunk evaluator is used to evaluate segment labelling accuracy for a sequence. It calculates the chunk detection F1 score. @@ -376,6 +380,8 @@ def chunk_evaluator( :param num_chunk_types: number of chunk types other than "other" :param name: The Evaluator name, it is optional. :type name: basename|None + :param excluded_chunk_types: chunks of these types are not considered + :type excluded_chunk_types: list of integer|[] """ evaluator_base( name=name, @@ -383,7 +389,8 @@ def chunk_evaluator( input=input, label=label, chunk_scheme=chunk_scheme, - num_chunk_types=num_chunk_types) + num_chunk_types=num_chunk_types, + excluded_chunk_types=excluded_chunk_types, ) @evaluator(EvaluatorAttribute.FOR_UTILS) From 5fddd99e18f3920ff0d8158fd4a9800d5566943e Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 20 Dec 2016 17:20:22 +0800 Subject: [PATCH 23/26] move TEST from test_matrixCompare.cpp to cross_map_normal_op_test.cpp --- cmake/util.cmake | 1 + paddle/function/CMakeLists.txt | 35 +++-- paddle/function/FunctionTest.h | 102 +++++++++++++ paddle/function/TestMain.cpp | 22 +++ paddle/function/cross_map_normal_op_test.cpp | 71 +++++++++ paddle/math/tests/test_matrixCompare.cpp | 144 ------------------- 6 files changed, 221 insertions(+), 154 deletions(-) create mode 100644 paddle/function/FunctionTest.h create mode 100644 paddle/function/TestMain.cpp create mode 100644 paddle/function/cross_map_normal_op_test.cpp diff --git a/cmake/util.cmake b/cmake/util.cmake index 03734e7839..8a71b23c62 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -107,6 +107,7 @@ function(link_paddle_exe TARGET_NAME) paddle_parameter paddle_proto paddle_cuda + paddle_test_main ${METRIC_LIBS} ${PROTOBUF_LIBRARY} ${LIBGLOG_LIBRARY} diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 8fad0e3ebd..0697842bbe 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -1,12 +1,27 @@ -file(GLOB FUNCTION_HEADERS . *.h) - -if(NOT WITH_GPU) - file(GLOB FUNCTION_SOURCES . *.cpp) - add_library(paddle_function STATIC ${FUNCTION_SOURCES}) -else() - file(GLOB FUNCTION_SOURCES . *.cpp *.cu) - cuda_add_library(paddle_function ${FUNCTION_SOURCES}) +file(GLOB h_files . *_op.h) +file(GLOB cpp_files . *_op.cpp) + +list(APPEND h_files Function.h) +list(APPEND cpp_files Function.cpp) + +if(WITH_GPU) + file(GLOB cu_files . *_op_gpu.cu) + cuda_compile(cu_objs ${cu_files}) endif() -add_style_check_target(paddle_function ${FUNCTION_SOURCES}) -add_style_check_target(paddle_function ${FUNCTION_HEADERS}) +add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) + +add_library(paddle_test_main STATIC TestMain.cpp) + +if(WITH_GPU) + # TODO: + # file(GLOB test_files . *_op_test.cpp) + # add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files}) + add_simple_unittest(cross_map_normal_op_test) +endif() + +add_style_check_target(paddle_function ${h_files}) +add_style_check_target(paddle_function ${cpp_files}) +if(WITH_GPU) + add_style_check_target(paddle_function ${cu_files}) +endif() diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h new file mode 100644 index 0000000000..a8c5e412bd --- /dev/null +++ b/paddle/function/FunctionTest.h @@ -0,0 +1,102 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" +#include "paddle/math/Vector.h" +#include "paddle/math/tests/TensorCheck.h" + +namespace paddle { + +class FunctionCompare { +public: + FunctionCompare(const std::string& name, const FuncConfig& config) + : cpu(FunctionBase::funcRegistrar_.createByType(name + "-CPU")), + gpu(FunctionBase::funcRegistrar_.createByType(name + "-GPU")) { + cpu->init(config); + gpu->init(config); + } + + void cmpWithArg(const Arguments& inputs, + const Arguments& outputs, + const Arguments& inouts) { + // init cpu and gpu arguments + auto initArgs = [=]( + Arguments& cpuArgs, Arguments& gpuArgs, const Arguments& inArgs) { + for (auto arg : inArgs) { + size_t size = sizeof(real); + for (auto dim : arg.dims_) { + size *= dim; + } + cpuMemory.emplace_back(std::make_shared(size)); + gpuMemory.emplace_back(std::make_shared(size)); + cpuArgs.emplace_back( + Tensor((real*)cpuMemory.back()->getBuf(), arg.dims_)); + gpuArgs.emplace_back( + Tensor((real*)gpuMemory.back()->getBuf(), arg.dims_)); + + // will use an api to refactor this code. + CpuVector cpuVector(size / sizeof(real), + (real*)cpuArgs.back().getData()); + GpuVector gpuVector(size / sizeof(real), + (real*)gpuArgs.back().getData()); + cpuVector.uniform(0.001, 1); + gpuVector.copyFrom(cpuVector); + } + }; + initArgs(cpuInputs, gpuInputs, inputs); + initArgs(cpuOutputs, gpuOutputs, outputs); + initArgs(cpuInouts, gpuInouts, inouts); + + // function calculate + cpu->calc(cpuInputs, cpuOutputs, cpuInouts); + gpu->calc(gpuInputs, gpuOutputs, gpuInouts); + + // check outputs and inouts + auto checkArgs = [=](const Arguments& cpuArgs, const Arguments& gpuArgs) { + for (size_t i = 0; i < cpuArgs.size(); i++) { + auto cpu = cpuArgs[i]; + auto gpu = gpuArgs[i]; + size_t size = 1; + for (auto dim : cpu.dims_) { + size *= dim; + } + CpuVector cpuVector(size, (real*)cpu.getData()); + GpuVector gpuVector(size, (real*)gpu.getData()); + + autotest::TensorCheckErr(cpuVector, gpuVector); + } + }; + checkArgs(cpuOutputs, gpuOutputs); + checkArgs(cpuInouts, gpuInouts); + } + +protected: + std::shared_ptr cpu; + std::shared_ptr gpu; + std::vector cpuMemory; + std::vector gpuMemory; + Arguments cpuInputs; + Arguments cpuOutputs; + Arguments cpuInouts; + Arguments gpuInputs; + Arguments gpuOutputs; + Arguments gpuInouts; +}; + +} // namespace paddle + +using paddle::FunctionCompare; +using paddle::FuncConfig; +using paddle::Dims; +using paddle::Tensor; diff --git a/paddle/function/TestMain.cpp b/paddle/function/TestMain.cpp new file mode 100644 index 0000000000..3e14532d18 --- /dev/null +++ b/paddle/function/TestMain.cpp @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/utils/Util.h" + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + paddle::initMain(argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/function/cross_map_normal_op_test.cpp b/paddle/function/cross_map_normal_op_test.cpp new file mode 100644 index 0000000000..22692691bd --- /dev/null +++ b/paddle/function/cross_map_normal_op_test.cpp @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +TEST(CrossMapNormal, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + for (size_t size : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " size=" << size; + + FunctionCompare compare("CrossMapNormal", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); + Dims dims{numSamples, channels, imgSizeH, imgSizeW}; + compare.cmpWithArg({Tensor(nullptr, dims)}, + {Tensor(nullptr, dims), Tensor(nullptr, dims)}, + {}); + } + } + } + } + } +} + +TEST(CrossMapNormalGrad, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {1, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + for (size_t size : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " size=" << size; + + FunctionCompare compare("CrossMapNormalGrad", + FuncConfig() + .set("size", size) + .set("scale", (real)1.5) + .set("pow", (real)0.5)); + Dims dims{numSamples, channels, imgSizeH, imgSizeW}; + compare.cmpWithArg({Tensor(nullptr, dims), + Tensor(nullptr, dims), + Tensor(nullptr, dims), + Tensor(nullptr, dims)}, + {Tensor(nullptr, dims)}, + {}); + } + } + } + } + } +} diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index c89b7ff490..440534e722 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1263,150 +1263,6 @@ TEST(Matrix, MaxOutFwdBwd) { } } -void testCrossMapNormalFwd( - int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { - float scale = 1.5; - float pow = 0.5; - int width = imgSizeH * imgSizeW * channels; - CpuMatrix inputs(numSamples, width); - CpuMatrix denoms(numSamples, width); - CpuMatrix outputs(numSamples, width); - GpuMatrix inputsGpu(numSamples, width); - GpuMatrix denomsGpu(numSamples, width); - GpuMatrix outputsGpu(numSamples, width); - - inputs.randomizeUniform(); - outputs.randomizeUniform(); - inputsGpu.copyFrom(inputs); - outputsGpu.copyFrom(outputs); - - FunctionBase* cpu = - FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, CPU)); - FunctionBase* gpu = - FunctionBase::funcRegistrar_.createByType(FUNC_NAME(CrossMapNormal, GPU)); - cpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - gpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - - Dims dims{ - (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; - cpu->calc({Tensor(inputs.getData(), dims)}, - {Tensor(outputs.getData(), dims), Tensor(denoms.getData(), dims)}, - {}); - - gpu->calc( - {Tensor(inputsGpu.getData(), dims)}, - {Tensor(outputsGpu.getData(), dims), Tensor(denomsGpu.getData(), dims)}, - {}); - - TensorCheckErr(outputs, outputsGpu); - TensorCheckErr(denoms, denomsGpu); -} - -TEST(Matrix, crossMapNormalFwd) { - for (auto numSamples : {5, 32}) { - for (auto channels : {1, 5, 32}) { - for (auto imgSizeH : {5, 33, 100}) { - for (auto imgSizeW : {5, 32, 96}) { - for (auto sizeX : {1, 2, 3, 5, 7}) { - VLOG(3) << " numSamples=" << numSamples << " channels=" << channels - << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW - << " sizeX=" << sizeX; - testCrossMapNormalFwd( - numSamples, channels, imgSizeH, imgSizeW, sizeX); - } - } - } - } - } -} - -void testCrossMapNormalBwd( - int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { - float scale = 1.5; - float pow = 0.5; - size_t width = imgSizeH * imgSizeW * channels; - - CpuMatrix inputsGrad(numSamples, width); - CpuMatrix inputsValue(numSamples, width); - CpuMatrix outputsGrad(numSamples, width); - CpuMatrix outputsValue(numSamples, width); - CpuMatrix denoms(numSamples, width); - - outputsGrad.randomizeUniform(); - denoms.randomizeUniform(); - inputsValue.randomizeUniform(); - outputsValue.randomizeUniform(); - inputsGrad.randomizeUniform(); - denoms.add(0.01); - - GpuMatrix inputsGradGpu(numSamples, width); - GpuMatrix inputsValueGpu(numSamples, width); - GpuMatrix outputsGradGpu(numSamples, width); - GpuMatrix outputsValueGpu(numSamples, width); - GpuMatrix denomsGpu(numSamples, width); - - outputsGradGpu.copyFrom(outputsGrad); - denomsGpu.copyFrom(denoms); - inputsValueGpu.copyFrom(inputsValue); - outputsValueGpu.copyFrom(outputsValue); - inputsGradGpu.copyFrom(inputsGrad); - - FunctionBase* cpu = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, CPU)); - FunctionBase* gpu = FunctionBase::funcRegistrar_.createByType( - FUNC_NAME(CrossMapNormalGrad, GPU)); - cpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - gpu->init(FuncConfig() - .set("size", (size_t)sizeX) - .set("scale", scale) - .set("pow", pow)); - - Dims dims{ - (size_t)numSamples, (size_t)channels, (size_t)imgSizeH, (size_t)imgSizeW}; - cpu->calc({Tensor(inputsValue.getData(), dims), - Tensor(outputsValue.getData(), dims), - Tensor(outputsGrad.getData(), dims), - Tensor(denoms.getData(), dims)}, - {Tensor(inputsGrad.getData(), dims)}, - {}); - - gpu->calc({Tensor(inputsValueGpu.getData(), dims), - Tensor(outputsValueGpu.getData(), dims), - Tensor(outputsGradGpu.getData(), dims), - Tensor(denomsGpu.getData(), dims)}, - {Tensor(inputsGradGpu.getData(), dims)}, - {}); - - TensorCheckErr(inputsGrad, inputsGradGpu); -} - -TEST(Matrix, crossMapNormalBwd) { - for (auto numSamples : {5, 32}) { - for (auto channels : {1, 5, 32}) { - for (auto imgSizeH : {5, 33, 100}) { - for (auto imgSizeW : {5, 32, 96}) { - for (auto sizeX : {1, 2, 3, 5, 7}) { - VLOG(3) << " numSamples=" << numSamples << " channels=" << channels - << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW - << " sizeX=" << sizeX; - testCrossMapNormalBwd( - numSamples, channels, imgSizeH, imgSizeW, sizeX); - } - } - } - } - } -} - int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); From 6e405a10c54fd8f5695832663668a51f4ed19c2b Mon Sep 17 00:00:00 2001 From: Peng Li Date: Tue, 20 Dec 2016 17:35:37 +0800 Subject: [PATCH 24/26] fix style issues --- paddle/gserver/evaluators/ChunkEvaluator.cpp | 6 +++--- proto/ModelConfig.proto | 6 +++--- python/paddle/trainer_config_helpers/evaluators.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/gserver/evaluators/ChunkEvaluator.cpp b/paddle/gserver/evaluators/ChunkEvaluator.cpp index 15e0e95206..13f02e51fe 100644 --- a/paddle/gserver/evaluators/ChunkEvaluator.cpp +++ b/paddle/gserver/evaluators/ChunkEvaluator.cpp @@ -162,9 +162,9 @@ public: getSegments(label, length, labelSegments_); size_t i = 0, j = 0; while (i < outputSegments_.size() && j < labelSegments_.size()) { - if (outputSegments_[i] == labelSegments_[j]) { - if (excludedChunkTypes_.count(outputSegments_[i].type) != 1) - ++numCorrect_; + if (outputSegments_[i] == labelSegments_[j] && + excludedChunkTypes_.count(outputSegments_[i].type) != 1) { + ++numCorrect_; } if (outputSegments_[i].end < labelSegments_[j].end) { ++i; diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index e24ed21fbb..be4d0041f9 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -437,8 +437,6 @@ message EvaluatorConfig { optional string chunk_scheme = 4; // number of chunk types other than "other" optional int32 num_chunk_types = 5; - // chunk of these types are not counted - repeated int32 excluded_chunk_types = 12; // Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator // For multi binary labels: true if output > classification_threshold @@ -458,7 +456,9 @@ message EvaluatorConfig { // whether to delimit the sequence in the seq_text_printer optional bool delimited = 11 [default = true]; - // NOTE: 12 has been occupied by excluded_chunk_types + // Used by ChunkEvaluator + // chunk of these types are not counted + repeated int32 excluded_chunk_types = 12; } message LinkConfig { diff --git a/python/paddle/trainer_config_helpers/evaluators.py b/python/paddle/trainer_config_helpers/evaluators.py index 731e30d367..bd247ea9af 100644 --- a/python/paddle/trainer_config_helpers/evaluators.py +++ b/python/paddle/trainer_config_helpers/evaluators.py @@ -381,7 +381,7 @@ def chunk_evaluator( :param name: The Evaluator name, it is optional. :type name: basename|None :param excluded_chunk_types: chunks of these types are not considered - :type excluded_chunk_types: list of integer|[] + :type excluded_chunk_types: list of integer|None """ evaluator_base( name=name, From 0d1703d91ff02bb1ba51d164db8f29a4b9ed161c Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 20 Dec 2016 17:41:41 +0800 Subject: [PATCH 25/26] Add const in ParameterUpdater init --- paddle/parameter/ParameterUpdaterBase.cpp | 2 +- paddle/parameter/ParameterUpdaterBase.h | 4 ++-- paddle/trainer/ParameterUpdater.cpp | 3 ++- paddle/trainer/ParameterUpdater.h | 4 ++-- paddle/trainer/RemoteParameterUpdater.cpp | 7 ++++--- paddle/trainer/RemoteParameterUpdater.h | 6 +++--- paddle/trainer/ThreadParameterUpdater.cpp | 2 +- paddle/trainer/ThreadParameterUpdater.h | 2 +- 8 files changed, 16 insertions(+), 14 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterBase.cpp b/paddle/parameter/ParameterUpdaterBase.cpp index 49e2ae2b39..458cae886a 100644 --- a/paddle/parameter/ParameterUpdaterBase.cpp +++ b/paddle/parameter/ParameterUpdaterBase.cpp @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { -void ParameterUpdater::init(std::vector& parameters) { +void ParameterUpdater::init(const std::vector& parameters) { parameters_ = parameters; for (ParameterType type : getParameterTypes()) { for (auto& para : parameters) { diff --git a/paddle/parameter/ParameterUpdaterBase.h b/paddle/parameter/ParameterUpdaterBase.h index 5401046f67..88148d9b76 100644 --- a/paddle/parameter/ParameterUpdaterBase.h +++ b/paddle/parameter/ParameterUpdaterBase.h @@ -32,7 +32,7 @@ public: parameterTypes_.push_back(type); } - virtual void init(std::vector& parameters); + virtual void init(const std::vector& parameters); // called by Trainer when starting a new pass virtual void startPass() {} @@ -105,7 +105,7 @@ public: ParameterUpdaterComposite() {} virtual ~ParameterUpdaterComposite() {} - virtual void init(std::vector& parameters) = 0; + virtual void init(const std::vector& parameters) = 0; virtual void startPass() { syncThreadPool_->execPlusOwner( diff --git a/paddle/trainer/ParameterUpdater.cpp b/paddle/trainer/ParameterUpdater.cpp index 8b5b95da5b..4e9e890c85 100644 --- a/paddle/trainer/ParameterUpdater.cpp +++ b/paddle/trainer/ParameterUpdater.cpp @@ -34,7 +34,8 @@ SgdUpdaterWithCpuAverager::SgdUpdaterWithCpuAverager( updateWorker_.addJob([]() { hl_set_device(FLAGS_gpu_id); }); } -void SgdUpdaterWithCpuAverager::init(std::vector& parameters) { +void SgdUpdaterWithCpuAverager::init( + const std::vector& parameters) { SgdLocalUpdater::init(parameters); averager_->init(parameters_.size(), nullptr); copyEvents_.resize(parameters_.size()); diff --git a/paddle/trainer/ParameterUpdater.h b/paddle/trainer/ParameterUpdater.h index e52b5cd318..4dae77567f 100644 --- a/paddle/trainer/ParameterUpdater.h +++ b/paddle/trainer/ParameterUpdater.h @@ -64,7 +64,7 @@ public: * be initialized. * @param parameters The parameter need to be initialized. */ - virtual void init(std::vector& parameters) { + virtual void init(const std::vector& parameters) { ParameterUpdater::init(parameters); optimizer_->init(parameters_.size(), nullptr); // check no L1 decay in parameter configs @@ -208,7 +208,7 @@ public: * @brief init. Initialize cpu parameters, model average optimizer. * @param parameters */ - virtual void init(std::vector& parameters); + virtual void init(const std::vector& parameters); virtual PassType startBatch(int64_t batchSize) { averager_->startBatch(-1UL); diff --git a/paddle/trainer/RemoteParameterUpdater.cpp b/paddle/trainer/RemoteParameterUpdater.cpp index 974e78fa17..630f55d998 100644 --- a/paddle/trainer/RemoteParameterUpdater.cpp +++ b/paddle/trainer/RemoteParameterUpdater.cpp @@ -44,7 +44,7 @@ RemoteParameterUpdater::RemoteParameterUpdater( addParameterType(PARAMETER_MOMENTUM); } -void RemoteParameterUpdater::init(std::vector& parameters) { +void RemoteParameterUpdater::init(const std::vector& parameters) { ParameterUpdater::init(parameters); if (localUpdater_) { @@ -595,7 +595,8 @@ SparseRemoteParameterUpdater::SparseRemoteParameterUpdater( testing_(testing), useApplyInPserver_(false) {} -void SparseRemoteParameterUpdater::init(std::vector& parameters) { +void SparseRemoteParameterUpdater::init( + const std::vector& parameters) { ParameterUpdater::init(parameters); parameterClient_.reset(new ParameterClient2( @@ -809,7 +810,7 @@ void SparseRemoteParameterUpdater::saveParametersRemote( } void SparseRemoteParameterUpdaterComposite::init( - std::vector& parameters) { + const std::vector& parameters) { parameters_ = parameters; std::vector parametersArray[NUMBER_UPDATERS]; diff --git a/paddle/trainer/RemoteParameterUpdater.h b/paddle/trainer/RemoteParameterUpdater.h index 66055c778e..ec6ed443d3 100644 --- a/paddle/trainer/RemoteParameterUpdater.h +++ b/paddle/trainer/RemoteParameterUpdater.h @@ -67,7 +67,7 @@ public: /** * initialize the internal parameter client and itself. */ - virtual void init(std::vector& parameters); + virtual void init(const std::vector& parameters); /** * @brief start batch * @@ -274,7 +274,7 @@ public: } /// initialization - virtual void init(std::vector& parameters); + virtual void init(const std::vector& parameters); /// stateful batch control virtual PassType startBatch(int64_t batchSize); @@ -360,7 +360,7 @@ public: } /// initialization of dense and sparse updaters - virtual void init(std::vector& parameters); + virtual void init(const std::vector& parameters); }; class ParameterUpdaterCreators { diff --git a/paddle/trainer/ThreadParameterUpdater.cpp b/paddle/trainer/ThreadParameterUpdater.cpp index 049022b1f1..2a76d5723c 100644 --- a/paddle/trainer/ThreadParameterUpdater.cpp +++ b/paddle/trainer/ThreadParameterUpdater.cpp @@ -32,7 +32,7 @@ SgdThreadUpdater::SgdThreadUpdater(const OptimizationConfig& optConfig) } } -void SgdThreadUpdater::init(std::vector& parameters) { +void SgdThreadUpdater::init(const std::vector& parameters) { ParameterUpdater::init(parameters); // calc max parameter id diff --git a/paddle/trainer/ThreadParameterUpdater.h b/paddle/trainer/ThreadParameterUpdater.h index d01ac689f9..198435c0f3 100644 --- a/paddle/trainer/ThreadParameterUpdater.h +++ b/paddle/trainer/ThreadParameterUpdater.h @@ -49,7 +49,7 @@ public: // Use the finishPass() function of the base optimizer. virtual bool finishPass(real cost); - virtual void init(std::vector& parameters); + virtual void init(const std::vector& parameters); virtual PassType startBatch(int64_t batchSize); // Call finishBatch for each optimizer. virtual void finishBatch(real cost); From f1a94e3ff7fce800f6c846da2ae6ad4312c4acfc Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 20 Dec 2016 20:30:13 +0800 Subject: [PATCH 26/26] follow comments --- paddle/function/cross_map_normal_op.cpp | 22 +++++++++++----------- paddle/function/cross_map_normal_op.h | 10 +++++----- paddle/function/cross_map_normal_op_gpu.cu | 10 +++++----- paddle/math/tests/test_matrixCompare.cpp | 1 - 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/paddle/function/cross_map_normal_op.cpp b/paddle/function/cross_map_normal_op.cpp index a18c0bb750..a9c7693830 100644 --- a/paddle/function/cross_map_normal_op.cpp +++ b/paddle/function/cross_map_normal_op.cpp @@ -20,7 +20,7 @@ namespace paddle { template <> void CrossMapNormal(real* outputs, real* denoms, - real* inputs, + const real* inputs, size_t numSamples, size_t channels, size_t height, @@ -32,7 +32,7 @@ void CrossMapNormal(real* outputs, size_t oneSample = channels * oneImage; CpuVector outputsV(numSamples * oneSample, outputs); - CpuVector inputsV(numSamples * oneSample, inputs); + CpuVector inputsV(numSamples * oneSample, const_cast(inputs)); CpuVector denomsV(numSamples * oneSample, denoms); // f(x) = x * ( 1 + scale * SUM((x)^2) )^(-pow) @@ -44,7 +44,7 @@ void CrossMapNormal(real* outputs, const int end = (int)size + start; for (size_t i = 0; i < numSamples; i++) { real* oneDenom = denoms + i * oneSample; - real* oneInput = inputs + i * oneSample; + real* oneInput = const_cast(inputs) + i * oneSample; for (int c = 0; c < (int)channels; c++) { CpuVector denom(oneImage, oneDenom + c * oneImage); for (int s = start; s < end; s++) { @@ -61,10 +61,10 @@ void CrossMapNormal(real* outputs, template <> void CrossMapNormalGrad(real* inputsGrad, - real* inputsValue, - real* outputsValue, - real* outputsGrad, - real* denoms, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, size_t numSamples, size_t channels, size_t height, @@ -84,10 +84,10 @@ void CrossMapNormalGrad(real* inputsGrad, for (size_t i = 0; i < numSamples; i++) { size_t sOffset = i * oneSample; real* oneInputGrad = inputsGrad + sOffset; - real* oneInputValue = inputsValue + sOffset; - real* oneDenom = denoms + sOffset; - real* oneOutputGrad = outputsGrad + sOffset; - real* oneOutputValue = outputsValue + sOffset; + real* oneInputValue = const_cast(inputsValue) + sOffset; + real* oneDenom = const_cast(denoms) + sOffset; + real* oneOutputGrad = const_cast(outputsGrad) + sOffset; + real* oneOutputValue = const_cast(outputsValue) + sOffset; for (int c = 0; c < (int)channels; c++) { size_t cOffset = c * height * width; diff --git a/paddle/function/cross_map_normal_op.h b/paddle/function/cross_map_normal_op.h index e935b26e12..b1e401ad0a 100644 --- a/paddle/function/cross_map_normal_op.h +++ b/paddle/function/cross_map_normal_op.h @@ -37,7 +37,7 @@ namespace paddle { template void CrossMapNormal(real* outputs, real* denoms, - real* inputs, + const real* inputs, size_t numSamples, size_t channels, size_t height, @@ -66,10 +66,10 @@ void CrossMapNormal(real* outputs, */ template void CrossMapNormalGrad(real* inputsGrad, - real* inputsValue, - real* outputsValue, - real* outputsGrad, - real* denoms, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, size_t numSamples, size_t channels, size_t height, diff --git a/paddle/function/cross_map_normal_op_gpu.cu b/paddle/function/cross_map_normal_op_gpu.cu index 6339c04194..aae4f461b6 100644 --- a/paddle/function/cross_map_normal_op_gpu.cu +++ b/paddle/function/cross_map_normal_op_gpu.cu @@ -63,7 +63,7 @@ __global__ void KeCMRNormOutput(size_t inputSize, const real* in, template <> void CrossMapNormal(real* outputs, real* denoms, - real* inputs, + const real* inputs, size_t numSamples, size_t channels, size_t height, @@ -132,10 +132,10 @@ __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, template <> void CrossMapNormalGrad(real* inputsGrad, - real* inputsValue, - real* outputsValue, - real* outputsGrad, - real* denoms, + const real* inputsValue, + const real* outputsValue, + const real* outputsGrad, + const real* denoms, size_t numSamples, size_t channels, size_t height, diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 440534e722..62de5b25e4 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "TensorCheck.h" -#include "paddle/function/Function.h" #include "paddle/gserver/tests/TestUtil.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h"