From db694172bed8ee621e468546e9bf4c4c42e92602 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 2 Nov 2017 10:28:27 +0800 Subject: [PATCH 01/30] Add edit distance operator --- paddle/operators/ctc_edit_distance_op.cc | 74 ++++++++++++++++++ paddle/operators/ctc_edit_distance_op.h | 78 +++++++++++++++++++ .../framework/tests/test_ctc_edit_distance.py | 60 ++++++++++++++ 3 files changed, 212 insertions(+) create mode 100644 paddle/operators/ctc_edit_distance_op.cc create mode 100644 paddle/operators/ctc_edit_distance_op.h create mode 100644 python/paddle/v2/framework/tests/test_ctc_edit_distance.py diff --git a/paddle/operators/ctc_edit_distance_op.cc b/paddle/operators/ctc_edit_distance_op.cc new file mode 100644 index 0000000000..7b45ccc72e --- /dev/null +++ b/paddle/operators/ctc_edit_distance_op.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/ctc_edit_distance_op.h" + +namespace paddle { +namespace operators { + +class CTCEditDistanceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); + ctx->SetOutputDim("Out", {1}); + } +}; + +class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CTCEditDistanceOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X1", + "(2-D tensor with shape [M x 1]) The indices for " + "hypothesis string"); + AddInput("X2", + "(2-D tensor with shape [batch_size x 1]) The indices " + "for reference string."); + AddAttr("normalized", + "(bool, default false) Indicated whether " + "normalize. the Output(Out) by the length of reference " + "string (X2).") + .SetDefault(false); + AddOutput("Out", + "(2-D tensor with shape [1 x 1]) " + "The output distance of CTCEditDistance operator."); + AddComment(R"DOC( + +CTCEditDistance operator computes the edit distance of two sequences, one named +hypothesis and another named reference. + +Edit distance measures how dissimilar two strings, one is hypothesis and another +is reference, are by counting the minimum number of operations to transform +one string into anthor. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp, + ops::CTCEditDistanceOpMaker); +REGISTER_OP_CPU_KERNEL( + ctc_edit_distance, + ops::CTCEditDistanceKernel, + ops::CTCEditDistanceKernel); diff --git a/paddle/operators/ctc_edit_distance_op.h b/paddle/operators/ctc_edit_distance_op.h new file mode 100644 index 0000000000..d0494b4b1b --- /dev/null +++ b/paddle/operators/ctc_edit_distance_op.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class CTCEditDistanceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + + auto* x1_t = ctx.Input("X1"); + auto* x2_t = ctx.Input("X2"); + + out_t->mutable_data(ctx.GetPlace()); + + auto normalized = ctx.Attr("normalized"); + + auto m = x1_t->numel(); + auto n = x2_t->numel(); + float distance = 0.0; + if (m == 0) { + distance = n; + } else if (n == 0) { + distance = m; + } else { + framework::Tensor dist_t; + dist_t.Resize({m + 1, n + 1}); + dist_t.mutable_data(ctx.GetPlace()); + auto dist = dist_t.data(); + auto x1 = x1_t->data(); + auto x2 = x2_t->data(); + for (int i = 0; i < m + 1; ++i) { + dist[i * (n + 1)] = i; // dist[i][0] = i; + } + for (int j = 0; j < n + 1; ++j) { + dist[j] = j; // dist[0][j] = j; + } + for (int i = 1; i < m + 1; ++i) { + for (int j = 1; j < n + 1; ++j) { + int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; + int deletions = dist[(i - 1) * (n + 1) + j] + 1; + int insertions = dist[i * (n + 1) + (j - 1)] + 1; + int substitutions = dist[(i - 1) * (n + 1) + (j - 1)] + cost; + dist[i * (n + 1) + j] = + std::min(deletions, std::min(insertions, substitutions)); + } + } + distance = dist[m * (n + 1) + n]; + } + + if (normalized) { + distance = distance / n; + } + auto out = out_t->data(); + out[0] = distance; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_ctc_edit_distance.py b/python/paddle/v2/framework/tests/test_ctc_edit_distance.py new file mode 100644 index 0000000000..a6d9dfdf06 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_ctc_edit_distance.py @@ -0,0 +1,60 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def Levenshtein(hyp, ref): + """ Compute the Levenshtein distance between two strings. + + :param hyp: + :type hyp: list + :param ref: + :type ref: list + """ + m = len(hyp) + n = len(ref) + if m == 0: + return n + if n == 0: + return m + + dist = np.zeros((m + 1, n + 1)) + for i in range(0, m + 1): + dist[i][0] = i + for j in range(0, n + 1): + dist[0][j] = j + + for i in range(1, m + 1): + for j in range(1, n + 1): + cost = 0 if hyp[i - 1] == ref[j - 1] else 1 + deletion = dist[i - 1][j] + 1 + insertion = dist[i][j - 1] + 1 + substitution = dist[i - 1][j - 1] + cost + dist[i][j] = min(deletion, insertion, substitution) + return dist[m][n] + + +class TestCTCEditDistanceOp(OpTest): + def setUp(self): + self.op_type = "ctc_edit_distance" + normalized = True + x1 = np.array([0, 12, 3, 5]).astype("int64") + x2 = np.array([0, 12, 4, 7, 8]).astype("int64") + + distance = Levenshtein(hyp=x1, ref=x2) + if normalized is True: + distance = distance / len(x2) + print "distance = ", distance + self.attrs = {'normalized': normalized} + self.inputs = {'X1': x1, 'X2': x2} + self.outputs = {'Out': distance} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + #x1 = ['c', 'a', 'f', 'e'] + #x2 = ['c', 'o', 'f', 'f', 'e', 'e'] + #print Levenshtein(x1, x2) + unittest.main() From b7a4e3d72c58a6ca39b6434888c8144558d0bdbb Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 2 Nov 2017 12:12:31 +0800 Subject: [PATCH 02/30] rename some variables in ctc_edit_distance_op --- paddle/operators/ctc_edit_distance_op.cc | 4 ++-- paddle/operators/ctc_edit_distance_op.h | 21 +++++++++---------- .../framework/tests/test_ctc_edit_distance.py | 8 ++----- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/paddle/operators/ctc_edit_distance_op.cc b/paddle/operators/ctc_edit_distance_op.cc index 7b45ccc72e..fae5cfc117 100644 --- a/paddle/operators/ctc_edit_distance_op.cc +++ b/paddle/operators/ctc_edit_distance_op.cc @@ -38,11 +38,11 @@ class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { "(2-D tensor with shape [M x 1]) The indices for " "hypothesis string"); AddInput("X2", - "(2-D tensor with shape [batch_size x 1]) The indices " + "(2-D tensor with shape [N x 1]) The indices " "for reference string."); AddAttr("normalized", "(bool, default false) Indicated whether " - "normalize. the Output(Out) by the length of reference " + "normalize the Output(Out) by the length of reference " "string (X2).") .SetDefault(false); AddOutput("Out", diff --git a/paddle/operators/ctc_edit_distance_op.h b/paddle/operators/ctc_edit_distance_op.h index d0494b4b1b..a52960f1ef 100644 --- a/paddle/operators/ctc_edit_distance_op.h +++ b/paddle/operators/ctc_edit_distance_op.h @@ -47,20 +47,19 @@ class CTCEditDistanceKernel : public framework::OpKernel { auto dist = dist_t.data(); auto x1 = x1_t->data(); auto x2 = x2_t->data(); - for (int i = 0; i < m + 1; ++i) { - dist[i * (n + 1)] = i; // dist[i][0] = i; + for (size_t i = 0; i < m + 1; ++i) { + dist[i * (n + 1)] = i; } - for (int j = 0; j < n + 1; ++j) { - dist[j] = j; // dist[0][j] = j; + for (size_t j = 0; j < n + 1; ++j) { + dist[j] = j; } - for (int i = 1; i < m + 1; ++i) { - for (int j = 1; j < n + 1; ++j) { + for (size_t i = 1; i < m + 1; ++i) { + for (size_t j = 1; j < n + 1; ++j) { int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; - int deletions = dist[(i - 1) * (n + 1) + j] + 1; - int insertions = dist[i * (n + 1) + (j - 1)] + 1; - int substitutions = dist[(i - 1) * (n + 1) + (j - 1)] + cost; - dist[i * (n + 1) + j] = - std::min(deletions, std::min(insertions, substitutions)); + int dels = dist[(i - 1) * (n + 1) + j] + 1; + int ins = dist[i * (n + 1) + (j - 1)] + 1; + int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; + dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); } } distance = dist[m * (n + 1) + n]; diff --git a/python/paddle/v2/framework/tests/test_ctc_edit_distance.py b/python/paddle/v2/framework/tests/test_ctc_edit_distance.py index a6d9dfdf06..8a6b0b5390 100644 --- a/python/paddle/v2/framework/tests/test_ctc_edit_distance.py +++ b/python/paddle/v2/framework/tests/test_ctc_edit_distance.py @@ -6,9 +6,9 @@ from op_test import OpTest def Levenshtein(hyp, ref): """ Compute the Levenshtein distance between two strings. - :param hyp: + :param hyp: hypothesis string in index :type hyp: list - :param ref: + :param ref: reference string in index :type ref: list """ m = len(hyp) @@ -44,7 +44,6 @@ class TestCTCEditDistanceOp(OpTest): distance = Levenshtein(hyp=x1, ref=x2) if normalized is True: distance = distance / len(x2) - print "distance = ", distance self.attrs = {'normalized': normalized} self.inputs = {'X1': x1, 'X2': x2} self.outputs = {'Out': distance} @@ -54,7 +53,4 @@ class TestCTCEditDistanceOp(OpTest): if __name__ == '__main__': - #x1 = ['c', 'a', 'f', 'e'] - #x2 = ['c', 'o', 'f', 'f', 'e', 'e'] - #print Levenshtein(x1, x2) unittest.main() From 6bc6ccd187b3a0dcb6980a9d0c5090f3e5d16150 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 29 Nov 2017 05:49:44 +0000 Subject: [PATCH 03/30] add gpu kernel for ctc_edit_distance_op --- paddle/operators/ctc_edit_distance_op.cu | 130 ++++++++++++++++++ .../tests/test_ctc_edit_distance_op.py} | 8 +- 2 files changed, 135 insertions(+), 3 deletions(-) create mode 100644 paddle/operators/ctc_edit_distance_op.cu rename python/paddle/v2/{framework/tests/test_ctc_edit_distance.py => fluid/tests/test_ctc_edit_distance_op.py} (84%) diff --git a/paddle/operators/ctc_edit_distance_op.cu b/paddle/operators/ctc_edit_distance_op.cu new file mode 100644 index 0000000000..872268296e --- /dev/null +++ b/paddle/operators/ctc_edit_distance_op.cu @@ -0,0 +1,130 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/framework/op_registry.h" +#include "paddle/platform/cuda_helper.h" +#include "paddle/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void FillFirstRow(T* dist, const int N) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < N + 1) { + dist[idx] = idx; + } +} + +template +__global__ void FillFirstColumn(T* dist, const int M, const int N) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < M + 1) { + dist[idx * (N + 1)] = idx; + } +} + +template +__global__ void Levenshtein(T* dist, const T* x1, const T* x2, const int M, + const int N, const int start) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = N; + int index = start + idx * offset; + int row = index / (N + 1); + int col = index % (N + 1); + if (row > 0 && col > 0 && row < M + 1 && col < N + 1) { + int cost = x1[row - 1] == x2[col - 1] ? 0 : 1; + int dels = dist[(row - 1) * (N + 1) + col] + 1; + int ins = dist[row * (N + 1) + col - 1] + 1; + int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost; + dist[index] = min(dels, min(ins, subs)); + } +} + +template +class CTCEditDistanceGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + + auto* x1_t = ctx.Input("X1"); + auto* x2_t = ctx.Input("X2"); + + out_t->mutable_data(ctx.GetPlace()); + + auto normalized = ctx.Attr("normalized"); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + + auto m = x1_t->numel(); + auto n = x2_t->numel(); + T distance = 0; + if (m == 0) { + distance = n; + } else if (n == 0) { + distance = m; + } else { + framework::Tensor dist_t; + dist_t.Resize({m + 1, n + 1}); + dist_t.mutable_data(ctx.GetPlace()); + auto dist = dist_t.data(); + auto x1 = x1_t->data(); + auto x2 = x2_t->data(); + + FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); + + FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); + // compute the elements of distance matrix in the anti-diagonal diretion + for (size_t slice = 2; slice < m + n + 1; ++slice) { + int z_m = slice < m + 1 ? 0 : slice - m; + int z_n = slice < n + 1 ? 0 : slice - n; + // number of elments in the same anti-diagonal line + int size = slice - (z_m + z_n) + 1; + int start = slice < n + 1 ? slice : z_n * (n + 1) - 1; + Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m, + n, start); + } + + Place gpu_place = boost::get(ctx.GetPlace()); + memory::Copy(platform::CPUPlace(), &distance, gpu_place, + dist + m * (n + 1) + n, sizeof(T), stream); + } + + if (normalized) { + distance = distance / n; + } + auto out = out_t->data(); + Place gpu_place = boost::get(ctx.GetPlace()); + float dist_f = distance; + memory::Copy(gpu_place, out, platform::CPUPlace(), &dist_f, sizeof(float), + stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + ctc_edit_distance, + ops::CTCEditDistanceGPUKernel, + ops::CTCEditDistanceGPUKernel); diff --git a/python/paddle/v2/framework/tests/test_ctc_edit_distance.py b/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py similarity index 84% rename from python/paddle/v2/framework/tests/test_ctc_edit_distance.py rename to python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py index 8a6b0b5390..6694a6ee29 100644 --- a/python/paddle/v2/framework/tests/test_ctc_edit_distance.py +++ b/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py @@ -37,9 +37,11 @@ def Levenshtein(hyp, ref): class TestCTCEditDistanceOp(OpTest): def setUp(self): self.op_type = "ctc_edit_distance" - normalized = True - x1 = np.array([0, 12, 3, 5]).astype("int64") - x2 = np.array([0, 12, 4, 7, 8]).astype("int64") + normalized = False + #x1 = np.array([0, 12, 3, 5]).astype("int64") + #x2 = np.array([0, 12, 4, 7, 8]).astype("int64") + x1 = np.array([0, 12, 5]).astype("int64") + x2 = np.array([0, 12, 4]).astype("int64") distance = Levenshtein(hyp=x1, ref=x2) if normalized is True: From 116687a8ee8dab5938f8783428b4b5f416a443f5 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 29 Nov 2017 09:07:15 +0000 Subject: [PATCH 04/30] clean up code in ctc_edit_distance_op --- paddle/operators/ctc_edit_distance_op.cc | 10 +++- paddle/operators/ctc_edit_distance_op.cu | 59 ++++++++++--------- paddle/operators/ctc_edit_distance_op.h | 16 ++--- .../fluid/tests/test_ctc_edit_distance_op.py | 8 +-- 4 files changed, 49 insertions(+), 44 deletions(-) diff --git a/paddle/operators/ctc_edit_distance_op.cc b/paddle/operators/ctc_edit_distance_op.cc index fae5cfc117..d2f4ce67c2 100644 --- a/paddle/operators/ctc_edit_distance_op.cc +++ b/paddle/operators/ctc_edit_distance_op.cc @@ -27,6 +27,13 @@ class CTCEditDistanceOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); ctx->SetOutputDim("Out", {1}); } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(framework::DataType::FP32, + ctx.device_context()); + } }; class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { @@ -70,5 +77,4 @@ REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp, ops::CTCEditDistanceOpMaker); REGISTER_OP_CPU_KERNEL( ctc_edit_distance, - ops::CTCEditDistanceKernel, - ops::CTCEditDistanceKernel); + ops::CTCEditDistanceKernel); diff --git a/paddle/operators/ctc_edit_distance_op.cu b/paddle/operators/ctc_edit_distance_op.cu index 872268296e..22871acc4e 100644 --- a/paddle/operators/ctc_edit_distance_op.cu +++ b/paddle/operators/ctc_edit_distance_op.cu @@ -39,7 +39,7 @@ __global__ void FillFirstColumn(T* dist, const int M, const int N) { } template -__global__ void Levenshtein(T* dist, const T* x1, const T* x2, const int M, +__global__ void Levenshtein(T* dist, const int* x1, const int* x2, const int M, const int N, const int start) { int idx = blockDim.x * blockIdx.x + threadIdx.x; int offset = N; @@ -55,6 +55,15 @@ __global__ void Levenshtein(T* dist, const T* x1, const T* x2, const int M, } } +template +__global__ void SetOutput(T* out, const T* dist, const int M, const int N, + bool normalized) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx == 0) { + out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N]; + } +} + template class CTCEditDistanceGPUKernel : public framework::OpKernel { public: @@ -64,7 +73,8 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel { auto* x1_t = ctx.Input("X1"); auto* x2_t = ctx.Input("X2"); - out_t->mutable_data(ctx.GetPlace()); + out_t->mutable_data(ctx.GetPlace()); + auto out = out_t->data(); auto normalized = ctx.Attr("normalized"); auto stream = reinterpret_cast( @@ -73,49 +83,41 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel { auto m = x1_t->numel(); auto n = x2_t->numel(); - T distance = 0; - if (m == 0) { - distance = n; - } else if (n == 0) { - distance = m; + T distance = 0.0; + if (m == 0 || n == 0) { + distance = std::max(m, n); + if (normalized) { + distance = distance / n; + } + memory::Copy(boost::get(ctx.GetPlace()), out, platform::CPUPlace(), + &distance, sizeof(T), stream); } else { framework::Tensor dist_t; dist_t.Resize({m + 1, n + 1}); dist_t.mutable_data(ctx.GetPlace()); auto dist = dist_t.data(); - auto x1 = x1_t->data(); - auto x2 = x2_t->data(); + auto x1 = x1_t->data(); + auto x2 = x2_t->data(); FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); - // compute the elements of distance matrix in the anti-diagonal diretion - for (size_t slice = 2; slice < m + n + 1; ++slice) { + // Compute the elements of distance matrix in the anti-diagonal diretion + for (int64_t slice = 2; slice < m + n + 1; ++slice) { int z_m = slice < m + 1 ? 0 : slice - m; int z_n = slice < n + 1 ? 0 : slice - n; - // number of elments in the same anti-diagonal line - int size = slice - (z_m + z_n) + 1; - int start = slice < n + 1 ? slice : z_n * (n + 1) - 1; + int size = slice - (z_m + z_n) + 1; // number of elments in the same + // anti-diagonal line to update + int start = slice < n + 1 ? slice : z_n * (n + 1) - 1; // start index + Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m, n, start); } - - Place gpu_place = boost::get(ctx.GetPlace()); - memory::Copy(platform::CPUPlace(), &distance, gpu_place, - dist + m * (n + 1) + n, sizeof(T), stream); - } - - if (normalized) { - distance = distance / n; + SetOutput<<<1, 1, 0, stream>>>(out, dist, m, n, normalized); } - auto out = out_t->data(); - Place gpu_place = boost::get(ctx.GetPlace()); - float dist_f = distance; - memory::Copy(gpu_place, out, platform::CPUPlace(), &dist_f, sizeof(float), - stream); } }; @@ -126,5 +128,4 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( ctc_edit_distance, - ops::CTCEditDistanceGPUKernel, - ops::CTCEditDistanceGPUKernel); + ops::CTCEditDistanceGPUKernel); diff --git a/paddle/operators/ctc_edit_distance_op.h b/paddle/operators/ctc_edit_distance_op.h index a52960f1ef..08f29cf24a 100644 --- a/paddle/operators/ctc_edit_distance_op.h +++ b/paddle/operators/ctc_edit_distance_op.h @@ -35,7 +35,7 @@ class CTCEditDistanceKernel : public framework::OpKernel { auto m = x1_t->numel(); auto n = x2_t->numel(); - float distance = 0.0; + T distance = 0.0; if (m == 0) { distance = n; } else if (n == 0) { @@ -45,16 +45,16 @@ class CTCEditDistanceKernel : public framework::OpKernel { dist_t.Resize({m + 1, n + 1}); dist_t.mutable_data(ctx.GetPlace()); auto dist = dist_t.data(); - auto x1 = x1_t->data(); - auto x2 = x2_t->data(); - for (size_t i = 0; i < m + 1; ++i) { + auto x1 = x1_t->data(); + auto x2 = x2_t->data(); + for (int64_t i = 0; i < m + 1; ++i) { dist[i * (n + 1)] = i; } - for (size_t j = 0; j < n + 1; ++j) { + for (int64_t j = 0; j < n + 1; ++j) { dist[j] = j; } - for (size_t i = 1; i < m + 1; ++i) { - for (size_t j = 1; j < n + 1; ++j) { + for (int64_t i = 1; i < m + 1; ++i) { + for (int64_t j = 1; j < n + 1; ++j) { int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; int dels = dist[(i - 1) * (n + 1) + j] + 1; int ins = dist[i * (n + 1) + (j - 1)] + 1; @@ -68,7 +68,7 @@ class CTCEditDistanceKernel : public framework::OpKernel { if (normalized) { distance = distance / n; } - auto out = out_t->data(); + auto out = out_t->data(); out[0] = distance; } }; diff --git a/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py index 6694a6ee29..62c233b34f 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py +++ b/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py @@ -37,11 +37,9 @@ def Levenshtein(hyp, ref): class TestCTCEditDistanceOp(OpTest): def setUp(self): self.op_type = "ctc_edit_distance" - normalized = False - #x1 = np.array([0, 12, 3, 5]).astype("int64") - #x2 = np.array([0, 12, 4, 7, 8]).astype("int64") - x1 = np.array([0, 12, 5]).astype("int64") - x2 = np.array([0, 12, 4]).astype("int64") + normalized = True + x1 = np.array([0, 12, 3, 5]).astype("int32") + x2 = np.array([0, 12, 4, 7, 8]).astype("int32") distance = Levenshtein(hyp=x1, ref=x2) if normalized is True: From b82049bdca55aa596ecaf4e7390f96ef7e3982c7 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 29 Nov 2017 09:40:49 +0000 Subject: [PATCH 05/30] revise the doc in ctc_edit_distance_op --- paddle/operators/ctc_edit_distance_op.cc | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/paddle/operators/ctc_edit_distance_op.cc b/paddle/operators/ctc_edit_distance_op.cc index d2f4ce67c2..11e9983e24 100644 --- a/paddle/operators/ctc_edit_distance_op.cc +++ b/paddle/operators/ctc_edit_distance_op.cc @@ -58,12 +58,19 @@ class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( CTCEditDistance operator computes the edit distance of two sequences, one named -hypothesis and another named reference. +hypothesis with length M and another named reference with length N. -Edit distance measures how dissimilar two strings, one is hypothesis and another -is reference, are by counting the minimum number of operations to transform -one string into anthor. +Edit distance, also called Levenshtein distance, measures how dissimilar two strings +are by counting the minimum number of operations to transform one string into anthor. +Here the operations include insertion, deletion, and substitution. For example, +given hypothesis string A = "kitten" and reference B = "sitting", the edit distance +is 3 for A will be transformed into B at least after two substitutions and one +insertion: + + "kitten" -> "sitten" -> "sittin" -> "sitting" +If Attr(normalized) is true, the edit distance will be divided by the length of +reference string N. )DOC"); } }; From 2c1adb060469e0b55dae966ec1edc260e1a2bfeb Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 29 Dec 2017 07:16:40 +0000 Subject: [PATCH 06/30] Rename ctc_edit_distance_op to edit_distance_op --- ...dit_distance_op.cc => edit_distance_op.cc} | 24 +++++++++---------- ...dit_distance_op.cu => edit_distance_op.cu} | 12 +++++----- ..._edit_distance_op.h => edit_distance_op.h} | 2 +- ...istance_op.py => test_edit_distance_op.py} | 2 +- 4 files changed, 19 insertions(+), 21 deletions(-) rename paddle/operators/{ctc_edit_distance_op.cc => edit_distance_op.cc} (77%) rename paddle/operators/{ctc_edit_distance_op.cu => edit_distance_op.cu} (93%) rename paddle/operators/{ctc_edit_distance_op.h => edit_distance_op.h} (97%) rename python/paddle/v2/fluid/tests/{test_ctc_edit_distance_op.py => test_edit_distance_op.py} (97%) diff --git a/paddle/operators/ctc_edit_distance_op.cc b/paddle/operators/edit_distance_op.cc similarity index 77% rename from paddle/operators/ctc_edit_distance_op.cc rename to paddle/operators/edit_distance_op.cc index 11e9983e24..843a6844cd 100644 --- a/paddle/operators/ctc_edit_distance_op.cc +++ b/paddle/operators/edit_distance_op.cc @@ -12,12 +12,12 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/ctc_edit_distance_op.h" +#include "paddle/operators/edit_distance_op.h" namespace paddle { namespace operators { -class CTCEditDistanceOp : public framework::OperatorWithKernel { +class EditDistanceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -29,17 +29,16 @@ class CTCEditDistanceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelType( + framework::OpKernelType GetActualKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(framework::DataType::FP32, + return framework::OpKernelType(framework::proto::DataType::FP32, ctx.device_context()); } }; -class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { +class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { public: - CTCEditDistanceOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X1", "(2-D tensor with shape [M x 1]) The indices for " @@ -54,10 +53,10 @@ class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(false); AddOutput("Out", "(2-D tensor with shape [1 x 1]) " - "The output distance of CTCEditDistance operator."); + "The output distance of EditDistance operator."); AddComment(R"DOC( -CTCEditDistance operator computes the edit distance of two sequences, one named +EditDistance operator computes the edit distance of two sequences, one named hypothesis with length M and another named reference with length N. Edit distance, also called Levenshtein distance, measures how dissimilar two strings @@ -80,8 +79,7 @@ reference string N. namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp, - ops::CTCEditDistanceOpMaker); +REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker, + paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - ctc_edit_distance, - ops::CTCEditDistanceKernel); + edit_distance, ops::EditDistanceKernel); diff --git a/paddle/operators/ctc_edit_distance_op.cu b/paddle/operators/edit_distance_op.cu similarity index 93% rename from paddle/operators/ctc_edit_distance_op.cu rename to paddle/operators/edit_distance_op.cu index 22871acc4e..7fa6a60df4 100644 --- a/paddle/operators/ctc_edit_distance_op.cu +++ b/paddle/operators/edit_distance_op.cu @@ -65,7 +65,7 @@ __global__ void SetOutput(T* out, const T* dist, const int M, const int N, } template -class CTCEditDistanceGPUKernel : public framework::OpKernel { +class EditDistanceGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); @@ -110,8 +110,8 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel { int z_n = slice < n + 1 ? 0 : slice - n; int size = slice - (z_m + z_n) + 1; // number of elments in the same // anti-diagonal line to update - int start = slice < n + 1 ? slice : z_n * (n + 1) - 1; // start index - + // the start index at which computes from + int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m, n, start); @@ -126,6 +126,6 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel { namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - ctc_edit_distance, - ops::CTCEditDistanceGPUKernel); +REGISTER_OP_CUDA_KERNEL( + edit_distance, + ops::EditDistanceGPUKernel); diff --git a/paddle/operators/ctc_edit_distance_op.h b/paddle/operators/edit_distance_op.h similarity index 97% rename from paddle/operators/ctc_edit_distance_op.h rename to paddle/operators/edit_distance_op.h index 08f29cf24a..182a6e3bf5 100644 --- a/paddle/operators/ctc_edit_distance_op.h +++ b/paddle/operators/edit_distance_op.h @@ -21,7 +21,7 @@ namespace paddle { namespace operators { template -class CTCEditDistanceKernel : public framework::OpKernel { +class EditDistanceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); diff --git a/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py similarity index 97% rename from python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py rename to python/paddle/v2/fluid/tests/test_edit_distance_op.py index 62c233b34f..8866922f2e 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py +++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py @@ -36,7 +36,7 @@ def Levenshtein(hyp, ref): class TestCTCEditDistanceOp(OpTest): def setUp(self): - self.op_type = "ctc_edit_distance" + self.op_type = "edit_distance" normalized = True x1 = np.array([0, 12, 3, 5]).astype("int32") x2 = np.array([0, 12, 4, 7, 8]).astype("int32") From 2e49facae9baa9cc161d23d064e792abbc4e6e84 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 1 Jan 2018 15:23:00 +0000 Subject: [PATCH 07/30] Rename inputs & format license --- paddle/operators/edit_distance_op.cc | 28 +++++++++---------- paddle/operators/edit_distance_op.cu | 22 +++++++-------- paddle/operators/edit_distance_op.h | 22 +++++++-------- .../v2/fluid/tests/test_edit_distance_op.py | 2 +- 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/paddle/operators/edit_distance_op.cc b/paddle/operators/edit_distance_op.cc index 843a6844cd..6022a7a4bd 100644 --- a/paddle/operators/edit_distance_op.cc +++ b/paddle/operators/edit_distance_op.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/edit_distance_op.h" @@ -22,8 +22,8 @@ class EditDistanceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null."); - PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Hyp"), "Input(Hyp) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Ref"), "Input(Ref) shouldn't be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); ctx->SetOutputDim("Out", {1}); } @@ -40,16 +40,16 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { public: EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X1", + AddInput("Hyp", "(2-D tensor with shape [M x 1]) The indices for " "hypothesis string"); - AddInput("X2", + AddInput("Ref", "(2-D tensor with shape [N x 1]) The indices " "for reference string."); AddAttr("normalized", "(bool, default false) Indicated whether " "normalize the Output(Out) by the length of reference " - "string (X2).") + "string (Ref).") .SetDefault(false); AddOutput("Out", "(2-D tensor with shape [1 x 1]) " diff --git a/paddle/operators/edit_distance_op.cu b/paddle/operators/edit_distance_op.cu index 7fa6a60df4..fed91ffb43 100644 --- a/paddle/operators/edit_distance_op.cu +++ b/paddle/operators/edit_distance_op.cu @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include #include "paddle/framework/op_registry.h" @@ -70,8 +70,8 @@ class EditDistanceGPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); - auto* x1_t = ctx.Input("X1"); - auto* x2_t = ctx.Input("X2"); + auto* x1_t = ctx.Input("Hyp"); + auto* x2_t = ctx.Input("Ref"); out_t->mutable_data(ctx.GetPlace()); auto out = out_t->data(); diff --git a/paddle/operators/edit_distance_op.h b/paddle/operators/edit_distance_op.h index 182a6e3bf5..abde4fe97c 100644 --- a/paddle/operators/edit_distance_op.h +++ b/paddle/operators/edit_distance_op.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include @@ -26,8 +26,8 @@ class EditDistanceKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); - auto* x1_t = ctx.Input("X1"); - auto* x2_t = ctx.Input("X2"); + auto* x1_t = ctx.Input("Hyp"); + auto* x2_t = ctx.Input("Ref"); out_t->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/v2/fluid/tests/test_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py index 8866922f2e..df1ac620e7 100644 --- a/python/paddle/v2/fluid/tests/test_edit_distance_op.py +++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py @@ -45,7 +45,7 @@ class TestCTCEditDistanceOp(OpTest): if normalized is True: distance = distance / len(x2) self.attrs = {'normalized': normalized} - self.inputs = {'X1': x1, 'X2': x2} + self.inputs = {'Hyp': x1, 'Ref': x2} self.outputs = {'Out': distance} def test_check_output(self): From 0250e54c2dc8ae4687a2ede661cd25dadfb66ce9 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 3 Jan 2018 09:48:29 +0000 Subject: [PATCH 08/30] Enable batch input in edit_distance_op --- paddle/operators/edit_distance_op.cc | 49 ++++++---- paddle/operators/edit_distance_op.cu | 98 +++++++++++-------- paddle/operators/edit_distance_op.h | 91 ++++++++++------- .../v2/fluid/tests/test_edit_distance_op.py | 52 ++++++++-- 4 files changed, 189 insertions(+), 101 deletions(-) diff --git a/paddle/operators/edit_distance_op.cc b/paddle/operators/edit_distance_op.cc index 6022a7a4bd..7b92148f0e 100644 --- a/paddle/operators/edit_distance_op.cc +++ b/paddle/operators/edit_distance_op.cc @@ -22,10 +22,18 @@ class EditDistanceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Hyp"), "Input(Hyp) shouldn't be null."); - PADDLE_ENFORCE(ctx->HasInput("Ref"), "Input(Ref) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); - ctx->SetOutputDim("Out", {1}); + auto hyp_dims = ctx->GetInputDim("Hyps"); + auto ref_dims = ctx->GetInputDim("Refs"); + PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1, + "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1."); + PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1, + "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1."); + ctx->SetOutputDim("Out", ctx->GetInputDim("Refs")); } protected: @@ -40,24 +48,23 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { public: EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Hyp", - "(2-D tensor with shape [M x 1]) The indices for " - "hypothesis string"); - AddInput("Ref", - "(2-D tensor with shape [N x 1]) The indices " - "for reference string."); + AddInput("Hyps", + "(2-D LoDTensor, 2nd dim. equal to 1) " + "The indices for hypothesis strings."); + AddInput("Refs", + "(2-D LoDTensor, 2nd dim. equal to 1) " + "The indices for reference strings."); AddAttr("normalized", - "(bool, default false) Indicated whether " - "normalize the Output(Out) by the length of reference " - "string (Ref).") + "(bool, default false) Indicated whether to normalize " + "the edit distance by the length of reference string.") .SetDefault(false); AddOutput("Out", - "(2-D tensor with shape [1 x 1]) " - "The output distance of EditDistance operator."); + "(2-D Tensor with shape [`batch_size` x 1]) " + "The output edit distances of EditDistance operator."); AddComment(R"DOC( -EditDistance operator computes the edit distance of two sequences, one named -hypothesis with length M and another named reference with length N. +EditDistance operator computes the edit distances between a batch of hypothesis +strings and their references. Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. @@ -68,8 +75,14 @@ insertion: "kitten" -> "sitten" -> "sittin" -> "sitting" -If Attr(normalized) is true, the edit distance will be divided by the length of -reference string N. +Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total +number denoted by `batch_size`, and the separation is specified by the LoD information. +And the `batch_size` reference strings are arranged in order in the same way in the +LoDTensor Input(Refs). + +Output(Out) contains the `batch_size` results and each stands for the edit stance +for a pair of strings respectively. If Attr(normalized) is true, the edit distance +will be divided by the length of reference string. )DOC"); } }; diff --git a/paddle/operators/edit_distance_op.cu b/paddle/operators/edit_distance_op.cu index fed91ffb43..b548345986 100644 --- a/paddle/operators/edit_distance_op.cu +++ b/paddle/operators/edit_distance_op.cu @@ -70,53 +70,71 @@ class EditDistanceGPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); - auto* x1_t = ctx.Input("Hyp"); - auto* x2_t = ctx.Input("Ref"); - - out_t->mutable_data(ctx.GetPlace()); - auto out = out_t->data(); + auto* x1_t = ctx.Input("Hyps"); + auto* x2_t = ctx.Input("Refs"); auto normalized = ctx.Attr("normalized"); auto stream = reinterpret_cast( ctx.device_context()) .stream(); - auto m = x1_t->numel(); - auto n = x2_t->numel(); - T distance = 0.0; - if (m == 0 || n == 0) { - distance = std::max(m, n); - if (normalized) { - distance = distance / n; - } - memory::Copy(boost::get(ctx.GetPlace()), out, platform::CPUPlace(), - &distance, sizeof(T), stream); - } else { - framework::Tensor dist_t; - dist_t.Resize({m + 1, n + 1}); - dist_t.mutable_data(ctx.GetPlace()); - auto dist = dist_t.data(); - auto x1 = x1_t->data(); - auto x2 = x2_t->data(); - - FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); - - FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); - // Compute the elements of distance matrix in the anti-diagonal diretion - for (int64_t slice = 2; slice < m + n + 1; ++slice) { - int z_m = slice < m + 1 ? 0 : slice - m; - int z_n = slice < n + 1 ? 0 : slice - n; - int size = slice - (z_m + z_n) + 1; // number of elments in the same - // anti-diagonal line to update - // the start index at which computes from - int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; - Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m, - n, start); + auto hyp_lod = x1_t->lod()[0]; + auto ref_lod = x2_t->lod()[0]; + PADDLE_ENFORCE( + hyp_lod.size() == ref_lod.size(), + "Input(Hyps) and Input(Refs) must have the same batch size."); + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], + "Reference string %d is empty.", i); + } + + auto num_strs = hyp_lod.size() - 1; + out_t->Resize({static_cast(num_strs), 1}); + out_t->mutable_data(ctx.GetPlace()); + auto out = out_t->data(); + + std::vector distance(num_strs, 0.0); + for (size_t num = 0; num < num_strs; num++) { + auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); + auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); + if (m == 0 || n == 0) { + distance[num] = std::max(m, n); + if (normalized) { + PADDLE_ENFORCE(n > 0, + "The reference string (#%d) cannot be empty " + "when Attr(normalized) is enabled.", + n); + distance[num] = distance[num] / n; + } + memory::Copy(boost::get(ctx.GetPlace()), out + num, + platform::CPUPlace(), &distance[num], sizeof(T), stream); + } else { + framework::Tensor dist_t; + dist_t.Resize({m + 1, n + 1}); + dist_t.mutable_data(ctx.GetPlace()); + auto dist = dist_t.data(); + auto x1 = x1_t->data() + hyp_lod[num]; + auto x2 = x2_t->data() + ref_lod[num]; + + FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); + + FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); + // Compute the elements of distance matrix in the anti-diagonal diretion + for (int64_t slice = 2; slice < m + n + 1; ++slice) { + int z_m = slice < m + 1 ? 0 : slice - m; + int z_n = slice < n + 1 ? 0 : slice - n; + int size = slice - (z_m + z_n) + 1; // number of elments in the same + // anti-diagonal line to update + // the start index at which computes from + int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; + Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, + m, n, start); + } + SetOutput<<<1, 1, 0, stream>>>(out + num, dist, m, n, normalized); } - SetOutput<<<1, 1, 0, stream>>>(out, dist, m, n, normalized); } } }; diff --git a/paddle/operators/edit_distance_op.h b/paddle/operators/edit_distance_op.h index abde4fe97c..6284f230e5 100644 --- a/paddle/operators/edit_distance_op.h +++ b/paddle/operators/edit_distance_op.h @@ -26,50 +26,69 @@ class EditDistanceKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); - auto* x1_t = ctx.Input("Hyp"); - auto* x2_t = ctx.Input("Ref"); + auto* x1_t = ctx.Input("Hyps"); + auto* x2_t = ctx.Input("Refs"); + auto normalized = ctx.Attr("normalized"); + + auto hyp_lod = x1_t->lod()[0]; + auto ref_lod = x2_t->lod()[0]; + PADDLE_ENFORCE( + hyp_lod.size() == ref_lod.size(), + "Input(Hyps) and Input(Refs) must have the same batch size."); + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], + "Reference string %d is empty.", i); + } + auto num_strs = hyp_lod.size() - 1; + + out_t->Resize({static_cast(num_strs), 1}); out_t->mutable_data(ctx.GetPlace()); + auto out = out_t->data(); - auto normalized = ctx.Attr("normalized"); + std::vector distance(num_strs, 0.0); + for (size_t num = 0; num < num_strs; ++num) { + auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); + auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); - auto m = x1_t->numel(); - auto n = x2_t->numel(); - T distance = 0.0; - if (m == 0) { - distance = n; - } else if (n == 0) { - distance = m; - } else { - framework::Tensor dist_t; - dist_t.Resize({m + 1, n + 1}); - dist_t.mutable_data(ctx.GetPlace()); - auto dist = dist_t.data(); - auto x1 = x1_t->data(); - auto x2 = x2_t->data(); - for (int64_t i = 0; i < m + 1; ++i) { - dist[i * (n + 1)] = i; - } - for (int64_t j = 0; j < n + 1; ++j) { - dist[j] = j; - } - for (int64_t i = 1; i < m + 1; ++i) { - for (int64_t j = 1; j < n + 1; ++j) { - int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; - int dels = dist[(i - 1) * (n + 1) + j] + 1; - int ins = dist[i * (n + 1) + (j - 1)] + 1; - int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; - dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); + if (m == 0) { + distance[num] = n; + } else if (n == 0) { + distance[num] = m; + } else { + framework::Tensor dist_t; + dist_t.Resize({m + 1, n + 1}); + dist_t.mutable_data(ctx.GetPlace()); + auto dist = dist_t.data(); + auto x1 = x1_t->data() + hyp_lod[num]; + auto x2 = x2_t->data() + ref_lod[num]; + for (int64_t i = 0; i < m + 1; ++i) { + dist[i * (n + 1)] = i; + } + for (int64_t j = 0; j < n + 1; ++j) { + dist[j] = j; + } + for (int64_t i = 1; i < m + 1; ++i) { + for (int64_t j = 1; j < n + 1; ++j) { + int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; + int dels = dist[(i - 1) * (n + 1) + j] + 1; + int ins = dist[i * (n + 1) + (j - 1)] + 1; + int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; + dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); + } } + distance[num] = dist[m * (n + 1) + n]; } - distance = dist[m * (n + 1) + n]; - } - if (normalized) { - distance = distance / n; + if (normalized) { + PADDLE_ENFORCE(n > 0, + "The reference string (#%d) cannot be empty " + "when Attr(normalized) is enabled.", + n); + distance[num] = distance[num] / n; + } + out[num] = distance[num]; } - auto out = out_t->data(); - out[0] = distance; } }; diff --git a/python/paddle/v2/fluid/tests/test_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py index df1ac620e7..24f2f0c5c2 100644 --- a/python/paddle/v2/fluid/tests/test_edit_distance_op.py +++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py @@ -18,7 +18,7 @@ def Levenshtein(hyp, ref): if n == 0: return m - dist = np.zeros((m + 1, n + 1)) + dist = np.zeros((m + 1, n + 1)).astype("float32") for i in range(0, m + 1): dist[i][0] = i for j in range(0, n + 1): @@ -35,17 +35,55 @@ def Levenshtein(hyp, ref): class TestCTCEditDistanceOp(OpTest): + def setUp(self): + self.op_type = "edit_distance" + normalized = False + x1 = np.array([[0, 12, 3, 5, 8, 2]]).astype("int32") + x2 = np.array([[0, 12, 4, 7, 8]]).astype("int32") + x1 = np.transpose(x1) + x2 = np.transpose(x2) + x1_lod = [0, 1, 5] + x2_lod = [0, 3, 4] + + num_strs = len(x1_lod) - 1 + distance = np.zeros((num_strs, 1)).astype("float32") + for i in range(0, num_strs): + distance[i] = Levenshtein( + hyp=x1[x1_lod[i]:x1_lod[i + 1]], + ref=x2[x2_lod[i]:x2_lod[i + 1]]) + if normalized is True: + len_ref = x2_lod[i + 1] - x2_lod[i] + distance[i] = distance[i] / len_ref + self.attrs = {'normalized': normalized} + self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} + self.outputs = {'Out': distance} + + def test_check_output(self): + self.check_output() + + +class TestCTCEditDistanceOpNormalized(OpTest): def setUp(self): self.op_type = "edit_distance" normalized = True - x1 = np.array([0, 12, 3, 5]).astype("int32") - x2 = np.array([0, 12, 4, 7, 8]).astype("int32") + x1 = np.array([[0, 10, 3, 6, 5, 8, 2]]).astype("int32") + x2 = np.array([[0, 10, 4, 6, 7, 8]]).astype("int32") + x1 = np.transpose(x1) + x2 = np.transpose(x2) + x1_lod = [0, 1, 3, 6] + x2_lod = [0, 2, 3, 5] - distance = Levenshtein(hyp=x1, ref=x2) - if normalized is True: - distance = distance / len(x2) + num_strs = len(x1_lod) - 1 + distance = np.zeros((num_strs, 1)).astype("float32") + for i in range(0, num_strs): + distance[i] = Levenshtein( + hyp=x1[x1_lod[i]:x1_lod[i + 1]], + ref=x2[x2_lod[i]:x2_lod[i + 1]]) + if normalized is True: + len_ref = x2_lod[i + 1] - x2_lod[i] + distance[i] = distance[i] / len_ref self.attrs = {'normalized': normalized} - self.inputs = {'Hyp': x1, 'Ref': x2} + self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} self.outputs = {'Out': distance} def test_check_output(self): From 2b3ba40e6a504b5526208a945e8fff6594dd7904 Mon Sep 17 00:00:00 2001 From: gx_wind Date: Sat, 6 Jan 2018 17:33:34 +0800 Subject: [PATCH 09/30] add adversarial sample --- adversarial/advbox/__init__.py | 17 +++ adversarial/advbox/attacks/base.py | 42 +++++++ adversarial/advbox/attacks/gradientsign.py | 36 ++++++ adversarial/advbox/models/__init__.py | 16 +++ adversarial/advbox/models/base.py | 91 ++++++++++++++ adversarial/advbox/models/paddle.py | 106 ++++++++++++++++ .../advbox/tutorials/tutorial_model.py | 32 +++++ adversarial/fluid_mnist.py | 91 ++++++++++++++ adversarial/mnist_fgsm.py | 113 ++++++++++++++++++ adversarial/mnist_tutorial_fgsm.py | 94 +++++++++++++++ 10 files changed, 638 insertions(+) create mode 100644 adversarial/advbox/__init__.py create mode 100644 adversarial/advbox/attacks/base.py create mode 100644 adversarial/advbox/attacks/gradientsign.py create mode 100644 adversarial/advbox/models/__init__.py create mode 100644 adversarial/advbox/models/base.py create mode 100644 adversarial/advbox/models/paddle.py create mode 100644 adversarial/advbox/tutorials/tutorial_model.py create mode 100644 adversarial/fluid_mnist.py create mode 100644 adversarial/mnist_fgsm.py create mode 100644 adversarial/mnist_tutorial_fgsm.py diff --git a/adversarial/advbox/__init__.py b/adversarial/advbox/__init__.py new file mode 100644 index 0000000000..4beb6be0a2 --- /dev/null +++ b/adversarial/advbox/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + A set of tools for generating adversarial example on paddle platform +""" diff --git a/adversarial/advbox/attacks/base.py b/adversarial/advbox/attacks/base.py new file mode 100644 index 0000000000..9cc2bfb854 --- /dev/null +++ b/adversarial/advbox/attacks/base.py @@ -0,0 +1,42 @@ +""" +The base model of the model. +""" +from abc import ABCMeta +#from advbox.base import Model +import abc + +abstractmethod = abc.abstractmethod + +class Attack(object): + """ + Abstract base class for adversarial attacks. `Attack` represent an adversarial attack + which search an adversarial example. subclass should implement the _apply() method. + + Args: + model(Model): an instance of the class advbox.base.Model. + + """ + __metaclass__ = ABCMeta + + def __init__(self, model): + self.model = model + + def __call__(self, image_batch): + """ + Generate the adversarial sample. + + Args: + image_batch(list): The image and label tuple list. + """ + adv_img = self._apply(image_batch) + return adv_img + + @abstractmethod + def _apply(self, image_batch): + """ + Search an adversarial example. + + Args: + image_batch(list): The image and label tuple list. + """ + raise NotImplementedError diff --git a/adversarial/advbox/attacks/gradientsign.py b/adversarial/advbox/attacks/gradientsign.py new file mode 100644 index 0000000000..6c188f6249 --- /dev/null +++ b/adversarial/advbox/attacks/gradientsign.py @@ -0,0 +1,36 @@ +""" +This module provide the attack method for FGSM's implement. +""" +from __future__ import division +import numpy as np +from collections import Iterable +from .base import Attack + +class GradientSignAttack(Attack): + """ + This attack was originally implemented by Goodfellow et al. (2015) with the + infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called + the Fast Gradient Method. + Paper link: https://arxiv.org/abs/1412.6572 + """ + + def _apply(self, image_batch, epsilons=1000): + pre_label = np.argmax(self.model.predict(image_batch)) + + min_, max_ = self.model.bounds() + gradient = self.model.gradient(image_batch) + gradient_sign = np.sign(gradient) * (max_ - min_) + + if not isinstance(epsilons, Iterable): + epsilons = np.linspace(0, 1, num = epsilons + 1) + + for epsilon in epsilons: + adv_img = image_batch[0][0].reshape(gradient_sign.shape) + epsilon * gradient_sign + adv_img = np.clip(adv_img, min_, max_) + adv_label = np.argmax(self.model.predict([(adv_img, 0)])) + #print("pre_label="+str(pre_label)+ " adv_label="+str(adv_label)) + if pre_label != adv_label: + #print(epsilon, pre_label, adv_label) + return adv_img + +FGSM = GradientSignAttack diff --git a/adversarial/advbox/models/__init__.py b/adversarial/advbox/models/__init__.py new file mode 100644 index 0000000000..eee0f6efd4 --- /dev/null +++ b/adversarial/advbox/models/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Paddle model for target of attack +""" diff --git a/adversarial/advbox/models/base.py b/adversarial/advbox/models/base.py new file mode 100644 index 0000000000..91b6fe4a3c --- /dev/null +++ b/adversarial/advbox/models/base.py @@ -0,0 +1,91 @@ +""" +The base model of the model. +""" +from abc import ABCMeta +import abc + +abstractmethod = abc.abstractmethod + +class Model(object): + + """ + Base class of model to provide attack. + + + Args: + bounds(tuple): The lower and upper bound for the image pixel. + channel_axis(int): The index of the axis that represents the color channel. + preprocess(tuple): Two element tuple used to preprocess the input. First + substract the first element, then divide the second element. + """ + __metaclass__ = ABCMeta + + def __init__(self, bounds, channel_axis, preprocess=None): + assert len(bounds) == 2 + assert channel_axis in [0, 1, 2, 3] + + if preprocess is None: + preprocess = (0, 1) + self._bounds = bounds + self._channel_axis = channel_axis + self._preprocess = preprocess + + def bounds(self): + """ + Return the upper and lower bounds of the model. + """ + return self._bounds + + def channel_axis(self): + """ + Return the channel axis of the model. + """ + return self._channel_axis + + def _process_input(self, input_): + res = input_ + sub, div = self._preprocess + if sub != 0: + res = input_ - sub + assert div != 0 + if div != 1: + res /= div + return res + + @abstractmethod + def predict(self, image_batch): + """ + Calculate the prediction of the image batch. + + Args: + image_batch(numpy.ndarray): image batch of shape (batch_size, height, width, channels). + + Return: + numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes). + """ + raise NotImplementedError + + @abstractmethod + def num_classes(self): + """ + Determine the number of the classes + + Return: + int: the number of the classes + """ + raise NotImplementedError + + @abstractmethod + def gradient(self, image_batch): + """ + Calculate the gradient of the cross-entropy loss w.r.t the image. + + Args: + image(numpy.ndarray): image with shape (height, width, channel) + label(int): image label used to cal gradient. + + Return: + numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with + the shape (height, width, channel). + """ + raise NotImplementedError diff --git a/adversarial/advbox/models/paddle.py b/adversarial/advbox/models/paddle.py new file mode 100644 index 0000000000..831fa6a362 --- /dev/null +++ b/adversarial/advbox/models/paddle.py @@ -0,0 +1,106 @@ +from __future__ import absolute_import + +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +from paddle.v2.fluid.framework import program_guard + +from .base import Model + +class PaddleModel(Model): + """ + Create a PaddleModel instance. + When you need to generate a adversarial sample, you should construct an instance of PaddleModel. + + Args: + program(paddle.v2.fluid.framework.Program): The program of the model which generate the adversarial sample. + input_name(string): The name of the input. + logits_name(string): The name of the logits. + predict_name(string): The name of the predict. + cost_name(string): The name of the loss in the program. + """ + + def __init__(self, + program, + input_name, + logits_name, + predict_name, + cost_name, + bounds, + channel_axis=3, + preprocess=None): + super(PaddleModel, self).__init__( + bounds=bounds, + channel_axis=channel_axis, + preprocess=preprocess) + + if preprocess is None: + preprocess = (0, 1) + + self._program = program + self._place = fluid.CPUPlace() + self._exe = fluid.Executor(self._place) + + self._input_name = input_name + self._logits_name = logits_name + self._predict_name = predict_name + self._cost_name = cost_name + + # gradient + loss = self._program.block(0).var(self._cost_name) + param_grads = fluid.backward.append_backward(loss, parameter_list=[self._input_name]) + self._gradient = param_grads[0][1] + + def predict(self, image_batch): + """ + Predict the label of the image_batch. + + Args: + image_batch(list): The image and label tuple list. + Return: + numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes). + """ + feeder = fluid.DataFeeder( + feed_list=[self._input_name, self._logits_name], + place=self._place, + program=self._program + ) + predict_var = self._program.block(0).var(self._predict_name) + predict = self._exe.run( + self._program, + feed=feeder.feed(image_batch), + fetch_list=[predict_var] + ) + return predict + + def num_classes(self): + """ + Calculate the number of classes of the output label. + + Return: + int: the number of classes + """ + predict_var = self._program.block(0).var(self._predict_name) + assert len(predict_var.shape) == 2 + return predict_var.shape[1] + + def gradient(self, image_batch): + """ + Calculate the gradient of the loss w.r.t the input. + + Args: + image_batch(list): The image and label tuple list. + Return: + list: The list of the gradient of the image. + """ + feeder = fluid.DataFeeder( + feed_list=[self._input_name, self._logits_name], + place=self._place, + program=self._program + ) + + grad, = self._exe.run( + self._program, + feed=feeder.feed(image_batch), + fetch_list=[self._gradient]) + return grad diff --git a/adversarial/advbox/tutorials/tutorial_model.py b/adversarial/advbox/tutorials/tutorial_model.py new file mode 100644 index 0000000000..425f09a056 --- /dev/null +++ b/adversarial/advbox/tutorials/tutorial_model.py @@ -0,0 +1,32 @@ +################################################################################ +# +# Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved +# +################################################################################ +""" + +A pure Paddlepaddle implementation of a neural network. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +from advbox import Model + +def main(): + """ + example main function + """ + model_dir = "./mnist_model" + place = fluid.CPUPlace() + exe = fluid.Executor(place) + program, feed_var_names, fetch_vars = fluid.io.load_inferfence_model(model_dir, exe) + print(program) + +if __name__ == "__main__": + main() diff --git a/adversarial/fluid_mnist.py b/adversarial/fluid_mnist.py new file mode 100644 index 0000000000..d46defda55 --- /dev/null +++ b/adversarial/fluid_mnist.py @@ -0,0 +1,91 @@ +""" +CNN on mnist data using fluid api of paddlepaddle +""" +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + +def mnist_cnn_model(img): + """ + Mnist cnn model + + Args: + img(Varaible): the input image to be recognized + + Returns: + Variable: the label prediction + """ + #conv1 = fluid.nets.conv2d() + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + num_filters=20, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + num_filters=50, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + logits = fluid.layers.fc( + input=conv_pool_2, + size=10, + act='softmax') + return logits + + +def main(): + """ + Train the cnn model on mnist datasets + """ + img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + logits = mnist_cnn_model(img) + cost = fluid.layers.cross_entropy(input=logits, label=label) + avg_cost = fluid.layers.mean(x=cost) + optimizer = fluid.optimizer.Adam(learning_rate=0.01) + optimizer.minimize(avg_cost) + + accuracy = fluid.evaluator.Accuracy(input=logits, label=label) + + BATCH_SIZE = 50 + PASS_NUM = 3 + ACC_THRESHOLD = 0.98 + LOSS_THRESHOLD = 10.0 + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=BATCH_SIZE) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[img, label], place=place) + exe.run(fluid.default_startup_program()) + + for pass_id in range(PASS_NUM): + accuracy.reset(exe) + for data in train_reader(): + loss, acc = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost] + accuracy.metrics) + pass_acc = accuracy.eval(exe) + print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" + + str(pass_acc)) + # print loss, acc + if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD: + # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good. + break +# exit(0) + + pass_acc = accuracy.eval(exe) + print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc)) + fluid.io.save_params(exe, dirname='./mnist', main_program=fluid.default_main_program()) + print('train mnist done') + exit(1) + +if __name__ == '__main__': + main() diff --git a/adversarial/mnist_fgsm.py b/adversarial/mnist_fgsm.py new file mode 100644 index 0000000000..187f37b82e --- /dev/null +++ b/adversarial/mnist_fgsm.py @@ -0,0 +1,113 @@ +""" +This attack was originally implemented by Goodfellow et al. (2015) with the +infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called +the Fast Gradient Method. +Paper link: https://arxiv.org/abs/1412.6572 +""" + +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + +BATCH_SIZE = 50 +PASS_NUM = 1 +EPS = 0.3 +CLIP_MIN = -1 +CLIP_MAX = 1 +PASS_NUM = 1 + +def mnist_cnn_model(img): + """ + Mnist cnn model + + Args: + img(Varaible): the input image to be recognized + + Returns: + Variable: the label prediction + """ + #conv1 = fluid.nets.conv2d() + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + num_filters=20, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + num_filters=50, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + logits = fluid.layers.fc( + input=conv_pool_2, + size=10, + act='softmax') + return logits + + +def main(): + """ + Generate adverserial example and evaluate accuracy on mnist using FGSM + """ + + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype='float32') + # The gradient should flow + images.stop_gradient = False + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + predict = mnist_cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Cal gradient of input + params_grads = fluid.backward.append_backward_ops(avg_cost, parameter_list=['pixel']) + # data batch + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=BATCH_SIZE) + + accuracy = fluid.evaluator.Accuracy(input=predict, label=label) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + accuracy.reset(exe) + #exe.run(fluid.default_startup_program()) + feeder = fluid.DataFeeder(feed_list=[images, label], place=place) + for pass_id in range(PASS_NUM): + fluid.io.load_params(exe, "./mnist/", main_program=fluid.default_main_program()) + for data in train_reader(): + # cal gradient and eval accuracy + ps, acc = exe.run( + fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[params_grads[0][1]]+accuracy.metrics) + labels = [] + for idx, _ in enumerate(data): + labels.append(data[idx][1]) + # generate adversarial example + batch_num = ps.shape[0] + new_data = [] + for i in range(batch_num): + adv_img = np.reshape(data[0][0], (1, 28, 28)) + EPS * np.sign(ps[i]) + adv_img = np.clip(adv_img, CLIP_MIN, CLIP_MAX) + #adv_imgs.append(adv_img) + t = (adv_img, data[0][1]) + new_data.append(t) + + # predict label + predict_label, = exe.run( + fluid.default_main_program(), + feed=feeder.feed(new_data), + fetch_list=[predict]) + adv_labels = np.argmax(predict_label, axis=1) + batch_accuracy = np.mean(np.equal(labels, adv_labels)) + print "pass_id=" + str(pass_id) + " acc=" + str(acc)+ " adv_acc=" + str(batch_accuracy) + + +if __name__ == "__main__": + main() diff --git a/adversarial/mnist_tutorial_fgsm.py b/adversarial/mnist_tutorial_fgsm.py new file mode 100644 index 0000000000..665062afd0 --- /dev/null +++ b/adversarial/mnist_tutorial_fgsm.py @@ -0,0 +1,94 @@ +""" +FGSM demos on mnist using advbox tool. +""" +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +import matplotlib.pyplot as plt +import numpy as np + +from advbox.models.paddle import PaddleModel +from advbox.attacks.gradientsign import GradientSignAttack + +def cnn_model(img): + """ + Mnist cnn model + Args: + img(Varaible): the input image to be recognized + Returns: + Variable: the label prediction + """ + #conv1 = fluid.nets.conv2d() + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + num_filters=20, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + num_filters=50, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + logits = fluid.layers.fc( + input=conv_pool_2, + size=10, + act='softmax') + return logits + + +def main(): + """ + Advbox demo which demonstrate how to use advbox. + """ + IMG_NAME = 'img' + LABEL_NAME = 'label' + + img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32') + # gradient should flow + img.stop_gradient = False + label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64') + logits = cnn_model(img) + cost = fluid.layers.cross_entropy(input=logits, label=label) + avg_cost = fluid.layers.mean(x=cost) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + BATCH_SIZE = 1 + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=BATCH_SIZE) + feeder = fluid.DataFeeder( + feed_list=[IMG_NAME, LABEL_NAME], + place=place, + program=fluid.default_main_program() + ) + + fluid.io.load_params(exe, "./mnist/", main_program=fluid.default_main_program()) + + # advbox demo + m = PaddleModel( + fluid.default_main_program(), + IMG_NAME, + LABEL_NAME, + logits.name, + avg_cost.name, + (-1, 1) + ) + att = GradientSignAttack(m) + for data in train_reader(): + # fgsm attack + adv_img = att(data) + plt.imshow(n[0][0], cmap='Greys_r') + plt.show() + #np.save('adv_img', adv_img) + break + +if __name__ == '__main__': + main() From 35210a044d046a69a6869c53022ed7c95236f382 Mon Sep 17 00:00:00 2001 From: gx_wind Date: Sat, 6 Jan 2018 17:35:41 +0800 Subject: [PATCH 10/30] delete unused files --- .../advbox/tutorials/tutorial_model.py | 32 ----- adversarial/mnist_fgsm.py | 113 ------------------ 2 files changed, 145 deletions(-) delete mode 100644 adversarial/advbox/tutorials/tutorial_model.py delete mode 100644 adversarial/mnist_fgsm.py diff --git a/adversarial/advbox/tutorials/tutorial_model.py b/adversarial/advbox/tutorials/tutorial_model.py deleted file mode 100644 index 425f09a056..0000000000 --- a/adversarial/advbox/tutorials/tutorial_model.py +++ /dev/null @@ -1,32 +0,0 @@ -################################################################################ -# -# Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved -# -################################################################################ -""" - -A pure Paddlepaddle implementation of a neural network. - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -import paddle.v2 as paddle -import paddle.v2.fluid as fluid -from advbox import Model - -def main(): - """ - example main function - """ - model_dir = "./mnist_model" - place = fluid.CPUPlace() - exe = fluid.Executor(place) - program, feed_var_names, fetch_vars = fluid.io.load_inferfence_model(model_dir, exe) - print(program) - -if __name__ == "__main__": - main() diff --git a/adversarial/mnist_fgsm.py b/adversarial/mnist_fgsm.py deleted file mode 100644 index 187f37b82e..0000000000 --- a/adversarial/mnist_fgsm.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -This attack was originally implemented by Goodfellow et al. (2015) with the -infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called -the Fast Gradient Method. -Paper link: https://arxiv.org/abs/1412.6572 -""" - -import numpy as np -import paddle.v2 as paddle -import paddle.v2.fluid as fluid - -BATCH_SIZE = 50 -PASS_NUM = 1 -EPS = 0.3 -CLIP_MIN = -1 -CLIP_MAX = 1 -PASS_NUM = 1 - -def mnist_cnn_model(img): - """ - Mnist cnn model - - Args: - img(Varaible): the input image to be recognized - - Returns: - Variable: the label prediction - """ - #conv1 = fluid.nets.conv2d() - conv_pool_1 = fluid.nets.simple_img_conv_pool( - input=img, - num_filters=20, - filter_size=5, - pool_size=2, - pool_stride=2, - act='relu') - - conv_pool_2 = fluid.nets.simple_img_conv_pool( - input=conv_pool_1, - num_filters=50, - filter_size=5, - pool_size=2, - pool_stride=2, - act='relu') - - logits = fluid.layers.fc( - input=conv_pool_2, - size=10, - act='softmax') - return logits - - -def main(): - """ - Generate adverserial example and evaluate accuracy on mnist using FGSM - """ - - images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype='float32') - # The gradient should flow - images.stop_gradient = False - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - - predict = mnist_cnn_model(images) - cost = fluid.layers.cross_entropy(input=predict, label=label) - avg_cost = fluid.layers.mean(x=cost) - - # Cal gradient of input - params_grads = fluid.backward.append_backward_ops(avg_cost, parameter_list=['pixel']) - # data batch - train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.mnist.train(), buf_size=500), - batch_size=BATCH_SIZE) - - accuracy = fluid.evaluator.Accuracy(input=predict, label=label) - place = fluid.CPUPlace() - exe = fluid.Executor(place) - accuracy.reset(exe) - #exe.run(fluid.default_startup_program()) - feeder = fluid.DataFeeder(feed_list=[images, label], place=place) - for pass_id in range(PASS_NUM): - fluid.io.load_params(exe, "./mnist/", main_program=fluid.default_main_program()) - for data in train_reader(): - # cal gradient and eval accuracy - ps, acc = exe.run( - fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[params_grads[0][1]]+accuracy.metrics) - labels = [] - for idx, _ in enumerate(data): - labels.append(data[idx][1]) - # generate adversarial example - batch_num = ps.shape[0] - new_data = [] - for i in range(batch_num): - adv_img = np.reshape(data[0][0], (1, 28, 28)) + EPS * np.sign(ps[i]) - adv_img = np.clip(adv_img, CLIP_MIN, CLIP_MAX) - #adv_imgs.append(adv_img) - t = (adv_img, data[0][1]) - new_data.append(t) - - # predict label - predict_label, = exe.run( - fluid.default_main_program(), - feed=feeder.feed(new_data), - fetch_list=[predict]) - adv_labels = np.argmax(predict_label, axis=1) - batch_accuracy = np.mean(np.equal(labels, adv_labels)) - print "pass_id=" + str(pass_id) + " acc=" + str(acc)+ " adv_acc=" + str(batch_accuracy) - - -if __name__ == "__main__": - main() From bbb03fceb38b63ece9f864e1cf7aa346ba207056 Mon Sep 17 00:00:00 2001 From: gx_wind Date: Sat, 6 Jan 2018 18:02:05 +0800 Subject: [PATCH 11/30] add readme --- adversarial/README.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 adversarial/README.md diff --git a/adversarial/README.md b/adversarial/README.md new file mode 100644 index 0000000000..7c9502828f --- /dev/null +++ b/adversarial/README.md @@ -0,0 +1,3 @@ +# Advbox + +Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle. \ No newline at end of file From e85c51330700a4125f8574e8c0927407c6d9e3d0 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 7 Jan 2018 13:14:54 +0000 Subject: [PATCH 12/30] Add sequencee erase operator --- paddle/operators/sequence_erase_op.cc | 61 ++++++++++++++ paddle/operators/sequence_erase_op.h | 80 +++++++++++++++++++ .../v2/fluid/tests/test_sequence_erase_op.py | 58 ++++++++++++++ 3 files changed, 199 insertions(+) create mode 100644 paddle/operators/sequence_erase_op.cc create mode 100644 paddle/operators/sequence_erase_op.h create mode 100644 python/paddle/v2/fluid/tests/test_sequence_erase_op.py diff --git a/paddle/operators/sequence_erase_op.cc b/paddle/operators/sequence_erase_op.cc new file mode 100644 index 0000000000..e611ef0571 --- /dev/null +++ b/paddle/operators/sequence_erase_op.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_erase_op.h" + +namespace paddle { +namespace operators { + +class SequenceEraseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceEraseOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceEraseOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } +}; + +class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor) 2-D input LoDTensor with the 2-nd dimension " + "of length 1."); + AddOutput("Out", + "(LoDTensor) 2-D output LoDTensor with the 2-nd dimension " + "of length 1."); + AddAttr>("tokens", + "(vector) " + "Tokens to be removed from input."); + AddComment(R"DOC( +Sequence Erase Operator. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp, + ops::SequenceEraseOpMaker); +REGISTER_OP_CPU_KERNEL( + sequence_erase, + ops::SequenceEraseKernel); diff --git a/paddle/operators/sequence_erase_op.h b/paddle/operators/sequence_erase_op.h new file mode 100644 index 0000000000..937b9870aa --- /dev/null +++ b/paddle/operators/sequence_erase_op.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/softmax.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class SequenceEraseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + // auto dims = x->dims(); + /* + const size_t level = lod.size() - 1; + PADDLE_ENFORCE_EQ(dims[0], static_cast(lod[level].back()), + "The first dimension of Input(X) should be equal to the " + "sum of all sequences' lengths."); + PADDLE_ENFORCE_EQ(dims[0], x->numel(), + "The width of each timestep in Input(X) of " + "SequenceEraseOp should be 1."); + out->mutable_data(ctx.GetPlace()); + */ + auto tokens = ctx.Attr>("tokens"); + auto in_len = in->numel(); + auto in_dat = in->data(); + auto lod0 = lod[0]; + std::vector num_erased(in_len + 1, 0); + for (int64_t i = 1; i < in_len + 1; ++i) { + num_erased[i] = num_erased[i - 1]; + if (std::find(tokens.begin(), tokens.end(), in_dat[i - 1]) != + tokens.end()) { + num_erased[i] += 1; + } + } + + std::vector out_lod0(lod0.size(), 0); + for (size_t i = 1; i < lod0.size(); ++i) { + out_lod0[i] = lod0[i] - num_erased[lod0[i]]; + } + + auto out_len = in_len - num_erased[in_len]; + out->Resize({static_cast(out_len), 1}); + auto out_dat = out->mutable_data(ctx.GetPlace()); + + for (size_t i = 0; i < in_len; ++i) { + if (num_erased[i] == num_erased[i + 1]) { + out_dat[i - num_erased[i]] = in_dat[i]; + } + } + framework::LoD out_lod; + out_lod.push_back(out_lod0); + out->set_lod(out_lod); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py new file mode 100644 index 0000000000..74274cf0ad --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py @@ -0,0 +1,58 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def sequence_erase(in_seq, lod0, tokens): + # num_erased[i]: the number of elments to be removed before #i elements + num_erased = [0] * (len(in_seq) + 1) + for i in range(1, len(in_seq) + 1): + num_erased[i] = num_erased[i - 1] + if in_seq[i - 1] in tokens: + num_erased[i] += 1 + + # recalculate lod information + new_lod0 = [0] * len(lod0) + for i in range(1, len(lod0)): + new_lod0[i] = lod0[i] - num_erased[lod0[i]] + + out_seq = np.zeros( + (len(in_seq) - num_erased[len(in_seq)], 1)).astype("int32") + for i in range(0, len(in_seq)): + if num_erased[i] == num_erased[i + 1]: + out_seq[i - num_erased[i]] = in_seq[i] + # else in_seq[i] needs to be removed + return out_seq, new_lod0 + + +class TestSequenceEraseOp(OpTest): + def setUp(self): + self.op_type = "sequence_erase" + in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + lod = [[0, 5, 15, 30]] + tokens = [2, 5] + out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) + + self.attrs = {'tokens': tokens} + self.inputs = {'X': (in_seq, lod)} + self.outputs = {'Out': (out_seq, [new_lod0])} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + """ + in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + lod0 = [0, 5, 15, 30] + tokens = [2, 5] + out_seq, new_lod = sequence_erase(in_seq, lod0, tokens) + + print lod0, new_lod + print("compare") + for i in range(0, len(lod0)-1): + print(np.transpose(in_seq[lod0[i] : lod0[i+1]])) + print(np.transpose(out_seq[new_lod[i] : new_lod[i+1]])) + print("\n") + """ + unittest.main() From 8e8e5a89f8f015b8a8c0e0a3a6e129c8276f92b1 Mon Sep 17 00:00:00 2001 From: gx_wind Date: Mon, 8 Jan 2018 14:06:17 +0800 Subject: [PATCH 13/30] fix coding standard --- adversarial/README.md | 8 ++++- adversarial/advbox/__init__.py | 1 - adversarial/advbox/attacks/base.py | 1 + adversarial/advbox/attacks/gradientsign.py | 9 ++++-- adversarial/advbox/models/base.py | 2 +- adversarial/advbox/models/paddle.py | 37 ++++++++++------------ 6 files changed, 31 insertions(+), 27 deletions(-) diff --git a/adversarial/README.md b/adversarial/README.md index 7c9502828f..51da21918a 100644 --- a/adversarial/README.md +++ b/adversarial/README.md @@ -1,3 +1,9 @@ # Advbox -Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle. \ No newline at end of file +Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle. + +## How to use + +1. train a model and save it's parameters. (like fluid_mnist.py) +2. load the parameters which is trained in step1, then reconstruct the model.(like mnist_tutorial_fgsm.py) +3. use advbox to generate the adversarial sample. diff --git a/adversarial/advbox/__init__.py b/adversarial/advbox/__init__.py index 4beb6be0a2..f56f14f18d 100644 --- a/adversarial/advbox/__init__.py +++ b/adversarial/advbox/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ A set of tools for generating adversarial example on paddle platform """ diff --git a/adversarial/advbox/attacks/base.py b/adversarial/advbox/attacks/base.py index 9cc2bfb854..dab1dbbeb0 100644 --- a/adversarial/advbox/attacks/base.py +++ b/adversarial/advbox/attacks/base.py @@ -7,6 +7,7 @@ import abc abstractmethod = abc.abstractmethod + class Attack(object): """ Abstract base class for adversarial attacks. `Attack` represent an adversarial attack diff --git a/adversarial/advbox/attacks/gradientsign.py b/adversarial/advbox/attacks/gradientsign.py index 6c188f6249..37fbdb1132 100644 --- a/adversarial/advbox/attacks/gradientsign.py +++ b/adversarial/advbox/attacks/gradientsign.py @@ -5,7 +5,8 @@ from __future__ import division import numpy as np from collections import Iterable from .base import Attack - + + class GradientSignAttack(Attack): """ This attack was originally implemented by Goodfellow et al. (2015) with the @@ -22,10 +23,11 @@ class GradientSignAttack(Attack): gradient_sign = np.sign(gradient) * (max_ - min_) if not isinstance(epsilons, Iterable): - epsilons = np.linspace(0, 1, num = epsilons + 1) + epsilons = np.linspace(0, 1, num=epsilons + 1) for epsilon in epsilons: - adv_img = image_batch[0][0].reshape(gradient_sign.shape) + epsilon * gradient_sign + adv_img = image_batch[0][0].reshape( + gradient_sign.shape) + epsilon * gradient_sign adv_img = np.clip(adv_img, min_, max_) adv_label = np.argmax(self.model.predict([(adv_img, 0)])) #print("pre_label="+str(pre_label)+ " adv_label="+str(adv_label)) @@ -33,4 +35,5 @@ class GradientSignAttack(Attack): #print(epsilon, pre_label, adv_label) return adv_img + FGSM = GradientSignAttack diff --git a/adversarial/advbox/models/base.py b/adversarial/advbox/models/base.py index 91b6fe4a3c..2e5c397dc4 100644 --- a/adversarial/advbox/models/base.py +++ b/adversarial/advbox/models/base.py @@ -6,8 +6,8 @@ import abc abstractmethod = abc.abstractmethod -class Model(object): +class Model(object): """ Base class of model to provide attack. diff --git a/adversarial/advbox/models/paddle.py b/adversarial/advbox/models/paddle.py index 831fa6a362..a72eb148bc 100644 --- a/adversarial/advbox/models/paddle.py +++ b/adversarial/advbox/models/paddle.py @@ -7,6 +7,7 @@ from paddle.v2.fluid.framework import program_guard from .base import Model + class PaddleModel(Model): """ Create a PaddleModel instance. @@ -30,9 +31,7 @@ class PaddleModel(Model): channel_axis=3, preprocess=None): super(PaddleModel, self).__init__( - bounds=bounds, - channel_axis=channel_axis, - preprocess=preprocess) + bounds=bounds, channel_axis=channel_axis, preprocess=preprocess) if preprocess is None: preprocess = (0, 1) @@ -48,7 +47,8 @@ class PaddleModel(Model): # gradient loss = self._program.block(0).var(self._cost_name) - param_grads = fluid.backward.append_backward(loss, parameter_list=[self._input_name]) + param_grads = fluid.backward.append_backward( + loss, parameter_list=[self._input_name]) self._gradient = param_grads[0][1] def predict(self, image_batch): @@ -61,16 +61,13 @@ class PaddleModel(Model): numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes). """ feeder = fluid.DataFeeder( - feed_list=[self._input_name, self._logits_name], - place=self._place, - program=self._program - ) + feed_list=[self._input_name, self._logits_name], + place=self._place, + program=self._program) predict_var = self._program.block(0).var(self._predict_name) - predict = self._exe.run( - self._program, - feed=feeder.feed(image_batch), - fetch_list=[predict_var] - ) + predict = self._exe.run(self._program, + feed=feeder.feed(image_batch), + fetch_list=[predict_var]) return predict def num_classes(self): @@ -95,12 +92,10 @@ class PaddleModel(Model): """ feeder = fluid.DataFeeder( feed_list=[self._input_name, self._logits_name], - place=self._place, - program=self._program - ) - - grad, = self._exe.run( - self._program, - feed=feeder.feed(image_batch), - fetch_list=[self._gradient]) + place=self._place, + program=self._program) + + grad, = self._exe.run(self._program, + feed=feeder.feed(image_batch), + fetch_list=[self._gradient]) return grad From 343b32a0d143f5a3bffa43e418dfecf274ffec58 Mon Sep 17 00:00:00 2001 From: gx_wind Date: Mon, 8 Jan 2018 14:51:54 +0800 Subject: [PATCH 14/30] fix coding standard --- adversarial/fluid_mnist.py | 14 ++++---- adversarial/mnist_tutorial_fgsm.py | 53 +++++++++++++----------------- 2 files changed, 30 insertions(+), 37 deletions(-) diff --git a/adversarial/fluid_mnist.py b/adversarial/fluid_mnist.py index d46defda55..031928e994 100644 --- a/adversarial/fluid_mnist.py +++ b/adversarial/fluid_mnist.py @@ -4,6 +4,7 @@ CNN on mnist data using fluid api of paddlepaddle import paddle.v2 as paddle import paddle.v2.fluid as fluid + def mnist_cnn_model(img): """ Mnist cnn model @@ -31,10 +32,7 @@ def mnist_cnn_model(img): pool_stride=2, act='relu') - logits = fluid.layers.fc( - input=conv_pool_2, - size=10, - act='softmax') + logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') return logits @@ -73,17 +71,19 @@ def main(): feed=feeder.feed(data), fetch_list=[avg_cost] + accuracy.metrics) pass_acc = accuracy.eval(exe) - print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" + - str(pass_acc)) + print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" + + str(pass_acc)) # print loss, acc if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD: # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good. break + # exit(0) pass_acc = accuracy.eval(exe) print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc)) - fluid.io.save_params(exe, dirname='./mnist', main_program=fluid.default_main_program()) + fluid.io.save_params( + exe, dirname='./mnist', main_program=fluid.default_main_program()) print('train mnist done') exit(1) diff --git a/adversarial/mnist_tutorial_fgsm.py b/adversarial/mnist_tutorial_fgsm.py index 665062afd0..8b29346b8c 100644 --- a/adversarial/mnist_tutorial_fgsm.py +++ b/adversarial/mnist_tutorial_fgsm.py @@ -9,6 +9,7 @@ import numpy as np from advbox.models.paddle import PaddleModel from advbox.attacks.gradientsign import GradientSignAttack + def cnn_model(img): """ Mnist cnn model @@ -19,25 +20,22 @@ def cnn_model(img): """ #conv1 = fluid.nets.conv2d() conv_pool_1 = fluid.nets.simple_img_conv_pool( - input=img, - num_filters=20, - filter_size=5, - pool_size=2, - pool_stride=2, - act='relu') + input=img, + num_filters=20, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') conv_pool_2 = fluid.nets.simple_img_conv_pool( - input=conv_pool_1, - num_filters=50, - filter_size=5, - pool_size=2, - pool_stride=2, - act='relu') + input=conv_pool_1, + num_filters=50, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') - logits = fluid.layers.fc( - input=conv_pool_2, - size=10, - act='softmax') + logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') return logits @@ -65,22 +63,16 @@ def main(): paddle.dataset.mnist.train(), buf_size=500), batch_size=BATCH_SIZE) feeder = fluid.DataFeeder( - feed_list=[IMG_NAME, LABEL_NAME], - place=place, - program=fluid.default_main_program() - ) + feed_list=[IMG_NAME, LABEL_NAME], + place=place, + program=fluid.default_main_program()) - fluid.io.load_params(exe, "./mnist/", main_program=fluid.default_main_program()) + fluid.io.load_params( + exe, "./mnist/", main_program=fluid.default_main_program()) # advbox demo - m = PaddleModel( - fluid.default_main_program(), - IMG_NAME, - LABEL_NAME, - logits.name, - avg_cost.name, - (-1, 1) - ) + m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME, + logits.name, avg_cost.name, (-1, 1)) att = GradientSignAttack(m) for data in train_reader(): # fgsm attack @@ -89,6 +81,7 @@ def main(): plt.show() #np.save('adv_img', adv_img) break - + + if __name__ == '__main__': main() From 37f933b8ad85fb17fa903e59074ab6225ef4eec3 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 8 Jan 2018 15:25:45 +0000 Subject: [PATCH 15/30] Add gpu kernel for sequence_erase_op --- paddle/operators/sequence_erase_op.cu | 136 ++++++++++++++++++ paddle/operators/sequence_erase_op.h | 4 +- .../v2/fluid/tests/test_sequence_erase_op.py | 1 - 3 files changed, 138 insertions(+), 3 deletions(-) create mode 100644 paddle/operators/sequence_erase_op.cu diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu new file mode 100644 index 0000000000..5d314586d4 --- /dev/null +++ b/paddle/operators/sequence_erase_op.cu @@ -0,0 +1,136 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "paddle/operators/sequence_erase_op.h" +#include "paddle/platform/cuda_helper.h" +#include "paddle/platform/gpu_info.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +__global__ void LabelErasedIdx(const T* in_dat, const int in_len, + const T* tokens, const int tokens_len, + int* num_erased) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < in_len) { + int erased = 0; + for (int i = 0; i < tokens_len; ++i) { + if (in_dat[index] == tokens[i]) { + erased = 1; + } + } + num_erased[index + 1] = erased; + if (index == 0) { + num_erased[0] = 0; + } + } +} + +template +__global__ void GetOutLod(const T* num_erased, const int* in_lod, + const int lod_len, int* out_lod0) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < lod_len) { + out_lod0[index] = in_lod[index] - num_erased[in_lod[index]]; + } +} + +template +__global__ void SetOutput(const T* in_dat, const int in_len, + const int* num_erased, T* out_dat) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < in_len) { + if (in_dat[index] != in_dat[index + 1]) { + out_dat[index - num_erased[index]] = in_dat[index]; + } + } +} + +template +class SequenceEraseOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + auto tokens = ctx.Attr>("tokens"); + auto tokens_len = tokens.size(); + auto in_len = in->numel(); + auto in_dat = in->data(); + auto lod0 = lod[0]; + + thrust::host_vector host_tokens(tokens_len); + for (size_t i = 0; i < tokens.size(); ++i) { + host_tokens[i] = tokens[i]; + } + thrust::device_vector dev_tokens = host_tokens; + thrust::device_vector num_erased(in_len + 1); + + T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); + int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data()); + + auto stream = ctx.cuda_device_context().stream(); + LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_dat, in_len, dev_tokens_ptr, tokens_len, num_erased_ptr); + thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(), + num_erased.begin() + 1); + + // Reset LoD + auto lod_len = lod0.size(); + thrust::host_vector host_lod(lod_len); + for (size_t i = 0; i < lod_len; ++i) { + host_lod[i] = lod0[i]; + } + thrust::device_vector dev_in_lod = host_lod; + thrust::device_vector dev_out_lod(lod_len); + int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); + int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data()); + GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); + thrust::host_vector host_out_lod = dev_out_lod; + std::vector out_lod0(lod_len, 0); + for (size_t i = 0; i < lod_len; i++) { + out_lod0[i] = host_out_lod[i]; + } + framework::LoD out_lod; + out_lod.push_back(out_lod0); + + out->Resize({out_lod0.back(), 1}); + // Set output + auto out_dat = out->mutable_data(ctx.GetPlace()); + SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len, + num_erased_ptr, out_dat); + // Set LoD + out->set_lod(out_lod); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(sequence_erase, + paddle::operators::SequenceEraseOpCUDAKernel); diff --git a/paddle/operators/sequence_erase_op.h b/paddle/operators/sequence_erase_op.h index 937b9870aa..caf168a93d 100644 --- a/paddle/operators/sequence_erase_op.h +++ b/paddle/operators/sequence_erase_op.h @@ -27,8 +27,8 @@ template class SequenceEraseKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* out = ctx.Output("Out"); + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); auto lod = in->lod(); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); diff --git a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py index 74274cf0ad..e730f2f4b7 100644 --- a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py +++ b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py @@ -32,7 +32,6 @@ class TestSequenceEraseOp(OpTest): lod = [[0, 5, 15, 30]] tokens = [2, 5] out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) - self.attrs = {'tokens': tokens} self.inputs = {'X': (in_seq, lod)} self.outputs = {'Out': (out_seq, [new_lod0])} From 7b9d5b325c7c513815085e9ab1f59a42600aecfc Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 8 Jan 2018 16:56:31 +0000 Subject: [PATCH 16/30] Add document for sequence_erase_op --- paddle/operators/sequence_erase_op.cc | 42 +++++++++++++++---- paddle/operators/sequence_erase_op.cu | 11 ++--- paddle/operators/sequence_erase_op.h | 17 +------- .../v2/fluid/tests/test_sequence_erase_op.py | 6 +-- 4 files changed, 42 insertions(+), 34 deletions(-) diff --git a/paddle/operators/sequence_erase_op.cc b/paddle/operators/sequence_erase_op.cc index e611ef0571..331970b3f8 100644 --- a/paddle/operators/sequence_erase_op.cc +++ b/paddle/operators/sequence_erase_op.cc @@ -26,7 +26,11 @@ class SequenceEraseOp : public framework::OperatorWithKernel { "Input(X) of SequenceEraseOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SequenceEraseOp should not be null."); - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(x_dims.size() == 2 && x_dims[1] == 1, + "Input(X) of SequenceEraseOp should be a 2-D LoDTensor " + "with the 2nd dimension equal to 1."); + ctx->SetOutputDim("Out", x_dims); } }; @@ -35,17 +39,41 @@ class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "(LoDTensor) 2-D input LoDTensor with the 2-nd dimension " - "of length 1."); + "(2-D LoDTensor with the 2nd dim. equal to 1) " + "Input LoDTensor of SequenceEraseOp."); AddOutput("Out", - "(LoDTensor) 2-D output LoDTensor with the 2-nd dimension " - "of length 1."); + "(2-D LoDTensor with the 2nd dim. equal to 1) " + "Output LoDTensor of SequenceEraseOp."); AddAttr>("tokens", - "(vector) " - "Tokens to be removed from input."); + "(vector) Tokens need to be erased from " + "input sequences."); AddComment(R"DOC( Sequence Erase Operator. +Sequence erase operator erases tokens specified by Attr(tokens) in the input +sequences Input(X), and outputs the remaining data and modifies the LoD +information at the same time. For example, given a 2-D LoDTensor + + X = [[2, 2, 6, 1, 3, 9, 6, 1, 0, 1]]^T + +with lod = [[0, 3, 6, 10]], there are three sequences in the input: + + X1 = [[2, 2, 6]]^T, X2 = [[1, 3, 9]]^T and X3 = [[6, 1, 0, 1]]^T. + +If the tokens to be erased are Attr(tokens) = [2, 3, 5], after the erasing +operation, the three sequences become + + X1' = [[6]]^T, X2' = [[1, 9]]^T and X3' = [[6, 1, 0, 1]]^T. + +Hence the LoDTensor Output(Out) should be + + Out = [[6, 1, 9, 6, 1, 0, 1]]^T, + +with lod = [[0, 1, 3, 7]]. + +An example usage for this operator is to remove the special tokens when +computing the edit distance between two strings, such as blank, start token, +and end token. )DOC"); } }; diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu index 5d314586d4..3695a24cb7 100644 --- a/paddle/operators/sequence_erase_op.cu +++ b/paddle/operators/sequence_erase_op.cu @@ -13,17 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include #include -#include #include "paddle/operators/sequence_erase_op.h" #include "paddle/platform/cuda_helper.h" -#include "paddle/platform/gpu_info.h" namespace paddle { namespace operators { using platform::PADDLE_CUDA_NUM_THREADS; -using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; template @@ -97,7 +93,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(), num_erased.begin() + 1); - // Reset LoD + // Calc LoD auto lod_len = lod0.size(); thrust::host_vector host_lod(lod_len); for (size_t i = 0; i < lod_len; ++i) { @@ -117,15 +113,14 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { } framework::LoD out_lod; out_lod.push_back(out_lod0); + out->set_lod(out_lod); - out->Resize({out_lod0.back(), 1}); // Set output + out->Resize({out_lod0.back(), 1}); auto out_dat = out->mutable_data(ctx.GetPlace()); SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len, num_erased_ptr, out_dat); - // Set LoD - out->set_lod(out_lod); } }; diff --git a/paddle/operators/sequence_erase_op.h b/paddle/operators/sequence_erase_op.h index caf168a93d..92aa4a82b0 100644 --- a/paddle/operators/sequence_erase_op.h +++ b/paddle/operators/sequence_erase_op.h @@ -15,14 +15,10 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/softmax.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - template class SequenceEraseKernel : public framework::OpKernel { public: @@ -32,17 +28,6 @@ class SequenceEraseKernel : public framework::OpKernel { auto lod = in->lod(); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); - // auto dims = x->dims(); - /* - const size_t level = lod.size() - 1; - PADDLE_ENFORCE_EQ(dims[0], static_cast(lod[level].back()), - "The first dimension of Input(X) should be equal to the " - "sum of all sequences' lengths."); - PADDLE_ENFORCE_EQ(dims[0], x->numel(), - "The width of each timestep in Input(X) of " - "SequenceEraseOp should be 1."); - out->mutable_data(ctx.GetPlace()); - */ auto tokens = ctx.Attr>("tokens"); auto in_len = in->numel(); auto in_dat = in->data(); @@ -65,7 +50,7 @@ class SequenceEraseKernel : public framework::OpKernel { out->Resize({static_cast(out_len), 1}); auto out_dat = out->mutable_data(ctx.GetPlace()); - for (size_t i = 0; i < in_len; ++i) { + for (int64_t i = 0; i < in_len; ++i) { if (num_erased[i] == num_erased[i + 1]) { out_dat[i - num_erased[i]] = in_dat[i]; } diff --git a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py index e730f2f4b7..78105334f5 100644 --- a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py +++ b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py @@ -28,9 +28,9 @@ def sequence_erase(in_seq, lod0, tokens): class TestSequenceEraseOp(OpTest): def setUp(self): self.op_type = "sequence_erase" - in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") - lod = [[0, 5, 15, 30]] - tokens = [2, 5] + in_seq = np.random.randint(0, 10, (10, 1)).astype("int32") + lod = [[0, 3, 6, 10]] + tokens = [2, 3, 5] out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) self.attrs = {'tokens': tokens} self.inputs = {'X': (in_seq, lod)} From bf1e03721bb84b7390231fe6c16e646edd7a5a76 Mon Sep 17 00:00:00 2001 From: gx_wind Date: Tue, 9 Jan 2018 12:58:44 +0800 Subject: [PATCH 17/30] delete comment --- adversarial/advbox/attacks/base.py | 6 +----- adversarial/advbox/attacks/gradientsign.py | 2 -- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/adversarial/advbox/attacks/base.py b/adversarial/advbox/attacks/base.py index dab1dbbeb0..2058e3cfeb 100644 --- a/adversarial/advbox/attacks/base.py +++ b/adversarial/advbox/attacks/base.py @@ -1,11 +1,7 @@ """ The base model of the model. """ -from abc import ABCMeta -#from advbox.base import Model -import abc - -abstractmethod = abc.abstractmethod +from abc import ABCMeta, abstractmethod class Attack(object): diff --git a/adversarial/advbox/attacks/gradientsign.py b/adversarial/advbox/attacks/gradientsign.py index 37fbdb1132..dff518811e 100644 --- a/adversarial/advbox/attacks/gradientsign.py +++ b/adversarial/advbox/attacks/gradientsign.py @@ -30,9 +30,7 @@ class GradientSignAttack(Attack): gradient_sign.shape) + epsilon * gradient_sign adv_img = np.clip(adv_img, min_, max_) adv_label = np.argmax(self.model.predict([(adv_img, 0)])) - #print("pre_label="+str(pre_label)+ " adv_label="+str(adv_label)) if pre_label != adv_label: - #print(epsilon, pre_label, adv_label) return adv_img From 97724c2a1462625ab7dad7d2f1cbf8f7fc2a325f Mon Sep 17 00:00:00 2001 From: gx_wind Date: Tue, 9 Jan 2018 16:35:52 +0800 Subject: [PATCH 18/30] fix bugs and modify func param name --- adversarial/advbox/attacks/base.py | 10 +++++----- adversarial/advbox/attacks/gradientsign.py | 9 +++++---- adversarial/advbox/models/base.py | 3 +-- adversarial/advbox/models/paddle.py | 2 +- adversarial/fluid_mnist.py | 7 +------ 5 files changed, 13 insertions(+), 18 deletions(-) diff --git a/adversarial/advbox/attacks/base.py b/adversarial/advbox/attacks/base.py index 2058e3cfeb..98a65f2fdd 100644 --- a/adversarial/advbox/attacks/base.py +++ b/adversarial/advbox/attacks/base.py @@ -18,22 +18,22 @@ class Attack(object): def __init__(self, model): self.model = model - def __call__(self, image_batch): + def __call__(self, image_label): """ Generate the adversarial sample. Args: - image_batch(list): The image and label tuple list. + image_label(list): The image and label tuple list with one element. """ - adv_img = self._apply(image_batch) + adv_img = self._apply(image_label) return adv_img @abstractmethod - def _apply(self, image_batch): + def _apply(self, image_label): """ Search an adversarial example. Args: - image_batch(list): The image and label tuple list. + image_batch(list): The image and label tuple list with one element. """ raise NotImplementedError diff --git a/adversarial/advbox/attacks/gradientsign.py b/adversarial/advbox/attacks/gradientsign.py index dff518811e..15b1d176cb 100644 --- a/adversarial/advbox/attacks/gradientsign.py +++ b/adversarial/advbox/attacks/gradientsign.py @@ -15,18 +15,19 @@ class GradientSignAttack(Attack): Paper link: https://arxiv.org/abs/1412.6572 """ - def _apply(self, image_batch, epsilons=1000): - pre_label = np.argmax(self.model.predict(image_batch)) + def _apply(self, image_label, epsilons=1000): + assert len(image_label) == 1 + pre_label = np.argmax(self.model.predict(image_label)) min_, max_ = self.model.bounds() - gradient = self.model.gradient(image_batch) + gradient = self.model.gradient(image_label) gradient_sign = np.sign(gradient) * (max_ - min_) if not isinstance(epsilons, Iterable): epsilons = np.linspace(0, 1, num=epsilons + 1) for epsilon in epsilons: - adv_img = image_batch[0][0].reshape( + adv_img = image_label[0][0].reshape( gradient_sign.shape) + epsilon * gradient_sign adv_img = np.clip(adv_img, min_, max_) adv_label = np.argmax(self.model.predict([(adv_img, 0)])) diff --git a/adversarial/advbox/models/base.py b/adversarial/advbox/models/base.py index 2e5c397dc4..74e1045def 100644 --- a/adversarial/advbox/models/base.py +++ b/adversarial/advbox/models/base.py @@ -81,8 +81,7 @@ class Model(object): Calculate the gradient of the cross-entropy loss w.r.t the image. Args: - image(numpy.ndarray): image with shape (height, width, channel) - label(int): image label used to cal gradient. + image_batch(list): The image and label tuple list. Return: numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with diff --git a/adversarial/advbox/models/paddle.py b/adversarial/advbox/models/paddle.py index a72eb148bc..33b2a3d5c6 100644 --- a/adversarial/advbox/models/paddle.py +++ b/adversarial/advbox/models/paddle.py @@ -49,7 +49,7 @@ class PaddleModel(Model): loss = self._program.block(0).var(self._cost_name) param_grads = fluid.backward.append_backward( loss, parameter_list=[self._input_name]) - self._gradient = param_grads[0][1] + self._gradient = dict(param_grads)[self._input_name] def predict(self, image_batch): """ diff --git a/adversarial/fluid_mnist.py b/adversarial/fluid_mnist.py index 031928e994..db4d4b5186 100644 --- a/adversarial/fluid_mnist.py +++ b/adversarial/fluid_mnist.py @@ -15,7 +15,6 @@ def mnist_cnn_model(img): Returns: Variable: the label prediction """ - #conv1 = fluid.nets.conv2d() conv_pool_1 = fluid.nets.simple_img_conv_pool( input=img, num_filters=20, @@ -73,19 +72,15 @@ def main(): pass_acc = accuracy.eval(exe) print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" + str(pass_acc)) - # print loss, acc if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD: - # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good. break -# exit(0) - pass_acc = accuracy.eval(exe) print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc)) fluid.io.save_params( exe, dirname='./mnist', main_program=fluid.default_main_program()) print('train mnist done') - exit(1) + if __name__ == '__main__': main() From 10779460c5c127c257203d95d5f4740db4d55cad Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 10 Jan 2018 08:03:08 +0000 Subject: [PATCH 19/30] Simplify calc in test_sequence_erase_op --- paddle/operators/sequence_erase_op.cc | 2 +- paddle/operators/sequence_erase_op.cu | 2 + paddle/operators/sequence_erase_op.h | 25 ++++++---- .../v2/fluid/tests/test_sequence_erase_op.py | 46 +++++-------------- 4 files changed, 30 insertions(+), 45 deletions(-) diff --git a/paddle/operators/sequence_erase_op.cc b/paddle/operators/sequence_erase_op.cc index 331970b3f8..d17b268623 100644 --- a/paddle/operators/sequence_erase_op.cc +++ b/paddle/operators/sequence_erase_op.cc @@ -50,7 +50,7 @@ class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Sequence Erase Operator. -Sequence erase operator erases tokens specified by Attr(tokens) in the input +Sequence erase operator erases tokens specified by Attr(tokens) from the input sequences Input(X), and outputs the remaining data and modifies the LoD information at the same time. For example, given a 2-D LoDTensor diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu index 3695a24cb7..5da8eba3e1 100644 --- a/paddle/operators/sequence_erase_op.cu +++ b/paddle/operators/sequence_erase_op.cu @@ -70,6 +70,8 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { auto lod = in->lod(); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), + "The actual size mismatches with the LoD information."); auto tokens = ctx.Attr>("tokens"); auto tokens_len = tokens.size(); auto in_len = in->numel(); diff --git a/paddle/operators/sequence_erase_op.h b/paddle/operators/sequence_erase_op.h index 92aa4a82b0..cb2d7be009 100644 --- a/paddle/operators/sequence_erase_op.h +++ b/paddle/operators/sequence_erase_op.h @@ -28,22 +28,27 @@ class SequenceEraseKernel : public framework::OpKernel { auto lod = in->lod(); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), + "The actual size mismatches with the LoD information."); auto tokens = ctx.Attr>("tokens"); auto in_len = in->numel(); auto in_dat = in->data(); auto lod0 = lod[0]; + std::vector num_erased(in_len + 1, 0); - for (int64_t i = 1; i < in_len + 1; ++i) { - num_erased[i] = num_erased[i - 1]; - if (std::find(tokens.begin(), tokens.end(), in_dat[i - 1]) != - tokens.end()) { - num_erased[i] += 1; + std::vector out_lod0(1, 0); + for (size_t i = 0; i < lod0.size() - 1; ++i) { + size_t num_out = 0; + for (auto j = lod0[i] + 1; j <= lod0[i + 1]; ++j) { + num_erased[j] = num_erased[j - 1]; + if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) != + tokens.end()) { + num_erased[j] += 1; + } else { + num_out += 1; + } } - } - - std::vector out_lod0(lod0.size(), 0); - for (size_t i = 1; i < lod0.size(); ++i) { - out_lod0[i] = lod0[i] - num_erased[lod0[i]]; + out_lod0.push_back(out_lod0.back() + num_out); } auto out_len = in_len - num_erased[in_len]; diff --git a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py index 78105334f5..bf257fefea 100644 --- a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py +++ b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py @@ -4,32 +4,23 @@ from op_test import OpTest def sequence_erase(in_seq, lod0, tokens): - # num_erased[i]: the number of elments to be removed before #i elements - num_erased = [0] * (len(in_seq) + 1) - for i in range(1, len(in_seq) + 1): - num_erased[i] = num_erased[i - 1] - if in_seq[i - 1] in tokens: - num_erased[i] += 1 - - # recalculate lod information - new_lod0 = [0] * len(lod0) - for i in range(1, len(lod0)): - new_lod0[i] = lod0[i] - num_erased[lod0[i]] - - out_seq = np.zeros( - (len(in_seq) - num_erased[len(in_seq)], 1)).astype("int32") - for i in range(0, len(in_seq)): - if num_erased[i] == num_erased[i + 1]: - out_seq[i - num_erased[i]] = in_seq[i] - # else in_seq[i] needs to be removed - return out_seq, new_lod0 + new_lod0 = [0] + out_seq = [] + for i in range(0, len(lod0) - 1): + num_out = 0 + for dat in in_seq[lod0[i]:lod0[i + 1]]: + if dat not in tokens: + out_seq.append(dat) + num_out += 1 + new_lod0.append(new_lod0[-1] + num_out) + return np.array(out_seq).astype("int32"), new_lod0 class TestSequenceEraseOp(OpTest): def setUp(self): self.op_type = "sequence_erase" - in_seq = np.random.randint(0, 10, (10, 1)).astype("int32") - lod = [[0, 3, 6, 10]] + in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + lod = [[0, 9, 13, 24, 30]] tokens = [2, 3, 5] out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) self.attrs = {'tokens': tokens} @@ -41,17 +32,4 @@ class TestSequenceEraseOp(OpTest): if __name__ == '__main__': - """ - in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") - lod0 = [0, 5, 15, 30] - tokens = [2, 5] - out_seq, new_lod = sequence_erase(in_seq, lod0, tokens) - - print lod0, new_lod - print("compare") - for i in range(0, len(lod0)-1): - print(np.transpose(in_seq[lod0[i] : lod0[i+1]])) - print(np.transpose(out_seq[new_lod[i] : new_lod[i+1]])) - print("\n") - """ unittest.main() From 929d22c62213e3fa1e05109e2f88788f7b1f7501 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 10 Jan 2018 16:39:07 +0800 Subject: [PATCH 20/30] auto set openblas env --- paddle/scripts/submit_local.sh.in | 3 +++ python/paddle/v2/__init__.py | 35 +++++++++++++++++++------------ python/setup.py.in | 7 ++++++- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/paddle/scripts/submit_local.sh.in b/paddle/scripts/submit_local.sh.in index 8a352b0078..bb47ad614e 100755 --- a/paddle/scripts/submit_local.sh.in +++ b/paddle/scripts/submit_local.sh.in @@ -92,6 +92,9 @@ function threads_config() { if [ -z "$OPENBLAS_NUM_THREADS" ]; then export OPENBLAS_NUM_THREADS=$threads fi + if [ $threads -gt 1 ] && [ -z "$OPENBLAS_MAIN_FREE" ]; then + export OPENBLAS_MAIN_FREE=1 + fi fi } diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 0de417df2c..df710c33d0 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -62,12 +62,15 @@ __all__ = [ cp.begin_parse() -def set_omp_mkl_env_vars(trainer_count): +def set_env_vars(trainer_count): '''Auto set CPU environment if have not set before. - export KMP_AFFINITY, OMP_DYNAMIC according to the Hyper Threading status. - export OMP_NUM_THREADS, MKL_NUM_THREADS according to trainer_count. + For MKL: + export KMP_AFFINITY, OMP_DYNAMIC according to the Hyper Threading status. + export OMP_NUM_THREADS, MKL_NUM_THREADS according to trainer_count. + For OpenBLAS: + export OPENBLAS_NUM_THREADS, OPENBLAS_MAIN_FREE according to trainer_count. ''' - import platform + import platform, paddle if not platform.system() in ['Linux', 'Darwin']: return @@ -103,16 +106,22 @@ def set_omp_mkl_env_vars(trainer_count): num_cores = num_physical_cores() num_processors = num_logical_processors() - if num_processors > num_cores: # Hyper Threading is enabled - set_env("OMP_DYNAMIC", "true") - set_env("KMP_AFFINITY", "granularity=fine,compact,1,0") - else: - set_env("OMP_DYNAMIC", "false") - set_env("KMP_AFFINITY", "granularity=fine,compact,0,0") + if paddle.version.mkl() == 'ON': + if num_processors > num_cores: # Hyper Threading is enabled + set_env("OMP_DYNAMIC", "true") + set_env("KMP_AFFINITY", "granularity=fine,compact,1,0") + else: + set_env("OMP_DYNAMIC", "false") + set_env("KMP_AFFINITY", "granularity=fine,compact,0,0") threads = num_processors / trainer_count threads = '1' if threads < 1 else str(threads) - set_env("OMP_NUM_THREADS", threads) - set_env("MKL_NUM_THREADS", threads) + if paddle.version.mkl() == 'ON': + set_env("OMP_NUM_THREADS", threads) + set_env("MKL_NUM_THREADS", threads) + else: + set_env("OPENBLAS_NUM_THREADS", threads) + if threads > 1: + set_env("OPENBLAS_MAIN_FREE", '1') def init(**kwargs): @@ -129,7 +138,7 @@ def init(**kwargs): for key in args_dict.keys(): args.append('--%s=%s' % (key, str(args_dict[key]))) - set_omp_mkl_env_vars(kwargs.get('trainer_count', 1)) + set_env_vars(kwargs.get('trainer_count', 1)) if 'use_gpu' in kwargs: cp.g_command_config_args['use_gpu'] = kwargs['use_gpu'] diff --git a/python/setup.py.in b/python/setup.py.in index 66ccfe8087..65ec58ecf9 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -31,6 +31,7 @@ patch = '%(patch)d' rc = '%(rc)d' istaged = %(istaged)s commit = '%(commit)s' +with_mkl = '%(with_mkl)s' def show(): if istaged: @@ -41,6 +42,9 @@ def show(): print 'rc:', rc else: print 'commit:', commit + +def mkl(): + return with_mkl ''' commit = git_commit() with open(filename, 'w') as f: @@ -51,7 +55,8 @@ def show(): 'rc': RC, 'version': '${PADDLE_VERSION}', 'commit': commit, - 'istaged': ISTAGED}) + 'istaged': ISTAGED, + 'with_mkl': '@WITH_MKL@'}) write_version_py(filename='@PADDLE_SOURCE_DIR@/python/paddle/version.py') From f594ca436939b1ef0133727eadf0d5470ff74f67 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 10 Jan 2018 09:17:03 +0000 Subject: [PATCH 21/30] Reuse the usable variable in edit_distance_op --- paddle/operators/edit_distance_op.cc | 4 ++-- paddle/operators/edit_distance_op.cu | 8 ++++---- paddle/operators/edit_distance_op.h | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/paddle/operators/edit_distance_op.cc b/paddle/operators/edit_distance_op.cc index 7b92148f0e..441ae2aa00 100644 --- a/paddle/operators/edit_distance_op.cc +++ b/paddle/operators/edit_distance_op.cc @@ -49,10 +49,10 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Hyps", - "(2-D LoDTensor, 2nd dim. equal to 1) " + "(2-D LoDTensor, 2nd dim. equal to 1) " "The indices for hypothesis strings."); AddInput("Refs", - "(2-D LoDTensor, 2nd dim. equal to 1) " + "(2-D LoDTensor, 2nd dim. equal to 1) " "The indices for reference strings."); AddAttr("normalized", "(bool, default false) Indicated whether to normalize " diff --git a/paddle/operators/edit_distance_op.cu b/paddle/operators/edit_distance_op.cu index b548345986..cf5ebc5c38 100644 --- a/paddle/operators/edit_distance_op.cu +++ b/paddle/operators/edit_distance_op.cu @@ -93,21 +93,21 @@ class EditDistanceGPUKernel : public framework::OpKernel { out_t->mutable_data(ctx.GetPlace()); auto out = out_t->data(); - std::vector distance(num_strs, 0.0); + T distance = 0.0; for (size_t num = 0; num < num_strs; num++) { auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); if (m == 0 || n == 0) { - distance[num] = std::max(m, n); + distance = std::max(m, n); if (normalized) { PADDLE_ENFORCE(n > 0, "The reference string (#%d) cannot be empty " "when Attr(normalized) is enabled.", n); - distance[num] = distance[num] / n; + distance = distance / n; } memory::Copy(boost::get(ctx.GetPlace()), out + num, - platform::CPUPlace(), &distance[num], sizeof(T), stream); + platform::CPUPlace(), &distance, sizeof(T), stream); } else { framework::Tensor dist_t; dist_t.Resize({m + 1, n + 1}); diff --git a/paddle/operators/edit_distance_op.h b/paddle/operators/edit_distance_op.h index 6284f230e5..537e70281a 100644 --- a/paddle/operators/edit_distance_op.h +++ b/paddle/operators/edit_distance_op.h @@ -46,15 +46,15 @@ class EditDistanceKernel : public framework::OpKernel { out_t->mutable_data(ctx.GetPlace()); auto out = out_t->data(); - std::vector distance(num_strs, 0.0); + T distance = 0.0; for (size_t num = 0; num < num_strs; ++num) { auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); if (m == 0) { - distance[num] = n; + distance = n; } else if (n == 0) { - distance[num] = m; + distance = m; } else { framework::Tensor dist_t; dist_t.Resize({m + 1, n + 1}); @@ -77,7 +77,7 @@ class EditDistanceKernel : public framework::OpKernel { dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); } } - distance[num] = dist[m * (n + 1) + n]; + distance = dist[m * (n + 1) + n]; } if (normalized) { @@ -85,9 +85,9 @@ class EditDistanceKernel : public framework::OpKernel { "The reference string (#%d) cannot be empty " "when Attr(normalized) is enabled.", n); - distance[num] = distance[num] / n; + distance = distance / n; } - out[num] = distance[num]; + out[num] = distance; } } }; From a1935b23c48e3b7e46f70d568442b6bf5340b999 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 10 Jan 2018 09:26:53 +0000 Subject: [PATCH 22/30] Remove unnecessary prefix in test name of edit_distance_op --- python/paddle/v2/fluid/tests/test_edit_distance_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py index 24f2f0c5c2..38e87728b3 100644 --- a/python/paddle/v2/fluid/tests/test_edit_distance_op.py +++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py @@ -34,7 +34,7 @@ def Levenshtein(hyp, ref): return dist[m][n] -class TestCTCEditDistanceOp(OpTest): +class TestEditDistanceOp(OpTest): def setUp(self): self.op_type = "edit_distance" normalized = False @@ -62,7 +62,7 @@ class TestCTCEditDistanceOp(OpTest): self.check_output() -class TestCTCEditDistanceOpNormalized(OpTest): +class TestEditDistanceOpNormalized(OpTest): def setUp(self): self.op_type = "edit_distance" normalized = True From fe0ef91a3f3db8a806a462a030392a57b208d4ad Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 10 Jan 2018 11:26:50 +0000 Subject: [PATCH 23/30] fix ci error in edit_distance_op --- paddle/operators/edit_distance_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/operators/edit_distance_op.cc b/paddle/operators/edit_distance_op.cc index 441ae2aa00..e383f07fa9 100644 --- a/paddle/operators/edit_distance_op.cc +++ b/paddle/operators/edit_distance_op.cc @@ -37,7 +37,7 @@ class EditDistanceOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetActualKernelType( + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType(framework::proto::DataType::FP32, ctx.device_context()); From a7e847b648fb0084990ee914e1365b017436ce4e Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 10 Jan 2018 16:31:05 +0800 Subject: [PATCH 24/30] fix ds2 issue --- paddle/gserver/layers/MKLDNNLayer.cpp | 2 ++ paddle/gserver/layers/MKLDNNLayer.h | 2 ++ 2 files changed, 4 insertions(+) diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index 6fbf3c7fde..2d0fff608c 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -132,6 +132,8 @@ void MKLDNNLayer::reshapeInput(int& batchsize, if (w != 0) { width = w; } + height = height != 0 ? height : 1; + width = width != 0 ? width : 1; } void MKLDNNLayer::reshapeOutput(size_t height, size_t width) { diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index e48b9b5a91..3ba39f18b6 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -98,6 +98,8 @@ protected: public: explicit MKLDNNLayer(const LayerConfig& config) : Layer(config), + ih_(0), + iw_(0), condition_(0), needResetBwd_(true), outputOnlyMKLDNN_(false), From da3087ada1de62caea7ea2b2f819eb24ea5a6088 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 11 Jan 2018 09:28:24 +0800 Subject: [PATCH 25/30] Async GRPC sendrecv (#7133) Async GRPC sendrecv --- paddle/operators/detail/CMakeLists.txt | 2 +- paddle/operators/detail/grpc_client.cc | 147 ++++++++++++ paddle/operators/detail/grpc_client.h | 147 ++++++++++++ paddle/operators/detail/grpc_server.cc | 237 ++++++++++++++++++++ paddle/operators/detail/grpc_server.h | 91 ++++++++ paddle/operators/detail/recv_impl.cc | 65 ------ paddle/operators/detail/send_impl.cc | 67 ------ paddle/operators/detail/send_recv.proto | 2 - paddle/operators/detail/send_recv_impl.h | 141 ------------ paddle/operators/detail/sendrecvop_utils.cc | 68 ++++++ paddle/operators/detail/sendrecvop_utils.h | 42 ++++ paddle/operators/recv_op.cc | 43 ++-- paddle/operators/send_op.cc | 56 ++--- paddle/operators/send_recv_op_test.cc | 2 +- 14 files changed, 775 insertions(+), 335 deletions(-) create mode 100644 paddle/operators/detail/grpc_client.cc create mode 100644 paddle/operators/detail/grpc_client.h create mode 100644 paddle/operators/detail/grpc_server.cc create mode 100644 paddle/operators/detail/grpc_server.h delete mode 100644 paddle/operators/detail/recv_impl.cc delete mode 100644 paddle/operators/detail/send_impl.cc delete mode 100644 paddle/operators/detail/send_recv_impl.h create mode 100644 paddle/operators/detail/sendrecvop_utils.cc create mode 100644 paddle/operators/detail/sendrecvop_utils.h diff --git a/paddle/operators/detail/CMakeLists.txt b/paddle/operators/detail/CMakeLists.txt index f6bdc63cc2..571a75c9dc 100644 --- a/paddle/operators/detail/CMakeLists.txt +++ b/paddle/operators/detail/CMakeLists.txt @@ -1 +1 @@ -grpc_library(sendrecvop_grpc SRCS recv_impl.cc send_impl.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) +grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc new file mode 100644 index 0000000000..5a4db2d7e6 --- /dev/null +++ b/paddle/operators/detail/grpc_client.cc @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "grpc_client.h" +namespace paddle { +namespace operators { +namespace detail { + +bool RPCClient::AsyncSendVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out) { + sendrecv::VariableMessage req; + auto* var = scope.FindVar(var_name); + SerializeToMessage(var_name, var, ctx, &req); + + // varhandle + VarHandle var_h; + var_h.ep = ep; + var_h.scope = &scope; + var_h.name = var_name; + var_h.ctx = &ctx; + + // stub context + auto ch = GetChannel(ep); + SendProcessor* s = new SendProcessor(ch); + s->Prepare(var_h, time_out); + s->response_call_back_ = NULL; + + auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + + req_count_++; + + return true; +} + +void ProcGetResponse(const VarHandle& var_h, + const sendrecv::VariableMessage& ret_msg) { + auto* outvar = var_h.scope->FindVar(var_h.name); + + std::istringstream iss(ret_msg.serialized()); + DeserializeFromMessage(ret_msg, *var_h.ctx, outvar); +} + +bool RPCClient::AsyncGetVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out) { + sendrecv::VariableMessage req; + req.set_varname(var_name); + + auto* var = scope.FindVar(var_name); + SerializeToMessage(var_name, var, ctx, &req); + + // varhandle + VarHandle var_h; + var_h.ep = ep; + var_h.scope = &scope; + var_h.name = var_name; + var_h.ctx = &ctx; + + // stub context + auto ch = GetChannel(ep); + GetProcessor* s = new GetProcessor(ch); + s->Prepare(var_h, time_out); + s->response_call_back_ = ProcGetResponse; + + auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + + req_count_++; + + return true; +} + +bool RPCClient::wait() { + bool ok = true; + + while (true) { + if (req_count_ <= 0) { + break; + } + + if (!Proceed()) { + LOG(ERROR) << "Get meets CompletionQueue error"; + return false; + } + } + + return ok; +} + +bool RPCClient::Proceed() { + void* tag = NULL; + bool ok = false; + + // request counts. + if (!cq_.Next(&tag, &ok)) { + return false; + } + req_count_--; + + GPR_ASSERT(ok); + PADDLE_ENFORCE(tag); + + // TODO(gongwb): add more retries. + ClientBase* c = static_cast(tag); + if (!c->status_.ok()) { + delete c; + return true; + } + + c->Process(); + delete c; + return true; +} + +std::shared_ptr RPCClient::GetChannel(const std::string& ep) { + auto it = channels_.find(ep); + if (it != channels_.end()) { + return it->second; + } + + auto ch = std::shared_ptr( + grpc::CreateChannel(ep, grpc::InsecureChannelCredentials())); + + channels_[ep] = ch; + return ch; +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h new file mode 100644 index 0000000000..d27b5ced9e --- /dev/null +++ b/paddle/operators/detail/grpc_client.h @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/framework/data_type.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/selected_rows.h" +#include "paddle/operators/detail/sendrecvop_utils.h" +#include "paddle/operators/detail/simple_block_queue.h" + +namespace paddle { +namespace operators { +namespace detail { + +struct VarHandle { + std::string ep; + const platform::DeviceContext* ctx; + const framework::Scope* scope; + std::string name; + + std::string String() const { + std::ostringstream s; + s << "name:[" << name << "] ep:[" << ep << "]"; + return s.str(); + } +}; + +void ProcGetResponse(const VarHandle& var_h, + const sendrecv::VariableMessage& msg); + +class ClientBase { + public: + explicit ClientBase(std::shared_ptr ch) { + stub_ = sendrecv::SendRecvService::NewStub(ch); + context_ = NULL; + } + + virtual ~ClientBase() {} + + virtual void Prepare(const VarHandle& var_info, int64_t time_out) { + context_.reset(new grpc::ClientContext()); + var_h_ = var_info; + + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); + + context_->set_deadline(deadline); + } + + virtual void Process() = 0; + + std::unique_ptr stub_; + std::unique_ptr context_; + grpc::Status status_; + VarHandle var_h_; +}; + +typedef std::function + RequestSendCallBack; + +class SendProcessor : public ClientBase { + public: + explicit SendProcessor(std::shared_ptr ch) : ClientBase(ch) {} + + virtual ~SendProcessor() {} + + virtual void Process() { + if (response_call_back_) { + response_call_back_(var_h_, reply_); + } + } + + sendrecv::VoidMessage reply_; + RequestSendCallBack response_call_back_ = NULL; +}; + +typedef std::function + RequestGetCallBack; + +class GetProcessor : public ClientBase { + public: + explicit GetProcessor(std::shared_ptr ch) : ClientBase(ch) {} + + virtual ~GetProcessor() {} + + virtual void Process() { + if (response_call_back_) { + response_call_back_(var_h_, reply_); + } + } + + sendrecv::VariableMessage reply_; + RequestGetCallBack response_call_back_ = ProcGetResponse; +}; + +class RPCClient { + public: + bool AsyncSendVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = 600 * 1000); + + bool AsyncGetVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = 600 * 1000); + bool wait(); + + private: + bool Proceed(); + std::shared_ptr GetChannel(const std::string& ep); + + private: + grpc::CompletionQueue cq_; + std::map> channels_; + int64_t req_count_ = 0; +}; + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc new file mode 100644 index 0000000000..e8d561a57f --- /dev/null +++ b/paddle/operators/detail/grpc_server.cc @@ -0,0 +1,237 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/detail/grpc_server.h" + +using grpc::ServerAsyncResponseWriter; + +namespace paddle { +namespace operators { +namespace detail { + +enum CallStatus { PROCESS = 0, FINISH }; + +// reference: +// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server +class RequestBase { + public: + explicit RequestBase(sendrecv::SendRecvService::AsyncService* service, + grpc::ServerCompletionQueue* cq) + : service_(service), cq_(cq), status_(PROCESS) {} + virtual ~RequestBase() {} + virtual void Process() { assert(false); } + + CallStatus Status() { return status_; } + void SetStatus(CallStatus status) { status_ = status; } + + protected: + grpc::ServerContext ctx_; + sendrecv::SendRecvService::AsyncService* service_; + grpc::ServerCompletionQueue* cq_; + CallStatus status_; +}; + +typedef std::pair MessageWithName; + +class RequestSend final : public RequestBase { + public: + explicit RequestSend(sendrecv::SendRecvService::AsyncService* service, + grpc::ServerCompletionQueue* cq, + SimpleBlockQueue* queue) + : RequestBase(service, cq), queue_(queue), responder_(&ctx_) { + service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_, + this); + } + + virtual ~RequestSend() {} + + virtual void Process() { + MessageWithName msg_with_name = + std::make_pair(request_.varname(), std::move(request_)); + queue_->Push(std::move(msg_with_name)); + // TODO(gongwb): check var's info. + responder_.Finish(reply_, grpc::Status::OK, this); + } + + protected: + sendrecv::VariableMessage request_; + sendrecv::VoidMessage reply_; + SimpleBlockQueue* queue_; + ServerAsyncResponseWriter responder_; +}; + +class RequestGet final : public RequestBase { + public: + explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, + grpc::ServerCompletionQueue* cq, framework::Scope* scope) + : RequestBase(service, cq), responder_(&ctx_), scope_(scope) { + service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + virtual ~RequestGet() {} + + virtual void Process() { + // proc request. + std::string var_name = request_.varname(); + auto* var = scope_->FindVar(var_name); + SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_); + // TODO(gongwb): check var's info. + responder_.Finish(reply_, grpc::Status::OK, this); + } + + protected: + sendrecv::VariableMessage request_; + sendrecv::VariableMessage reply_; + ServerAsyncResponseWriter responder_; + framework::Scope* scope_; +}; + +void AsyncGRPCServer::RunSyncUpdate() { + grpc::ServerBuilder builder; + builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); + builder.RegisterService(&service_); + + cq_send_ = builder.AddCompletionQueue(); + cq_get_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + LOG(INFO) << "Server listening on " << address_ << std::endl; + + std::function send_register = + std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); + std::function get_register = + std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); + + t_send_.reset( + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false, + cq_send_.get(), "cq_send", send_register))); + + t_get_.reset( + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true, + cq_get_.get(), "cq_get", get_register))); + + // wait server + server_->Wait(); + t_send_->join(); + t_get_->join(); +} + +void AsyncGRPCServer::ShutdownQueue() { + std::unique_lock lock(cq_mutex_); + cq_send_->Shutdown(); + cq_get_->Shutdown(); + is_shut_down_ = true; +} + +// This URL explains why shutdown is complicate: +// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c +void AsyncGRPCServer::ShutDown() { + server_->Shutdown(); + ShutdownQueue(); +} + +void AsyncGRPCServer::TryToRegisterNewSendOne() { + std::unique_lock lock(cq_mutex_); + if (is_shut_down_) { + return; + } + RequestSend* send = + new RequestSend(&service_, cq_send_.get(), &var_recv_queue_); + VLOG(4) << "create RequestSend status:" << send->Status(); +} + +void AsyncGRPCServer::TryToRegisterNewGetOne() { + std::unique_lock lock(cq_mutex_); + if (is_shut_down_) { + return; + } + RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_); + VLOG(4) << "create Requestget status:" << get->Status(); +} + +void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) { + std::unique_lock lock(cq_mutex_); + if (is_shut_down_) { + delete last; + last = NULL; + return; + } + + last->SetStatus(FINISH); + return; +} + +void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, + std::string cq_name, + std::function TryToRegisterNewOne) { + TryToRegisterNewOne(); + + void* tag = NULL; + bool ok = false; + while (true) { + if (!cq->Next(&tag, &ok)) { + LOG(INFO) << cq_name << " get CompletionQueue shutdown!"; + break; + } + + if (wait && !done_) { + Wait(); + } + + RequestBase* base = (RequestBase*)tag; + if (!ok) { + VLOG(4) << cq_name << " recv no regular event"; + TryToRegisterNewOne(); + delete base; + continue; + } + + switch (base->Status()) { + case PROCESS: { + VLOG(4) << cq_name << " status:" << base->Status(); + TryToRegisterNewOne(); + base->Process(); + SetFinishOrDelete(base); + break; + } + case FINISH: { + VLOG(4) << cq_name << " status:" << base->Status(); + delete base; + break; + } + default: { assert(false); } + } + } +} + +void AsyncGRPCServer::Wait() { + std::unique_lock lock(this->mutex_); + condition_.wait(lock, [=] { return this->done_ == true; }); +} + +void AsyncGRPCServer::Reset() { + std::lock_guard lock(this->mutex_); + done_ = false; +} + +void AsyncGRPCServer::Done() { + { + std::lock_guard lock(this->mutex_); + done_ = true; + } + condition_.notify_all(); +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h new file mode 100644 index 0000000000..041fe05b2e --- /dev/null +++ b/paddle/operators/detail/grpc_server.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/selected_rows.h" +#include "paddle/framework/var_type.h" +#include "paddle/operators/detail/simple_block_queue.h" + +#include "paddle/operators/detail/send_recv.grpc.pb.h" +#include "paddle/operators/detail/send_recv.pb.h" + +#include +#include +#include +#include "paddle/operators/detail/sendrecvop_utils.h" + +namespace paddle { +namespace operators { +namespace detail { + +typedef std::pair MessageWithName; +class RequestBase; + +class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { + public: + explicit AsyncGRPCServer(std::string address) { address_ = address; } + + void RunSyncUpdate(); + + void Reset(); + + void Done(); + + void SetScope(framework::Scope *scope) { scope_ = scope; } + + const MessageWithName Get() { return this->var_recv_queue_.Pop(); } + + void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } + + void ShutDown(); + + protected: + void Wait(); + void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq, + std::string cq_name, + std::function TryToRegisterNewOne); + void TryToRegisterNewSendOne(); + void TryToRegisterNewGetOne(); + void SetFinishOrDelete(RequestBase *&last); + void ShutdownQueue(); + + private: + std::mutex cq_mutex_; + volatile bool is_shut_down_ = false; + std::unique_ptr cq_send_; + std::unique_ptr cq_get_; + + sendrecv::SendRecvService::AsyncService service_; + std::unique_ptr server_; + + std::string address_; + framework::Scope *scope_; + // received variable from RPC, operators fetch variable from this queue. + SimpleBlockQueue var_recv_queue_; + + // condition of the sub program + std::mutex mutex_; + volatile mutable bool done_; + std::condition_variable condition_; + + std::unique_ptr t_send_; + std::unique_ptr t_get_; +}; + +}; // namespace detail +}; // namespace operators +}; // namespace paddle diff --git a/paddle/operators/detail/recv_impl.cc b/paddle/operators/detail/recv_impl.cc deleted file mode 100644 index 319404e56a..0000000000 --- a/paddle/operators/detail/recv_impl.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "send_recv_impl.h" - -namespace paddle { -namespace operators { -namespace detail { - -Status SendRecvServerImpl::SendVariable(ServerContext *context, - const VariableMessage *in_var, - VoidMessage *out_var) { - MessageWithName msg_with_name = - std::make_pair(in_var->varname(), std::move(*in_var)); - var_recv_queue_.Push(std::move(msg_with_name)); - return Status::OK; -} - -Status SendRecvServerImpl::GetVariable(ServerContext *context, - const VariableMessage *in_var, - VariableMessage *out_var) { - std::string get_var_name = in_var->varname(); - auto *var = scope_->FindVar(get_var_name); - - SerializeToMessage(get_var_name, var, platform::CPUDeviceContext(), out_var); - return Status::OK; -} - -Status SendRecvServerImpl::Wait(ServerContext *context, - const VoidMessage *in_var, - VoidMessage *out_var) { - { - std::unique_lock lock(this->mutex_); - condition_.wait(lock, [=] { return this->done_ == true; }); - } - return Status::OK; -} - -void SendRecvServerImpl::Reset() { - std::lock_guard lock(this->mutex_); - done_ = false; -} - -void SendRecvServerImpl::Done() { - { - std::lock_guard lock(this->mutex_); - done_ = true; - } - condition_.notify_all(); -} - -} // namespace detail -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/detail/send_impl.cc b/paddle/operators/detail/send_impl.cc deleted file mode 100644 index ae85cf2cec..0000000000 --- a/paddle/operators/detail/send_impl.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "send_recv_impl.h" - -namespace paddle { -namespace operators { -namespace detail { - -bool RPCClient::SendVariable(const framework::Scope& scope, - const std::string& inname) { - ClientContext context; - VariableMessage msg; - VoidMessage out_msg; - // FIXME(typhoonzero): pass device context to here. - auto ctx = platform::CPUDeviceContext(); - auto* var = scope.FindVar(inname); - PADDLE_ENFORCE(var); - SerializeToMessage(inname, var, ctx, &msg); - - Status status = stub_->SendVariable(&context, msg, &out_msg); - if (!status.ok()) { - LOG(ERROR) << "gRPC error: " << status.error_message(); - return false; - } - return true; -} - -bool RPCClient::GetVariable(const framework::Scope& scope, - const std::string& outname) { - ClientContext context; - VariableMessage call_msg, ret_msg; - call_msg.set_varname(outname); - auto ctx = platform::CPUDeviceContext(); - Status status = stub_->GetVariable(&context, call_msg, &ret_msg); - auto* outvar = scope.FindVar(outname); - if (!status.ok()) { - LOG(ERROR) << "gRPC error: " << status.error_message(); - return false; - } - - std::istringstream iss(ret_msg.serialized()); - DeserializeFromMessage(ret_msg, ctx, outvar); - - return true; -} - -void RPCClient::Wait() { - ClientContext context; - VoidMessage call_msg, ret_msg; - stub_->Wait(&context, call_msg, &ret_msg); -} - -} // namespace detail -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/detail/send_recv.proto b/paddle/operators/detail/send_recv.proto index f141c755ce..8f962b4c69 100644 --- a/paddle/operators/detail/send_recv.proto +++ b/paddle/operators/detail/send_recv.proto @@ -21,8 +21,6 @@ service SendRecvService { rpc SendVariable(VariableMessage) returns (VoidMessage) {} // Argument VariableMessage for GetVariable should only contain varname. rpc GetVariable(VariableMessage) returns (VariableMessage) {} - // wait for one execution of the program - rpc Wait(VoidMessage) returns (VoidMessage) {} } // VariableMessage is serialized paddle variable message. diff --git a/paddle/operators/detail/send_recv_impl.h b/paddle/operators/detail/send_recv_impl.h deleted file mode 100644 index 1fe54f1f05..0000000000 --- a/paddle/operators/detail/send_recv_impl.h +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/lod_tensor.h" -#include "paddle/framework/scope.h" -#include "paddle/framework/selected_rows.h" -#include "paddle/framework/var_type.h" -#include "paddle/operators/detail/simple_block_queue.h" - -#include "paddle/operators/detail/send_recv.grpc.pb.h" -#include "paddle/operators/detail/send_recv.pb.h" - -#include - -using grpc::Channel; -using grpc::Server; -using grpc::ServerContext; -using grpc::ServerReader; -using grpc::ServerBuilder; - -using grpc::ClientContext; -using grpc::ClientReader; -using grpc::ClientReaderWriter; -using grpc::ClientWriter; -using grpc::Status; -using sendrecv::SendRecvService; -using sendrecv::VariableMessage; -using sendrecv::VoidMessage; - -namespace paddle { -namespace operators { -namespace detail { - -typedef std::pair MessageWithName; - -class SendRecvServerImpl final : public SendRecvService::Service { - public: - explicit SendRecvServerImpl() {} - - Status SendVariable(ServerContext *context, const VariableMessage *in_var, - VoidMessage *out_var) override; - Status GetVariable(ServerContext *context, const VariableMessage *in_var, - VariableMessage *out_var) override; - Status Wait(ServerContext *context, const VoidMessage *in_var, - VoidMessage *out_var) override; - void Reset(); - void Done(); - void SetScope(framework::Scope *scope) { scope_ = scope; }; - - const MessageWithName Get() { return this->var_recv_queue_.Pop(); } - - void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } - - private: - // received variable from RPC, operators fetch variable from this queue. - SimpleBlockQueue var_recv_queue_; - framework::Scope *scope_; - // condition of the sub program - std::mutex mutex_; - bool done_; - std::condition_variable condition_; -}; - -// RPCClient is a class to send tensors to pserver sub-network -// using different hashing methods. -class RPCClient { - public: - RPCClient(std::shared_ptr channel) - : stub_(SendRecvService::NewStub(channel)) {} - - bool SendVariable(const framework::Scope &scope, const std::string &inname); - bool GetVariable(const framework::Scope &scope, const std::string &outname); - void Wait(); - - private: - std::unique_ptr stub_; -}; - -inline void SerializeToMessage(const std::string &name, - const framework::Variable *var, - const platform::DeviceContext &ctx, - VariableMessage *msg) { - msg->set_varname(name); - std::ostringstream oss; - switch (framework::ToVarType(var->Type())) { - case framework::proto::VarDesc_VarType_LOD_TENSOR: - msg->set_type(sendrecv::VarType::LOD_TENSOR); - framework::SerializeToStream(oss, var->Get(), ctx); - break; - case framework::proto::VarDesc_VarType_SELECTED_ROWS: - msg->set_type(sendrecv::VarType::SELECTED_ROWS); - framework::SerializeToStream(oss, var->Get(), - ctx); - break; - default: { - PADDLE_THROW("Serialize does not support type: %s", - typeid(var->Type()).name()); - break; - } - } - msg->set_serialized(oss.str()); -} - -inline void DeserializeFromMessage(const VariableMessage &msg, - const platform::DeviceContext &ctx, - framework::Variable *var) { - using namespace paddle::framework::proto; - std::istringstream iss(msg.serialized()); - switch (msg.type()) { - case sendrecv::VarType::LOD_TENSOR: - DeserializeFromStream(iss, var->GetMutable(), ctx); - break; - case sendrecv::VarType::SELECTED_ROWS: { - DeserializeFromStream(iss, var->GetMutable(), - ctx); - break; - } - default: { - PADDLE_THROW("Deserialize does not support type: %s", - typeid(var->Type()).name()); - break; - } - } -} - -} // namespace detail -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/detail/sendrecvop_utils.cc b/paddle/operators/detail/sendrecvop_utils.cc new file mode 100644 index 0000000000..7635b9e8db --- /dev/null +++ b/paddle/operators/detail/sendrecvop_utils.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/detail/sendrecvop_utils.h" + +namespace paddle { +namespace operators { +namespace detail { + +void SerializeToMessage(const std::string& name, const framework::Variable* var, + const platform::DeviceContext& ctx, + sendrecv::VariableMessage* msg) { + msg->set_varname(name); + std::ostringstream oss; + switch (framework::ToVarType(var->Type())) { + case framework::proto::VarDesc_VarType_LOD_TENSOR: + msg->set_type(sendrecv::VarType::LOD_TENSOR); + framework::SerializeToStream(oss, var->Get(), ctx); + break; + case framework::proto::VarDesc_VarType_SELECTED_ROWS: + msg->set_type(sendrecv::VarType::SELECTED_ROWS); + framework::SerializeToStream(oss, var->Get(), + ctx); + break; + default: { + PADDLE_THROW("Serialize does not support type: %s", + typeid(var->Type()).name()); + break; + } + } + msg->set_serialized(oss.str()); +} + +void DeserializeFromMessage(const sendrecv::VariableMessage& msg, + const platform::DeviceContext& ctx, + framework::Variable* var) { + std::istringstream iss(msg.serialized()); + switch (msg.type()) { + case sendrecv::VarType::LOD_TENSOR: + DeserializeFromStream(iss, var->GetMutable(), ctx); + break; + case sendrecv::VarType::SELECTED_ROWS: { + DeserializeFromStream(iss, var->GetMutable(), + ctx); + break; + } + default: { + PADDLE_THROW("Deserialize does not support type: %s", + typeid(var->Type()).name()); + break; + } + } +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/sendrecvop_utils.h b/paddle/operators/detail/sendrecvop_utils.h new file mode 100644 index 0000000000..bc6581afab --- /dev/null +++ b/paddle/operators/detail/sendrecvop_utils.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/framework/data_type.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/selected_rows.h" +#include "paddle/framework/var_type.h" + +#include "paddle/operators/detail/send_recv.grpc.pb.h" +#include "paddle/operators/detail/send_recv.pb.h" + +namespace paddle { +namespace operators { +namespace detail { + +void SerializeToMessage(const std::string& name, const framework::Variable* var, + const platform::DeviceContext& ctx, + sendrecv::VariableMessage* msg); + +void DeserializeFromMessage(const sendrecv::VariableMessage& msg, + const platform::DeviceContext& ctx, + framework::Variable* var); +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 9331c7b563..55b33343af 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -24,7 +24,8 @@ limitations under the License. */ #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/proto_desc.h" -#include "paddle/operators/detail/send_recv_impl.h" +#include "paddle/operators/detail/grpc_server.h" +#include "paddle/operators/detail/sendrecvop_utils.h" #include "paddle/operators/detail/simple_block_queue.h" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" @@ -32,6 +33,11 @@ limitations under the License. */ namespace paddle { namespace operators { +void RunServer(std::shared_ptr service) { + service->RunSyncUpdate(); + VLOG(4) << "RunServer thread end"; +} + static void CreateTensorFromMessageType(framework::Variable *var, sendrecv::VarType var_type) { if (var_type == sendrecv::VarType::LOD_TENSOR) { @@ -46,18 +52,6 @@ static void CreateTensorFromMessageType(framework::Variable *var, } } -void RunServer(Server **rpc_server, - std::shared_ptr service, - const std::string &server_address) { - ServerBuilder builder; - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - builder.RegisterService(service.get()); - std::unique_ptr server(builder.BuildAndStart()); - *rpc_server = server.get(); - LOG(INFO) << "Server listening on " << server_address; - server->Wait(); -} - class RecvOp : public framework::OperatorBase { public: RecvOp(const std::string &type, const framework::VariableNameMap &inputs, @@ -65,10 +59,9 @@ class RecvOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) { if (!rpc_service_) { - rpc_service_.reset(new detail::SendRecvServerImpl()); std::string endpoint = Attr("endpoint"); - server_thread_.reset( - new std::thread(RunServer, &rpc_server_, rpc_service_, endpoint)); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + server_thread_.reset(new std::thread(RunServer, rpc_service_)); } } @@ -76,7 +69,7 @@ class RecvOp : public framework::OperatorBase { detail::MessageWithName term_msg; term_msg.first = LISTEN_TERMINATE_MESSAGE; rpc_service_->Push(term_msg); - rpc_server_->Shutdown(); + rpc_service_->ShutDown(); server_thread_->join(); } @@ -99,10 +92,12 @@ class RecvOp : public framework::OperatorBase { auto grad_list = Attr>("GradList"); auto trainer_count = Attr("Trainers"); size_t param_count = param_list.size(); + rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; while (!exit_flag) { + // TODO(gognwb): simply this loop. // Get from multiple trainers, we don't care about order in which // the gradient arrives, just add suffix 0~n then average the gradient. for (size_t i = 0; i < param_count * trainer_count; ++i) { @@ -110,6 +105,7 @@ class RecvOp : public framework::OperatorBase { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { + VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit"; exit_flag = true; break; } @@ -118,10 +114,12 @@ class RecvOp : public framework::OperatorBase { if (it != grad_list.end()) { param_var_name = param_list[it - grad_list.begin()]; } else { - LOG(ERROR) << "grad have no paired param found!"; + LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name + << "\""; } VLOG(3) << "recved grad: " << grad_var_name << " updating param: " << param_var_name; + auto *merged_grad = recv_scope.FindVar(grad_var_name); if (merged_grad == nullptr) { auto *ptr = recv_scope.Var(grad_var_name); @@ -141,9 +139,11 @@ class RecvOp : public framework::OperatorBase { auto &dev_ctx = *pool.Get(dev_place); detail::DeserializeFromMessage(v.second, dev_ctx, var); } + if (exit_flag) { break; } + rpc_service_->Reset(); std::string program_str = Attr("OptimizeProgram"); @@ -158,17 +158,14 @@ class RecvOp : public framework::OperatorBase { } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } + rpc_service_->Done(); grads_counter_.clear(); } // while(true) } protected: - // grpc server instance to track status and gracefully shutdown. - // borrow an pointer from server thread. - Server *rpc_server_{nullptr}; - // grpc send/recv service implement to register. - std::shared_ptr rpc_service_; + std::shared_ptr rpc_service_; std::shared_ptr server_thread_; mutable std::unordered_map grads_counter_; }; diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 95c207221a..4d145250bd 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -19,59 +19,45 @@ limitations under the License. */ #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/detail/send_recv_impl.h" -#include "paddle/operators/detail/simple_block_queue.h" +#include +#include "paddle/operators/detail/grpc_client.h" namespace paddle { namespace operators { -// TODO(typhoonzero): this is a simple implementation which only send -// one tensor class SendOp : public framework::OperatorBase { public: - SendOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) { - // init client when the operator is created at runtime. - std::vector endpoints = - Attr>("endpoints"); - for (auto ep : endpoints) { - client_map_[ep].reset(new detail::RPCClient( - grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()))); - } - } + SendOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + void Run(const framework::Scope& scope, + const platform::Place& dev_place) const override { auto ins = Inputs("X"); auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); - // TODO(typhoonzero): use async calls to send multiple variable asyncly. - for (size_t i = 0; i < ins.size(); ++i) { - bool ret = client_map_[epmap[i]]->SendVariable(scope, ins[i]); - if (!ret) { - LOG(ERROR) << "send variable error: " << ins[i]; - } + + // FIXME(gongwb): DeviceContext? + auto ctx = platform::CPUDeviceContext(); + for (size_t i = 0; i < ins.size(); i++) { + client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); } - // TODO(typhoonzero): support async optimization - client_map_[epmap[0]]->Wait(); - for (size_t i = 0; i < outs.size(); ++i) { - bool ret = client_map_[epmap[i]]->GetVariable(scope, outs[i]); - if (!ret) { - LOG(ERROR) << "GetVariable error: " << outs[i]; - } + + for (size_t i = 0; i < outs.size(); i++) { + client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } + + client_.wait(); } - protected: - mutable std::unordered_map> - client_map_; + private: + mutable detail::RPCClient client_; }; class SendOpMaker : public framework::OpProtoAndCheckerMaker { public: - SendOpMaker(OpProto *proto, OpAttrChecker *op_checker) + SendOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(Tensor) Input tensor to be send").AsDuplicable(); AddOutput("Out", "(Tensor) Output tensor to get from server") diff --git a/paddle/operators/send_recv_op_test.cc b/paddle/operators/send_recv_op_test.cc index fa94424bf9..ea09169479 100644 --- a/paddle/operators/send_recv_op_test.cc +++ b/paddle/operators/send_recv_op_test.cc @@ -140,7 +140,7 @@ void StartServerNet(bool is_sparse) { TEST(SendRecvOp, CPUDense) { std::thread server_thread(StartServerNet, false); - sleep(3); // wait server to start + sleep(10); // wait server to start // local net f::Scope scope; p::CPUPlace place; From 87f9b5836359d607cc4ef1c9684c067ed7e7b1e0 Mon Sep 17 00:00:00 2001 From: QI JUN Date: Thu, 11 Jan 2018 10:49:48 +0800 Subject: [PATCH 26/30] set stop gradient for mask in dropout layer (#7390) --- python/paddle/v2/fluid/layers/nn.py | 17 ++++++++++++++++- python/paddle/v2/fluid/layers/ops.py | 13 +------------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index b1534c5a88..48a6bee558 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -14,7 +14,7 @@ __all__ = [ 'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', - 'sequence_first_step', 'sequence_last_step' + 'sequence_first_step', 'sequence_last_step', 'dropout' ] @@ -386,6 +386,21 @@ def cos_sim(X, Y, **kwargs): return out +def dropout(x, dropout_prob, is_test=False, seed=0, **kwargs): + helper = LayerHelper('dropout', **kwargs) + out = helper.create_tmp_variable(dtype=x.dtype) + mask = helper.create_tmp_variable(dtype=x.dtype, stop_gradient=True) + helper.append_op( + type='dropout', + inputs={'X': [x]}, + outputs={'Out': [out], + 'Mask': [mask]}, + attrs={'dropout_prob': dropout_prob, + 'is_test': is_test, + 'seed': seed}) + return out + + def cross_entropy(input, label, **kwargs): """ **Cross Entropy Layer** diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index 544623c4bc..d3a5b70785 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -1,23 +1,12 @@ from ..registry import register_layer __activations__ = [ - 'abs', - 'ceil', - 'exp', - 'floor', - 'log', - 'relu', - 'round', - 'sigmoid', - 'sqrt', - 'square', - 'tanh', + 'abs', 'tanh', 'sigmoid', 'relu', 'sqrt', 'ceil', 'floor', 'log', 'round' ] __all__ = [ 'mean', 'mul', - 'dropout', 'reshape', 'scale', 'transpose', From 1797f3db850206cc7e87ee0ff2c06816b1dfcae7 Mon Sep 17 00:00:00 2001 From: QI JUN Date: Thu, 11 Jan 2018 11:15:06 +0800 Subject: [PATCH 27/30] Refine memory optimization transpiler (#7394) * add update graph method for memory optimization transpiler to avoid rebuild graph everytime * clean code * reset var desc if hit cache --- python/paddle/v2/fluid/framework.py | 3 + .../fluid/memory_optimization_transpiler.py | 77 ++++++++++++++----- 2 files changed, 62 insertions(+), 18 deletions(-) diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 2fb388acfc..3ef6b33192 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -236,6 +236,9 @@ class Variable(object): __repr__ = __str__ + def set_desc(self, input): + self.desc = input + @property def persistable(self): return self.desc.persistable() diff --git a/python/paddle/v2/fluid/memory_optimization_transpiler.py b/python/paddle/v2/fluid/memory_optimization_transpiler.py index 571fce7fac..6800d7ddbb 100644 --- a/python/paddle/v2/fluid/memory_optimization_transpiler.py +++ b/python/paddle/v2/fluid/memory_optimization_transpiler.py @@ -3,6 +3,17 @@ import framework from framework import Program, default_main_program, Parameter, Variable import backward from backward import _rename_arg_ +from . import core + +dtype_to_size = { + core.DataType.FP16: 2, + core.DataType.FP32: 4, + core.DataType.FP64: 8, + core.DataType.INT16: 2, + core.DataType.INT32: 4, + core.DataType.INT64: 8, + core.DataType.BOOL: 1 +} class ControlFlowGraph(object): @@ -28,18 +39,33 @@ class ControlFlowGraph(object): block_size = program_desc.num_blocks() # TODO(qijun) handle Program with if/while operators - self.global_block = program_desc.block(0) - self.op_size = self.global_block.op_size() + self.global_block_desc = program_desc.block(0) + self.op_size = self.global_block_desc.op_size() op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)] self._add_connections(op_node_connections) - self.ops = [self.global_block.op(i) for i in range(self.op_size)] + self.ops = [self.global_block_desc.op(i) for i in range(self.op_size)] for i in range(self.op_size): self._uses[i].update(self.ops[i].input_arg_names()) self._defs[i].update(self.ops[i].output_arg_names()) + def _update_graph(self, old_name, new_name, begin_idx=0): + for i in range(begin_idx, self.op_size): + if old_name in self._uses[i]: + self._uses[i].remove(old_name) + self._uses[i].add(new_name) + if old_name in self._defs[i]: + self._defs[i].remove(old_name) + self._defs[i].add(new_name) + if old_name in self._live_in[i]: + self._live_in[i].remove(old_name) + self._live_out[i].add(new_name) + if old_name in self._live_out[i]: + self._live_out[i].remove(old_name) + self._live_out[i].add(new_name) + def _reach_fixed_point(self, live_in, live_out): if len(live_in) != len(self._live_in): return False @@ -79,30 +105,45 @@ class ControlFlowGraph(object): self.pool = [] for i in range(self.op_size): if self.pool: - out_pair = [(x, self.global_block.var(str(x)).shape()) + out_pair = [(x, self.global_block_desc.var(str(x)).shape()) for x in self._defs[i]] for x, x_shape in out_pair: - for index, cache_pair in enumerate(self.pool): - cache_var = cache_pair[0] - cache_shape = cache_pair[1] - if x_shape == cache_shape: - print( - "Hit Cache !!!! cache pool index is %d, var name is %s, cached var name is %s, var shape is %s " - % (index, x, cache_var, str(cache_shape))) - self.pool.pop(index) - _rename_arg_(self.ops, x, cache_var, begin_idx=i) - self._dataflow_analyze() - break + if not self.global_block_desc.var(str(x)).persistable(): + for index, cache_pair in enumerate(self.pool): + cache_var = cache_pair[0] + cache_shape = cache_pair[1] + if x_shape == cache_shape: + x_dtype = self.global_block_desc.var(str( + x)).dtype() + cache_dtype = self.global_block_desc.var( + str(cache_var)).dtype() + # TODO(qijun): actually, we should compare dtype_to_size[x_dtype] + # and dtype_to_size[cache_dtype] + if x_dtype == cache_dtype: + print( + "Hit Cache !!!! cache pool index is %d, var name is %s, cached var name is %s, var shape is %s " + % + (index, x, cache_var, str(cache_shape))) + self.pool.pop(index) + _rename_arg_( + self.ops, x, cache_var, begin_idx=i) + self._program.current_block().var(str( + x)).desc = self.global_block_desc.var( + str(cache_var)) + self._update_graph( + x, cache_var, begin_idx=i) + break in_diff, out_diff = self._get_diff(self._live_in[i], self._live_out[i]) can_optimize = filter( - lambda x: not self.global_block.var(str(x)).persistable(), + lambda x: not self.global_block_desc.var(str(x)).persistable(), in_diff) if can_optimize: for var_name in can_optimize: - self.pool.append(( - var_name, self.global_block.var(str(var_name)).shape())) + self.pool.append( + (var_name, + self.global_block_desc.var(str(var_name)).shape())) def get_program(self): return self._program From 24cde57ca0edd9b734ab7ea9fc0c077bb76567b6 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 11 Jan 2018 11:26:10 +0800 Subject: [PATCH 28/30] Extend return value for layer functions Make users can access parameters of layers and their gradients. --- doc/design/python_api.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/doc/design/python_api.md b/doc/design/python_api.md index cb5fdc765b..73f6d7b90c 100644 --- a/doc/design/python_api.md +++ b/doc/design/python_api.md @@ -279,6 +279,26 @@ class LayerHelper(object): return tmp ``` +### Return value of layer functions + +The layer will return a Variable, which is also the output of an operator. However, outputs of a layer function have more attributes than an operator. There are parameter variables, and their gradient variables need to return. To return them is useful. For example, + +1. Users can debug the network by printing parameter gradients. +2. Users can append attributes to a parameter, such as, `param.stop_gradient=True` will make a parameter stop generate the gradient. We can fix the parameter value during training by using this attribute. + +However, it is good to return a Variable for layers, since all layers and operators use Variables as their parameters. We can just append a `param` field and a `grad` field for layer function since the Python is dynamic typing. + +The sample usage is + +```python +data = fluid.layers.data(...) +hidden = fluid.layers.fc(data, ...) +... + +executor.run(fetch_list=[hidden.param, hidden.param.grad], ...) +``` + + ## Optimizer [Optimizer Design Doc](./optimizer.md) From a795a0d743de776e0008e21ae7266690f989971d Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Wed, 10 Jan 2018 20:48:01 -0800 Subject: [PATCH 29/30] Refine the document for memory optimization (#7420) --- doc/design/memory_optimization.md | 32 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/doc/design/memory_optimization.md b/doc/design/memory_optimization.md index 00f514711a..1f68cef4cc 100644 --- a/doc/design/memory_optimization.md +++ b/doc/design/memory_optimization.md @@ -5,28 +5,28 @@ In a lecture from Andrew Ng, he attributes the recent sucess of AI due to a combination of these: -- availability of Big Data -- supercomputing power to process this Big Data over very large neural networks -- modern algorithms +- Availability of Big Data +- Supercomputing power to process this Big Data over very large neural networks +- Modern algorithms Following graph shows the details: ![](images/deep_learning.png) -Larger model usually brings better performance. However, GPU memory is certain limited. For example, the memory size of a GTX TITAN X is only 12GB. To train complex and large model, we have to take care of memory using. Besides, memory optimization is also necessary in both online/mobile inference. +Larger model usually bring better performance. However, GPU memory is limited. For example, the memory size of a GTX TITAN X is only 12GB. To train complex and large models, we have to take care of memory usage. Besides, memory optimization is also necessary in both online/mobile inference. ## Solution ### Basic Strategy -There are some basic strategies to make memory optimization, including in-place operation and memory sharing. +There are some basic strategies to improve memory usage, including in-place operations and memory sharing. #### In-place Operation In a relu activation operator: $y = \max(x, 0)$ -If the variable x is not used in any other operator, we can make an in-place operation. In other words, the memory block of variable y and variable x are the same. In-place operation will save 50% memory occupancy immediately. +If the variable x is not used in any other operator, we can make an in-place operation. In other words, the memory block of variable y and variable x will be the same. In-place operations will save 50% memory occupancy immediately. #### Memory Sharing @@ -40,18 +40,18 @@ d = op2(a) e = op3(d, f) ``` -In this case, variable a is no longer used, and op2 does not support in-place operation. After op2 finished, we can put the memory of variable a to a memory pool. Then, variable e can share the memory of variable a from the pool. +In this case, variable a is no longer used, and op2 does not support in-place operation. After op2 finishes, we can put the memory of variable a to a memory pool. Then, variable e can share the memory of variable a from the pool. ### Live Variable Analysis -It's not enough to only have some basic strategies. The prerequisite of memory optimization is to know if a variable is still "live" after an operation. +It's not enough to only have some basic strategies. The pre-requisite of memory optimization is to know if a variable is still "live" after an operation. In our design, the neural network topology is defined as a program. Luckily, [live variable analysis](https://en.wikipedia.org/wiki/Live_variable_analysis) is a classic problem in compilers which can be used in many stages, such as register allocation. -In compilers, the front end of the compilers translates programs into an intermediate language with an unbounded number of temporaries. This program must run on a machine with a bounded number of registers. Two temporaries a and b can fit into the same register, if a and b are never "in use" at the same time. Thus, many temporaries can fit in few registers; if they don't all fit, the excess temporaries can be kept in memory. +In compilers, the front end of the compiler translates programs into an intermediate language with an unbounded number of temporary variables. This program must run on a machine with a bounded number of registers. Two temporary variables a and b can fit into the same register, if a and b are never "in use" at the same time. Thus, many temporary variables can fit in few registers; if they don't all fit, the excess tempory variables can be kept in memory. -Therefore, the compiler needs to analyze the intermediate-representation program to determine which temporaries are in use at the same time. We say a variable is "live" if it holds a value that may be needed in the future, so this analysis is called liveness analysis. +Therefore, the compiler needs to analyze the intermediate-representation program to determine which temporary variables are in use at the same time. We say a variable is "live" if it holds a value that may be needed in the future, so this analysis is called liveness analysis. We can leran these techniques from compilers. There are mainly two stages to make live variable analysis: @@ -60,7 +60,7 @@ We can leran these techniques from compilers. There are mainly two stages to mak #### Control Flow Graph -To preform analyses on a program, it is often useful to make a control flow graph. A [control flow graph](https://en.wikipedia.org/wiki/Control_flow_graph) (CFG) in computer science is a representation, using graph notation, of all paths that might be traversed through a program during its execution. Each statement in the program is a node in the flow graph; if statemment x can be followed by statement y, there is an egde from x to y. +To perform analysis on a program, it is often useful to make a control flow graph. A [control flow graph](https://en.wikipedia.org/wiki/Control_flow_graph) (CFG) in computer science is a representation, using graph notation, of all paths that might be traversed through a program during its execution. Each statement in the program is a node in the flow graph; if statemment x can be followed by statement y, there is an egde from x to y. Following is the flow graph for a simple loop. @@ -68,18 +68,18 @@ Following is the flow graph for a simple loop. #### Dataflow Analysis -liveness of variable "flows" around the edges of the control flow graph; determining the live range of each variable is an example of a dataflow problem. [Dataflow analysis](https://en.wikipedia.org/wiki/Data-flow_analysis) is a technique for gathering information about the possible set of values calculated at various points in a computer program. +Liveness of variable "flows" around the edges of the control flow graph; determining the live range of each variable is an example of a dataflow problem. [Dataflow analysis](https://en.wikipedia.org/wiki/Data-flow_analysis) is a technique for gathering information about the possible set of values calculated at various points in a computer program. A simple way to perform data-flow analysis of programs is to set up dataflow equations for each node of the control flow graph and solve them by repeatedly calculating the output from the input locally at each node until the whole system stabilizes. - Flow Graph Terminology -A flow graph node has out-edges that lead to sucessor nodes, and in-edges that come from presucessor nodes. The set *pred[n]* is all the predecessors of node n, and *succ[n]* is the set of sucessors. +A flow graph node has out-edges that lead to sucessor nodes, and in-edges that come from predecessor nodes. The set *pred[n]* is all the predecessors of node n, and *succ[n]* is the set of sucessors. In former control flow graph, the out-edges of node 5 are 5 --> 6 and 5 --> 2, and *succ[5]* = {2, 6}. The in-edges of 2 are 5 --> 2 and 1 --> 2, and *pred[2]* = {1, 5}. - Uses and Defs -An assignmemt to a variable or temporary defines that variable. An occurence of a variable on the right-hand side of an assginment(or in other expressions) uses the variable. We can speak the *def* of a variable as the set of graph nodes that define it; or the *def* of a graph node as the set of variables that it defines; and the similarly for the *use* of a variable or graph node. In former control flow graph, *def(3)* = {c}, *use(3)* = {b, c}. +An assignmemt to a variable or temporary defines that variable. An occurence of a variable on the right-hand side of an assginment(or in other expressions) uses the variable. We can define the *def* of a variable as the set of graph nodes that define it; or the *def* of a graph node as the set of variables that it defines; and the similarly for the *use* of a variable or graph node. In former control flow graph, *def(3)* = {c}, *use(3)* = {b, c}. - Liveness @@ -168,9 +168,9 @@ class ControlFlowGraph(object): return self._program ``` -#### make dataflow analysis +#### Make dataflow analysis -We follow guide from compilers and try to solve the dataflow equation to get liveness of every variable. If the live-in of an operator node is different from the live-out, then we can make memory sharing. +We follow the guide from compilers and try to solve the dataflow equation to get liveness of every variable. If the live-in of an operator node is different from the live-out, then we can make memory sharing. For example: From e71e4a3d134aee08016ce638c47f47fd332a271b Mon Sep 17 00:00:00 2001 From: "Yang Yang(Tony)" Date: Thu, 11 Jan 2018 14:44:09 +0800 Subject: [PATCH 30/30] Update read_source.md (#7406) --- doc/howto/read_source.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/howto/read_source.md b/doc/howto/read_source.md index e4211abb3b..31987920f3 100644 --- a/doc/howto/read_source.md +++ b/doc/howto/read_source.md @@ -26,16 +26,16 @@ sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer.minimize(avg_cost) ``` -- Variables: `x`, `y`, `y_predict`, `cost` and `avg_cost`. [Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/framework.py#L93) -- Layers: `fluid.layers.data`, `fluid.layers.fc` and `fluid.layers.mean` are layers. [Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/layers.py) +- Variables: `x`, `y`, `y_predict`, `cost` and `avg_cost`. [Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/framework.py#) +- Layers: `fluid.layers.data`, `fluid.layers.fc` and `fluid.layers.mean` are layers. [Python](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/v2/fluid/layers) - Every Layer has one or more operators and variables/parameters - All the operators are defined at [`paddle/operators/`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators). Other worth-looking files: - Base class: [`paddle/framework/operator.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h) - Operator Registration: [`paddle/framework/op_registry.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_registry.h) - Operator Lookup: [`paddle/framework/op_info.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_info.h) - Optimizer: `fluid.optimizer.SGD`. It does the following - - Add backward operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/backward.py), [C++](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/backward.cc)] - - Add optimizer operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/optimizer.py), [C++](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/optimizer)] + - Add backward operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/backward.py)] + - Add optimizer operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/optimizer.py)] # Run Time