From c8801e100f04fb6ad4d35a5635cbc316fead80d1 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Sat, 10 Nov 2018 10:55:07 +0000 Subject: [PATCH 01/23] grad diff problem to be fixed and need api spec change to be done --- paddle/fluid/framework/selected_rows.h | 3 +- .../operators/hierarchical_sigmoid_op.cc | 11 +- .../fluid/operators/hierarchical_sigmoid_op.h | 55 ++++++-- .../fluid/operators/math/matrix_bit_code.cc | 49 ++++---- paddle/fluid/operators/math/matrix_bit_code.h | 119 ++++++++++++++++-- python/paddle/fluid/layers/nn.py | 23 +++- .../paddle/fluid/tests/unittests/op_test.py | 7 +- .../fluid/tests/unittests/test_hsigmoid_op.py | 117 +++++++++++++++-- 8 files changed, 324 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index daf5e95304..4d728ae54a 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -133,7 +133,8 @@ class SelectedRows { // SelectedRows are simply concated when adding together. Until a // SelectedRows add a Tensor, will the duplicate rows be handled. Vector rows_; - std::unordered_map id_to_index_; + std::unordered_map + id_to_index_; // should not be used when ids has duplicate member std::unique_ptr value_{nullptr}; int64_t height_; std::unique_ptr rwlock_{nullptr}; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index dadd054b9a..49a17416c8 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -91,10 +91,19 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("W", "(Tensor, required), The parameters of hierarchical " "sigmoid operator, each of them is a 2-D tensor, the shape is" - "[num_classes - 1, D]."); + "[K, D]. Which K is the num of non-leaf node in Path Tree"); AddInput("Label", "(Tensor, required), The labels of training data. It's a" "tensor with shape [N, 1]."); + AddInput("PTable", + "(Tensor, optional), The Path Table from root to current word" + "it should have shape like [N, L], L is the length of the Path") + .AsDispensable(); + AddInput("PCode", + "(Tensor, optional), The Code on each Node of the Path from root " + "to current word" + "it should have shape like [N, L], L is the length of the Path") + .AsDispensable(); AddInput("Bias", "(Tensor, optional), The bias is a tensor with shape" "[1, num_classes - 1]."); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 64096a717b..2d500a03df 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/matrix_bit_code.h" @@ -34,12 +35,21 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); auto* w = ctx.Input("W"); + auto* path = ctx.Input("PTable"); + auto* code = ctx.Input("PCode"); auto* label = ctx.Input("Label"); auto* bias = ctx.Input("Bias"); auto* out = ctx.Output("Out"); auto* pre_out = ctx.Output("PreOut"); size_t num_classes = static_cast(ctx.Attr("num_classes")); - int64_t code_length = math::FindLastSet(num_classes - 1); + bool is_custom = false; + if (path) { + is_custom = true; + } else { + is_custom = false; + } + int64_t code_length = + path ? path->dims()[1] : math::FindLastSet(num_classes - 1); int64_t batch_size = in->dims()[0]; framework::Tensor sum; auto& dev_ctx = ctx.template device_context(); @@ -52,7 +62,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { zero(dev_ctx, pre_out, static_cast(0.0)); auto& place = *ctx.template device_context().eigen_device(); math::RowwiseSum row_sum; - math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); + + std::unique_ptr> bit_code; + if (!is_custom) { + bit_code.reset(new math::MatrixBitCodeFunctor(num_classes, + label->data())); + } else { + bit_code.reset(new math::MatrixBitCodeFunctor(path, code, + label->data())); + } std::vector sum_dims({batch_size, 1UL}); sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); @@ -60,15 +78,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); auto out_mat = framework::EigenVector::Flatten(*out); if (bias) { - bit_code.Add(pre_out, *bias); + bit_code->Add(pre_out, *bias); } - bit_code.Mul(pre_out, *w, *in); + bit_code->Mul(pre_out, *w, *in); // clip to [-40, 40] Transform trans; trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out->numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); - bit_code.Sum(*pre_out, out, static_cast(-1)); + bit_code->Sum(*pre_out, out, static_cast(-1)); // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); row_sum(dev_ctx, *pre_out, &sum); @@ -86,6 +104,8 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); auto* w = ctx.Input("W"); + auto* path = ctx.Input("PTable"); + auto* code = ctx.Input("PCode"); auto* in_grad = ctx.Output(framework::GradVarName("X")); auto* w_grad = ctx.Output(framework::GradVarName("W")); auto* bias_grad = @@ -105,7 +125,22 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { zero(dev_ctx, w_grad, static_cast(0.0)); size_t num_classes = static_cast(ctx.Attr("num_classes")); - math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); + + bool is_custom = false; + if (path) { + is_custom = true; + } else { + is_custom = false; + } + + std::unique_ptr> bit_code; + if (!is_custom) { + bit_code.reset(new math::MatrixBitCodeFunctor(num_classes, + label->data())); + } else { + bit_code.reset(new math::MatrixBitCodeFunctor(path, code, + label->data())); + } auto& place = *ctx.template device_context().eigen_device(); auto pre_out_mat = EigenMatrix::From(*pre_out); @@ -116,7 +151,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { // softrelu derivative pre_out_grad_mat.device(place) = static_cast(1.0) - static_cast(1.0) / pre_out_mat.exp(); - bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b) + bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) pre_out_grad_mat.device(place) = pre_out_grad_mat * out_grad_mat.broadcast(bcast); // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to @@ -124,10 +159,10 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { if (bias_grad) { bias_grad->mutable_data(ctx.GetPlace()); zero(dev_ctx, bias_grad, static_cast(0.0)); - bit_code.AddGrad(pre_out_grad, bias_grad); + bit_code->AddGrad(pre_out_grad, bias_grad); } - bit_code.MulGradWeight(pre_out_grad, w_grad, *in); - bit_code.MulGradError(pre_out_grad, *w, in_grad); + bit_code->MulGradWeight(pre_out_grad, w_grad, *in); + bit_code->MulGradError(pre_out_grad, *w, in_grad); } }; diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 1e56e29739..88279f8d8a 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -21,14 +21,13 @@ namespace math { template void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, const framework::Tensor& vec) { - SimpleCodeTable code_table(num_classes_); size_t batch_size = tmat->dims()[0]; size_t width = tmat->dims()[1]; for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); tmat->data()[i * width + j] += vec.data()[index]; } } @@ -37,14 +36,13 @@ void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, template void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, framework::Tensor* vec) { - SimpleCodeTable code_table(num_classes_); size_t batch_size = tmat.dims()[0]; size_t width = tmat.dims()[1]; for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); vec->data()[index] += tmat.data()[i * width + j]; } } @@ -53,15 +51,14 @@ void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, template void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t o_width = tmat.dims()[1]; for (size_t i = 0; i < num_samples; ++i) { T sm = static_cast(0.0); - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - if (code.calc_bit(j)) { + if (code->calc_bit(j)) { // calc_bit starts from right most bit, while data in tmat[i] is in the // reverse order. sm += tmat.data()[i * o_width + j]; @@ -75,7 +72,6 @@ template void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, const framework::Tensor& weight, const framework::Tensor& input) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat->dims()[0]; size_t tmat_width = tmat->dims()[1]; size_t input_width = input.dims()[1]; @@ -84,10 +80,10 @@ void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, auto weight_value = weight.data(); auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); T sum = static_cast(0.0); for (size_t k = 0; k < input_width; ++k) { sum += weight_value[weight_width * index + k] * @@ -102,7 +98,6 @@ template void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, const framework::Tensor& input) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; @@ -111,10 +106,10 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, auto weight_value = weight->data(); auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); for (size_t k = 0; k < input_width; ++k) { weight_value[weight_width * index + k] += @@ -128,7 +123,6 @@ template void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, const framework::Tensor& weight, framework::Tensor* input) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t tmat_width = tmat.dims()[1]; size_t input_width = input->dims()[1]; @@ -138,10 +132,10 @@ void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, auto input_value = input->data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); for (size_t k = 0; k < input_width; ++k) { input_value[input_width * i + k] += @@ -154,14 +148,13 @@ void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, template void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat->dims()[0]; size_t o_width = tmat->dims()[1]; for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - if (code.calc_bit(j)) { + if (code->calc_bit(j)) { tmat->data()[i * o_width + j] -= 1; } } diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index 07854c8358..f03c8d3689 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -93,9 +93,27 @@ inline int clz(const T& value) { inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); } #endif // !_WIN32 } +// set a code interface to create multiple code +class Code { + public: + virtual ~Code() {} + virtual size_t calc_index(int bit) const = 0; + virtual bool calc_bit(int bit) const = 0; + virtual int get_length() const = 0; +}; +// set a CodeTable interface to create multiple code table +class CodeTable { + public: + virtual std::unique_ptr get_code(int64_t code) const = 0; + virtual size_t size() const = 0; + virtual int get_max_code_length() const = 0; + virtual ~CodeTable() {} +}; -struct SimpleCode { - SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} +class SimpleCode : public Code { + public: + SimpleCode(size_t code, size_t num_classes, const int64_t* ids) + : c_(static_cast(ids[code]) + num_classes) {} /** * Here the id of root shoud be 1 rather than 0, thus the encoding of class c * is `c + num_classes` and all siblings can get the same weight indice using @@ -105,31 +123,111 @@ struct SimpleCode { * Binary classification path is the suffixes of encoding, thus leave out the * left most bit in calc_bit. */ - inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } - inline bool calc_bit(int bit) const { return c_ & (1 << bit); } - inline int get_length() const { return FindLastSet(c_) - 1; } + size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } + bool calc_bit(int bit) const { return c_ & (1 << bit); } + int get_length() const { return FindLastSet(c_) - 1; } private: size_t c_; }; -struct SimpleCodeTable { - explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {} - SimpleCode operator()(size_t code) const { - return SimpleCode(code, num_classes_); +template +class CustomCode : public Code { + public: + CustomCode(const framework::Tensor* ptable, const framework::Tensor* pcode, + const int64_t* ids, const int index) + : ptable_(ptable), pcode_(pcode), ids_(ids), index_(index) {} + /** + * Here the id of root shoud be 1 rather than 0, thus the encoding of class c + * is `c + num_classes` and all siblings can get the same weight indice using + * prefixes. + * Weight index is the prefixes of encoding, thus leave out the right most + * bit in calc_index. + * Binary classification path is the suffixes of encoding, thus leave out the + * left most bit in calc_bit. + */ + size_t calc_index(int bit) const { + return ptable_ + ->data()[index_ * static_cast(ptable_->dims()[1]) + bit]; + } + bool calc_bit(int bit) const { + return pcode_ + ->data()[index_ * static_cast(ptable_->dims()[1]) + bit]; + } + int get_length() const { + int length = 0; + + for (int i = 0; i < ptable_->dims()[1]; i++) { + if (ptable_->data()[index_ * static_cast(ptable_->dims()[1]) + + i] != -1) { + length++; + } else { + return length; + } + } + return length; + } + + private: + const framework::Tensor* ptable_; + const framework::Tensor* pcode_; + const int64_t* ids_; + const int index_; +}; + +class SimpleCodeTable : public CodeTable { + public: + explicit SimpleCodeTable(size_t num_classes, const int64_t* ids) + : num_classes_(num_classes), ids_(ids) {} + std::unique_ptr get_code(int64_t code) const { + std::unique_ptr coder(new SimpleCode(code, num_classes_, ids_)); + return coder; } size_t size() const { return num_classes_; } int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } private: size_t num_classes_; + const int64_t* ids_; +}; + +template +class CustomCodeTable : public CodeTable { + public: + explicit CustomCodeTable(const framework::Tensor* ptable, + const framework::Tensor* pcode, const int64_t* ids) + : ptable_(ptable), pcode_(pcode), ids_(ids) {} + + std::unique_ptr get_code(int64_t code) const { + std::unique_ptr coder(new CustomCode(ptable_, pcode_, ids_, code)); + return coder; + } + + size_t size() const { return static_cast(ptable_->dims()[1]); } + int get_max_code_length() const { + return static_cast(ptable_->dims()[1]); + } + + private: + const framework::Tensor* ptable_; + const framework::Tensor* pcode_; + const int64_t* ids_; }; template class MatrixBitCodeFunctor { public: explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) - : num_classes_(num_classes), ids_(ids) {} + : num_classes_(num_classes), + ids_(ids), + code_table(new SimpleCodeTable(num_classes, ids)) {} + + explicit MatrixBitCodeFunctor(const framework::Tensor* ptable, + const framework::Tensor* pcode, + const int64_t* ids) + : num_classes_(static_cast(ptable->dims()[1])), + ids_(ids), + code_table(new CustomCodeTable(ptable, pcode, ids)) {} /* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ @@ -168,6 +266,7 @@ class MatrixBitCodeFunctor { size_t num_classes_; const int64_t* ids_; + std::unique_ptr code_table; }; } // namespace math } // namespace operators diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 110e6d5ab2..d3ee80ad52 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4349,6 +4349,8 @@ def nce(input, def hsigmoid(input, label, num_classes, + ptabl=None, + pcode=None, param_attr=None, bias_attr=None, name=None): @@ -4372,6 +4374,12 @@ def hsigmoid(input, label (Variable): The tensor variable contains labels of training data. It's a tensor with shape is :math:`[N \\times 1]`. num_classes: (int), The number of classes, must not be less than 2. + ptable: (Variable|None) this variable can store each batch of samples' path to root, + it should be in leaf -> root order + ptable should have the same shape with pcode, and for each sample i ptable[i] indicates a np.array like + structure and each element in this array is indexes in parent nodes' Weight Matrix. + pcode: (Variable|None) this variable can store each batch of samples' code, + each code consist with every code of parent nodes. it should be in leaf -> root order param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create ParamAttr as param_attr. If the Initializer of the param_attr @@ -4403,12 +4411,25 @@ def hsigmoid(input, dim = input.shape[1] if num_classes < 2: raise ValueError("num_classes must not be less than 2.") + if (ptable is not None) and (pcode is None): + raise ValueError("pcode should not be None when ptable has been set") + elif (ptable is None) and (pcode is not None): + raise ValueError("ptable should not be None when pcode has been set") + else: + pass + weights = helper.create_parameter( attr=helper.param_attr, shape=[num_classes - 1, dim], is_bias=False, dtype=input.dtype) - inputs = {"X": input, "W": weights, "Label": label} + inputs = { + "X": input, + "W": weights, + "PTable": ptable, + "PCode": pcode, + "Label": label + } if helper.bias_attr: bias = helper.create_parameter( attr=helper.bias_attr, diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index e97643cdde..fb521e86a3 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -138,8 +138,11 @@ class OpTest(unittest.TestCase): cls.dtype = "float32" cls.outputs = {} - np.random.seed(123) - random.seed(124) + # np.random.seed(123) + # random.seed(124) + + np.random.seed(190) + random.seed(200) @classmethod def tearDownClass(cls): diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 6948ae3002..4beeed0131 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -40,6 +40,29 @@ class CodeTable(object): return self.c & (1 << bit) +class CodeTableWithCustomTree(object): + def __init__(self, ptable, pcode, index): + self.ptable_ = ptable + self.pcode_ = pcode + self.index_ = index + + def cal_index(self, bit): + return self.ptable_[self.index_][bit] + + def get_length(self): + length = 0 + for ele in self.ptable_[self.index_]: + + if ele >= 0: + length = length + 1 + else: + return length + return length + + def cal_bit(self, bit): + return self.pcode_[self.index_][bit] + + def hsigmoid(x, w, label, bias, num_classes): batch_size = x.shape[0] code_length = find_latest_set(num_classes - 1) @@ -48,10 +71,12 @@ def hsigmoid(x, w, label, bias, num_classes): pre_sum = np.zeros((batch_size, 1)) out = np.zeros((batch_size, 1)).astype("float32") for i in range(batch_size): + #print("\n leaf {leaf}: \n".format(leaf = label[i])) code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() for j in range(length): idx = code_table.cal_index(j) + #print("index {index} ".format(index = j)) pre_output[i][j] += bias[0][idx] for i in range(batch_size): code_table = CodeTable(num_classes, label[i]) @@ -63,10 +88,12 @@ def hsigmoid(x, w, label, bias, num_classes): pre_output = np.clip(pre_output, -40.0, 40.0) # out(i, 0) = \sum_j bit(i, j) * preout(i, j) for i in range(batch_size): + #print("\n leaf {leaf}: \n".format(leaf = label[i])) code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() sum = 0.0 for j in range(length): + #print("bit {bit} ".format(bit = code_table.cal_bit(j))) if code_table.cal_bit(j): sum += pre_output[i][j] out[i] = -1.0 * sum @@ -77,25 +104,101 @@ def hsigmoid(x, w, label, bias, num_classes): return pre_output, out -class TestHSigmoidOp(OpTest): +def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): + batch_size = x.shape[0] + code_length = len(ptable[0]) + code_table = [0 for _ in range(code_length)] + pre_output = np.zeros((batch_size, code_length)) + pre_sum = np.zeros((batch_size, 1)) + out = np.zeros((batch_size, 1)).astype("float32") + for i in range(batch_size): + code_table = CodeTableWithCustomTree(ptable, pcode, i) + length = code_table.get_length() + for j in range(length): + idx = code_table.cal_index(j) + pre_output[i][j] += bias[0][idx] + for i in range(batch_size): + code_table = CodeTableWithCustomTree(ptable, pcode, i) + length = code_table.get_length() + for j in range(length): + idx = code_table.cal_index(j) + pre_output[i][j] += np.dot(w[idx], x[i]) + # clip[-40.0, 40.0] + pre_output = np.clip(pre_output, -40.0, 40.0) + # out(i, 0) = \sum_j bit(i, j) * preout(i, j) + for i in range(batch_size): + code_table = CodeTableWithCustomTree(ptable, pcode, i) + length = code_table.get_length() + sum = 0.0 + for j in range(length): + if code_table.cal_bit(j): + sum += pre_output[i][j] + out[i] = -1.0 * sum + # soft relu + pre_output = np.log(1 + np.exp(pre_output)) + pre_sum = pre_output.sum(1).reshape((batch_size, 1)) + out += pre_sum + return pre_output, out + + +# class TestHSigmoidOp(OpTest): +# def setUp(self): +# self.op_type = "hierarchical_sigmoid" +# num_classes = 6 +# feature_size = 8 +# batch_size = 7 +# x = np.random.random((batch_size, feature_size)).astype("float32") +# w = np.random.random((num_classes - 1, feature_size)).astype("float32") +# label = np.random.randint(0, num_classes, (batch_size, 1)) +# bias = np.random.random((1, num_classes - 1)).astype("float32") +# self.attrs = {'num_classes': num_classes} +# self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} +# pre_output, out = hsigmoid(x, w, label, bias, num_classes) +# self.outputs = {'PreOut': pre_output, 'Out': out} + +# def test_check_output(self): +# self.check_output() + +# def test_check_grad(self): +# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + + +class TestHSigmoidOpWithCostumTree(OpTest): def setUp(self): self.op_type = "hierarchical_sigmoid" - num_classes = 6 + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample feature_size = 8 batch_size = 4 - x = np.random.random((batch_size, feature_size)).astype("float32") - w = np.random.random((num_classes - 1, feature_size)).astype("float32") - label = np.random.randint(0, num_classes, (batch_size, 1)) + x = np.random.random((batch_size, feature_size)).astype("float32") * 10 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 10 + label = np.array([0, 1, 4, 5]) + ptable = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store bias = np.random.random((1, num_classes - 1)).astype("float32") self.attrs = {'num_classes': num_classes} - self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} - pre_output, out = hsigmoid(x, w, label, bias, num_classes) + self.inputs = { + 'X': x, + 'W': w, + 'PTable': ptable, + 'PCode': pcode, + 'Label': label, + 'Bias': bias + } + pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, + bias, num_classes) self.outputs = {'PreOut': pre_output, 'Out': out} def test_check_output(self): + print("checking output in CostumTree") self.check_output() def test_check_grad(self): + print("checking outputGrad in CostumTree") self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) From 32e05b01f294b8ea5d742294fc8b4f4e69985f0a Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 12 Nov 2018 11:36:48 +0000 Subject: [PATCH 02/23] test=develop --- .../fluid/operators/hierarchical_sigmoid_op.h | 9 ++++ paddle/fluid/operators/math/matrix_bit_code.h | 2 +- .../paddle/fluid/tests/unittests/op_test.py | 7 +-- .../fluid/tests/unittests/test_hsigmoid_op.py | 53 ++++++++++--------- 4 files changed, 40 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 2d500a03df..90bdb47311 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -86,6 +86,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out->numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); + pre_out_mat = -1 * pre_out_mat; bit_code->Sum(*pre_out, out, static_cast(-1)); // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); @@ -146,6 +147,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto pre_out_mat = EigenMatrix::From(*pre_out); auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); auto out_grad_mat = EigenMatrix::From(*out_grad); + Eigen::array bcast({{1, static_cast(pre_out_grad.dims()[1])}}); // softrelu derivative @@ -160,9 +162,16 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { bias_grad->mutable_data(ctx.GetPlace()); zero(dev_ctx, bias_grad, static_cast(0.0)); bit_code->AddGrad(pre_out_grad, bias_grad); + auto bias_grad_mat = EigenMatrix::From(*bias_grad); + bias_grad_mat = -1 * bias_grad_mat; } bit_code->MulGradWeight(pre_out_grad, w_grad, *in); bit_code->MulGradError(pre_out_grad, *w, in_grad); + auto w_grad_mat = EigenMatrix::From(*w_grad); + auto in_grad_mat = EigenMatrix::From(*in_grad); + + w_grad_mat = -1 * w_grad_mat; + in_grad_mat = -1 * in_grad_mat; } }; diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index f03c8d3689..1e2abd1e69 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -157,7 +157,7 @@ class CustomCode : public Code { int get_length() const { int length = 0; - for (int i = 0; i < ptable_->dims()[1]; i++) { + for (int i = 0; i < static_cast(ptable_->dims()[1]); i++) { if (ptable_->data()[index_ * static_cast(ptable_->dims()[1]) + i] != -1) { length++; diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index fb521e86a3..e97643cdde 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -138,11 +138,8 @@ class OpTest(unittest.TestCase): cls.dtype = "float32" cls.outputs = {} - # np.random.seed(123) - # random.seed(124) - - np.random.seed(190) - random.seed(200) + np.random.seed(123) + random.seed(124) @classmethod def tearDownClass(cls): diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 4beeed0131..0a16f5a39c 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -17,6 +17,9 @@ from __future__ import print_function import unittest import numpy as np import math +# import paddle.fluid as fluid +# import paddle.fluid.core as core +# from op_builder import OpBuilder from op_test import OpTest np.random.seed(100) @@ -51,7 +54,7 @@ class CodeTableWithCustomTree(object): def get_length(self): length = 0 - for ele in self.ptable_[self.index_]: + for ele in self.ptable_[self.index_]: # find the first -1 to stop trace if ele >= 0: length = length + 1 @@ -71,12 +74,10 @@ def hsigmoid(x, w, label, bias, num_classes): pre_sum = np.zeros((batch_size, 1)) out = np.zeros((batch_size, 1)).astype("float32") for i in range(batch_size): - #print("\n leaf {leaf}: \n".format(leaf = label[i])) code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() for j in range(length): idx = code_table.cal_index(j) - #print("index {index} ".format(index = j)) pre_output[i][j] += bias[0][idx] for i in range(batch_size): code_table = CodeTable(num_classes, label[i]) @@ -87,13 +88,12 @@ def hsigmoid(x, w, label, bias, num_classes): # clip[-40.0, 40.0] pre_output = np.clip(pre_output, -40.0, 40.0) # out(i, 0) = \sum_j bit(i, j) * preout(i, j) + pre_output = -1 * pre_output for i in range(batch_size): - #print("\n leaf {leaf}: \n".format(leaf = label[i])) code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() sum = 0.0 for j in range(length): - #print("bit {bit} ".format(bit = code_table.cal_bit(j))) if code_table.cal_bit(j): sum += pre_output[i][j] out[i] = -1.0 * sum @@ -108,6 +108,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): batch_size = x.shape[0] code_length = len(ptable[0]) code_table = [0 for _ in range(code_length)] + # init pre_out with shape [N, code_length] pre_output = np.zeros((batch_size, code_length)) pre_sum = np.zeros((batch_size, 1)) out = np.zeros((batch_size, 1)).astype("float32") @@ -125,6 +126,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): pre_output[i][j] += np.dot(w[idx], x[i]) # clip[-40.0, 40.0] pre_output = np.clip(pre_output, -40.0, 40.0) + pre_output = -1 * pre_output # out(i, 0) = \sum_j bit(i, j) * preout(i, j) for i in range(batch_size): code_table = CodeTableWithCustomTree(ptable, pcode, i) @@ -141,26 +143,27 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): return pre_output, out -# class TestHSigmoidOp(OpTest): -# def setUp(self): -# self.op_type = "hierarchical_sigmoid" -# num_classes = 6 -# feature_size = 8 -# batch_size = 7 -# x = np.random.random((batch_size, feature_size)).astype("float32") -# w = np.random.random((num_classes - 1, feature_size)).astype("float32") -# label = np.random.randint(0, num_classes, (batch_size, 1)) -# bias = np.random.random((1, num_classes - 1)).astype("float32") -# self.attrs = {'num_classes': num_classes} -# self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} -# pre_output, out = hsigmoid(x, w, label, bias, num_classes) -# self.outputs = {'PreOut': pre_output, 'Out': out} +class TestHSigmoidOp(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 + label = np.random.randint(0, num_classes, (batch_size, 1)) + bias = np.random.random((1, num_classes - 1)).astype("float32") + self.attrs = {'num_classes': num_classes} + self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} + pre_output, out = hsigmoid(x, w, label, bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} -# def test_check_output(self): -# self.check_output() + def test_check_output(self): + self.check_output() -# def test_check_grad(self): -# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + def test_check_grad(self): + self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) class TestHSigmoidOpWithCostumTree(OpTest): @@ -169,9 +172,9 @@ class TestHSigmoidOpWithCostumTree(OpTest): num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample feature_size = 8 batch_size = 4 - x = np.random.random((batch_size, feature_size)).astype("float32") * 10 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 w = np.random.random( - (num_classes - 1, feature_size)).astype("float32") * 10 + (num_classes - 1, feature_size)).astype("float32") * 2 label = np.array([0, 1, 4, 5]) ptable = np.array( [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), From b8ff0972b63238dbc0fb853615967f8e339a30b7 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 12 Nov 2018 12:05:31 +0000 Subject: [PATCH 03/23] test=develop --- paddle/fluid/operators/hierarchical_sigmoid_op.h | 8 -------- python/paddle/fluid/tests/unittests/test_hsigmoid_op.py | 2 -- 2 files changed, 10 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 90bdb47311..df4f5f561a 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -86,7 +86,6 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out->numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); - pre_out_mat = -1 * pre_out_mat; bit_code->Sum(*pre_out, out, static_cast(-1)); // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); @@ -162,16 +161,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { bias_grad->mutable_data(ctx.GetPlace()); zero(dev_ctx, bias_grad, static_cast(0.0)); bit_code->AddGrad(pre_out_grad, bias_grad); - auto bias_grad_mat = EigenMatrix::From(*bias_grad); - bias_grad_mat = -1 * bias_grad_mat; } bit_code->MulGradWeight(pre_out_grad, w_grad, *in); bit_code->MulGradError(pre_out_grad, *w, in_grad); - auto w_grad_mat = EigenMatrix::From(*w_grad); - auto in_grad_mat = EigenMatrix::From(*in_grad); - - w_grad_mat = -1 * w_grad_mat; - in_grad_mat = -1 * in_grad_mat; } }; diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 0a16f5a39c..6152b96912 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -88,7 +88,6 @@ def hsigmoid(x, w, label, bias, num_classes): # clip[-40.0, 40.0] pre_output = np.clip(pre_output, -40.0, 40.0) # out(i, 0) = \sum_j bit(i, j) * preout(i, j) - pre_output = -1 * pre_output for i in range(batch_size): code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() @@ -126,7 +125,6 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): pre_output[i][j] += np.dot(w[idx], x[i]) # clip[-40.0, 40.0] pre_output = np.clip(pre_output, -40.0, 40.0) - pre_output = -1 * pre_output # out(i, 0) = \sum_j bit(i, j) * preout(i, j) for i in range(batch_size): code_table = CodeTableWithCustomTree(ptable, pcode, i) From f4be1d99d0a9c334d6b4ee8d6c557ea0d936f58a Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 13 Nov 2018 06:19:26 +0000 Subject: [PATCH 04/23] polish code and test --- .../operators/hierarchical_sigmoid_op.cc | 2 +- python/paddle/fluid/layers/nn.py | 66 +++++++++++++------ .../fluid/tests/unittests/test_layers.py | 17 +++++ 3 files changed, 63 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 49a17416c8..8d4e0556dd 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -115,7 +115,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { "[batch_size, code_length], where code_length represents the " "maximum path length from root to leaf nodes.") .AsIntermediate(); - AddAttr("num_classes", "(int, required), The number of classes") + AddAttr("num_classes", "(int, optional), The number of classes") .SetDefault(2); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d3ee80ad52..835ec4506a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4348,12 +4348,14 @@ def nce(input, def hsigmoid(input, label, - num_classes, - ptabl=None, + num_classes=None, + non_leaf_num=None, + ptable=None, pcode=None, param_attr=None, bias_attr=None, - name=None): + name=None, + is_costum=False): """ The hierarchical sigmoid operator is used to accelerate the training process of language model. This operator organizes the classes into a @@ -4373,7 +4375,8 @@ def hsigmoid(input, and :math:`D` is the feature size. label (Variable): The tensor variable contains labels of training data. It's a tensor with shape is :math:`[N \\times 1]`. - num_classes: (int), The number of classes, must not be less than 2. + num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set + non_leaf_num: this defines the number of non-leaf nodes in costumed tree ptable: (Variable|None) this variable can store each batch of samples' path to root, it should be in leaf -> root order ptable should have the same shape with pcode, and for each sample i ptable[i] indicates a np.array like @@ -4409,20 +4412,33 @@ def hsigmoid(input, out = helper.create_variable_for_type_inference(dtype) pre_out = helper.create_variable_for_type_inference(dtype) dim = input.shape[1] - if num_classes < 2: - raise ValueError("num_classes must not be less than 2.") - if (ptable is not None) and (pcode is None): - raise ValueError("pcode should not be None when ptable has been set") - elif (ptable is None) and (pcode is not None): - raise ValueError("ptable should not be None when pcode has been set") + if ((num_classes < 2) or (num_classes is None)) and (not is_costum): + raise ValueError( + "num_classes must not be less than 2 with default tree") + + if (is_costum) and (pcode is None): + raise ValueError("pcode should not be None with costum tree") + elif (is_costum) and (ptable is None): + raise ValueError("ptable should not be None with costum tree") + elif (is_costum) and (non_leaf_num is None): + raise ValueError("non_leaf_num should not be None with costum tree") else: pass - weights = helper.create_parameter( - attr=helper.param_attr, - shape=[num_classes - 1, dim], - is_bias=False, - dtype=input.dtype) + weights = None + + if not is_costum: + weights = helper.create_parameter( + attr=helper.param_attr, + shape=[num_classes - 1, dim], + is_bias=False, + dtype=input.dtype) + else: + weights = helper.create_parameter( + attr=helper.param_attr, + shape=[non_leaf_num, dim], + is_bias=False, + dtype=input.dtype) inputs = { "X": input, "W": weights, @@ -4431,12 +4447,20 @@ def hsigmoid(input, "Label": label } if helper.bias_attr: - bias = helper.create_parameter( - attr=helper.bias_attr, - shape=[1, num_classes - 1], - is_bias=True, - dtype=input.dtype) - inputs['Bias'] = bias + if not is_costum: + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=[1, num_classes - 1], + is_bias=True, + dtype=input.dtype) + inputs['Bias'] = bias + else: + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=[1, non_leaf_num], + is_bias=True, + dtype=input.dtype) + inputs['Bias'] = bias helper.append_op( type="hierarchical_sigmoid", inputs=inputs, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 50de468dba..b067e6213c 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -185,6 +185,23 @@ class TestBook(unittest.TestCase): input=x, label=y, num_classes=2)) print(str(program)) + program2 = Program() + + with program_guard(program2): + x2 = layers.data(name='x2', shape=[4, 8], dtype='float32') + y2 = layers.data(name='y2', shape=[4], dtype='int64') + ptable = layers.data(name='ptable', shape=[4, 6], dtype='int64') + pcode = layers.data(name='pcode', shape=[4, 6], dtype='int64') + self.assertIsNotNone( + layers.hsigmoid( + input=x2, + label=y2, + non_leaf_num=6, + ptable=ptable, + pcode=pcode, + is_costum=True)) + print(str(program2)) + def test_sequence_expand(self): program = Program() with program_guard(program): From 30332ad91d6c69b841d7ead0bb000b5964287a7b Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 13 Nov 2018 06:34:56 +0000 Subject: [PATCH 05/23] test=develop --- python/paddle/fluid/tests/unittests/test_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index b067e6213c..4379aeb993 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -185,8 +185,8 @@ class TestBook(unittest.TestCase): input=x, label=y, num_classes=2)) print(str(program)) + # test hsigmod with custom tree structure program2 = Program() - with program_guard(program2): x2 = layers.data(name='x2', shape=[4, 8], dtype='float32') y2 = layers.data(name='y2', shape=[4], dtype='int64') From db06568e693a724b5578ab6c77d9db833d253f18 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 13 Nov 2018 08:26:13 +0000 Subject: [PATCH 06/23] test=develop --- paddle/fluid/API.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 3bbe7c2b8c..d64939413b 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -98,7 +98,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs= paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None)) -paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) +paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'non_leaf_num', 'ptable', 'pcode', 'param_attr', 'bias_attr', 'name', 'is_costum'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, False)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) From 99d1446a8ba3bddf899026a030ed6ab2f44a6531 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Wed, 14 Nov 2018 05:49:51 +0000 Subject: [PATCH 07/23] test=develop --- python/paddle/fluid/layers/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 835ec4506a..4472f20409 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4412,7 +4412,7 @@ def hsigmoid(input, out = helper.create_variable_for_type_inference(dtype) pre_out = helper.create_variable_for_type_inference(dtype) dim = input.shape[1] - if ((num_classes < 2) or (num_classes is None)) and (not is_costum): + if ((num_classes is None) or (num_classes < 2)) and (not is_costum): raise ValueError( "num_classes must not be less than 2 with default tree") From a507845a7735af6552f035f27902d2758bd36bcb Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Wed, 14 Nov 2018 06:13:41 +0000 Subject: [PATCH 08/23] test=develop --- paddle/fluid/operators/math/matrix_bit_code.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index 1e2abd1e69..39c3b1520b 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -159,7 +159,7 @@ class CustomCode : public Code { for (int i = 0; i < static_cast(ptable_->dims()[1]); i++) { if (ptable_->data()[index_ * static_cast(ptable_->dims()[1]) + - i] != -1) { + i] >= 0) { length++; } else { return length; From ba9ff508e8339319c926b105e9ffb32f7332977a Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 15 Nov 2018 08:43:36 +0000 Subject: [PATCH 09/23] temp fix --- .../fluid/operators/math/matrix_bit_code.cc | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 88279f8d8a..090c0cca36 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -119,6 +119,33 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, } } +// template +// void MatrixBitCodeFunctor::MulGradSparseWeight(const framework::Tensor& +// tmat, +// framework::SelectedRows* weight, +// const framework::Tensor& input) { +// size_t num_samples = tmat.dims()[0]; +// size_t input_width = input.dims()[1]; +// size_t tmat_width = tmat.dims()[1]; +// size_t weight_width = weight->dims()[1]; +// auto tmat_value = tmat.data(); +// auto weight_value = weight->data(); +// auto input_value = input.data(); +// for (size_t i = 0; i < num_samples; ++i) { +// auto code = code_table->get_code(i); +// int code_length = code->get_length(); +// for (int j = 0; j < code_length; ++j) { +// // size_t index = code->calc_index(j); + +// for (size_t k = 0; k < input_width; ++k) { +// weight_value[j * weight_width + k] += +// tmat_value[i * tmat_width + j] * input_value[input_width * i + +// k]; +// } +// } +// } +// } + template void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, const framework::Tensor& weight, From 014e50c284eb9698cc02d0457f8eb3b566687e70 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Wed, 21 Nov 2018 07:53:15 +0000 Subject: [PATCH 10/23] test=develop --- paddle/fluid/framework/mixed_vector.h | 6 + .../operators/hierarchical_sigmoid_op.cc | 68 ++++-- .../fluid/operators/hierarchical_sigmoid_op.h | 92 +++++--- .../fluid/operators/math/matrix_bit_code.cc | 85 ++++---- paddle/fluid/operators/math/matrix_bit_code.h | 53 +++-- python/paddle/fluid/layers/nn.py | 10 +- .../fluid/tests/unittests/test_hsigmoid_op.py | 206 ++++++++++++------ 7 files changed, 349 insertions(+), 171 deletions(-) diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index e1aac6dc5a..cd06da9d05 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -533,6 +533,12 @@ class CPUVector : public std::vector> { return os; } + size_t size() const noexcept { + size_t size = + static_cast(std::vector>::size()); + return size; + } + T &operator[](size_t id) { return this->at(id); } const T &operator[](size_t id) const { return this->at(id); } diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 8d4e0556dd..b2f4616441 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -70,13 +70,14 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { const int64_t batch_size = ctx->GetInputDim("X")[0]; std::vector output_shape({batch_size, 1}); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->ShareLoD("X", /*->*/ "Out"); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); } }; @@ -86,32 +87,34 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, required) The input tensor with shape [N, D], " + "(LoDTensor, required) The input tensor with shape [N, D], " "where N is the size of mini-batch, and D is the feature size."); AddInput("W", - "(Tensor, required), The parameters of hierarchical " + "(LoDTensor, required), The parameters of hierarchical " "sigmoid operator, each of them is a 2-D tensor, the shape is" "[K, D]. Which K is the num of non-leaf node in Path Tree"); AddInput("Label", - "(Tensor, required), The labels of training data. It's a" + "(LoDTensor, required), The labels of training data. It's a" "tensor with shape [N, 1]."); AddInput("PTable", - "(Tensor, optional), The Path Table from root to current word" + "(LoDTensor, optional), The Path Table from root to current word" "it should have shape like [N, L], L is the length of the Path") .AsDispensable(); - AddInput("PCode", - "(Tensor, optional), The Code on each Node of the Path from root " - "to current word" - "it should have shape like [N, L], L is the length of the Path") + AddInput( + "PCode", + "(LoDTensor, optional), The Code on each Node of the Path from root " + "to current word" + "it should have shape like [N, L], L is the length of the Path") .AsDispensable(); AddInput("Bias", - "(Tensor, optional), The bias is a tensor with shape" + "(LoDTensor, optional), The bias is a tensor with shape" "[1, num_classes - 1]."); - AddOutput("Out", - "(Tensor, required) The output of hierarchical sigmoid operator." - "The shape is [N, 1]."); + AddOutput( + "Out", + "(LoDTensor, required) The output of hierarchical sigmoid operator." + "The shape is [N, 1]."); AddOutput("PreOut", - "(Tensor, required) A intermedia 2-D tensor with shape " + "(LoDTensor, required) A intermedia 2-D tensor with shape " "[batch_size, code_length], where code_length represents the " "maximum path length from root to leaf nodes.") .AsIntermediate(); @@ -124,6 +127,10 @@ belonging to the right branch. This idea is from "F. Morin, Y. Bengio (AISTATS 05): Hierarchical Probabilistic Neural Network Language Model." )DOC"); + AddAttr("is_sparse", + "(boolean, default false) " + "Sparse update.") + .SetDefault(false); } }; @@ -133,6 +140,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@Grad) should not be null"); PADDLE_ENFORCE(ctx->HasInput("PreOut"), "Input(Preout) should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), @@ -142,7 +151,9 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); } - ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); + if (!ctx->Attrs().Get("is_sparse")) { + ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); + } ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } @@ -150,11 +161,33 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); } }; +class HierarchicalSigmoidGradOpGradVarTypeInference + : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); + auto attr = op_desc.GetAttr("is_sparse"); + bool is_sparse = boost::get(attr); + if (is_sparse) { + VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") + << " is set to SelectedRows"; + block->Var(out_var_name) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + } else { + VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") + << " is set to LoDTensor"; + block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); + } + block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); + } +}; + } // namespace operators } // namespace paddle @@ -162,7 +195,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, ops::HierarchicalSigmoidOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); +REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp, + ops::HierarchicalSigmoidGradOpGradVarTypeInference); REGISTER_OP_CPU_KERNEL( hierarchical_sigmoid, ops::HierarchicalSigmoidOpKernel, diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index df4f5f561a..3e2fbafa26 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -14,9 +14,10 @@ limitations under the License. */ #pragma once #include +#include #include +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/matrix_bit_code.h" @@ -29,18 +30,37 @@ template ; using platform::Transform; +std::vector cal_rows(const framework::LoDTensor* path) { + std::set tmp; + std::vector rows; + rows.clear(); + for (size_t i = 0; i < static_cast(path->dims()[0]); i++) { + for (size_t j = 0; j < static_cast(path->dims()[1]); j++) { + int64_t temp = + path->data()[i * static_cast(path->dims()[1]) + j]; + if (temp >= 0) { + tmp.insert(temp); + } + } + } + for (std::set::iterator it = tmp.begin(); it != tmp.end(); ++it) { + rows.push_back(*it); + } + return rows; +} + template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* w = ctx.Input("W"); - auto* path = ctx.Input("PTable"); - auto* code = ctx.Input("PCode"); - auto* label = ctx.Input("Label"); - auto* bias = ctx.Input("Bias"); - auto* out = ctx.Output("Out"); - auto* pre_out = ctx.Output("PreOut"); + auto* in = ctx.Input("X"); + auto* w = ctx.Input("W"); + auto* path = ctx.Input("PTable"); + auto* code = ctx.Input("PCode"); + auto* label = ctx.Input("Label"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + auto* pre_out = ctx.Output("PreOut"); size_t num_classes = static_cast(ctx.Attr("num_classes")); bool is_custom = false; if (path) { @@ -51,7 +71,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { int64_t code_length = path ? path->dims()[1] : math::FindLastSet(num_classes - 1); int64_t batch_size = in->dims()[0]; - framework::Tensor sum; + framework::LoDTensor sum; auto& dev_ctx = ctx.template device_context(); auto* pre_out_data = pre_out->mutable_data( framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); @@ -102,27 +122,26 @@ template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* w = ctx.Input("W"); - auto* path = ctx.Input("PTable"); - auto* code = ctx.Input("PCode"); - auto* in_grad = ctx.Output(framework::GradVarName("X")); - auto* w_grad = ctx.Output(framework::GradVarName("W")); + auto* in = ctx.Input("X"); + auto* w = ctx.Input("W"); + auto* path = ctx.Input("PTable"); + auto* code = ctx.Input("PCode"); + auto* in_grad = + ctx.Output(framework::GradVarName("X")); + bool is_sparse = ctx.Attr("is_sparse"); + auto& dev_ctx = ctx.template device_context(); + math::SetConstant zero; auto* bias_grad = - ctx.Output(framework::GradVarName("Bias")); - auto* label = ctx.Input("Label"); - auto* pre_out = ctx.Input("PreOut"); + ctx.Output(framework::GradVarName("Bias")); + auto* label = ctx.Input("Label"); + auto* pre_out = ctx.Input("PreOut"); auto* out_grad = - ctx.Input(framework::GradVarName("Out")); - framework::Tensor pre_out_grad; + ctx.Input(framework::GradVarName("Out")); + framework::LoDTensor pre_out_grad; pre_out_grad.mutable_data(pre_out->dims(), ctx.GetPlace()); in_grad->mutable_data(ctx.GetPlace()); - w_grad->mutable_data(ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); - math::SetConstant zero; zero(dev_ctx, in_grad, static_cast(0.0)); - zero(dev_ctx, w_grad, static_cast(0.0)); size_t num_classes = static_cast(ctx.Attr("num_classes")); @@ -162,7 +181,28 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { zero(dev_ctx, bias_grad, static_cast(0.0)); bit_code->AddGrad(pre_out_grad, bias_grad); } - bit_code->MulGradWeight(pre_out_grad, w_grad, *in); + if (!is_sparse) { + auto* w_grad = + ctx.Output(framework::GradVarName("W")); + w_grad->mutable_data(ctx.GetPlace()); + zero(dev_ctx, w_grad, static_cast(0.0)); + bit_code->MulGradWeight(pre_out_grad, w_grad, *in); + } else { + framework::Vector real_rows = cal_rows(path); + auto* w_grad = + ctx.Output(framework::GradVarName("W")); + + w_grad->set_rows(real_rows); + // build ids -> rows index map + w_grad->SyncIndex(); + auto* w_grad_value = w_grad->mutable_value(); + framework::DDim temp_dim(w->dims()); + set(temp_dim, 0, real_rows.size()); + + w_grad_value->mutable_data(temp_dim, ctx.GetPlace()); + zero(dev_ctx, w_grad_value, static_cast(0.0)); + bit_code->MulGradWeight(pre_out_grad, w_grad, *in); + } bit_code->MulGradError(pre_out_grad, *w, in_grad); } }; diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 090c0cca36..8baffe1ba1 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -19,8 +19,8 @@ namespace operators { namespace math { template -void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, - const framework::Tensor& vec) { +void MatrixBitCodeFunctor::Add(framework::LoDTensor* tmat, + const framework::LoDTensor& vec) { size_t batch_size = tmat->dims()[0]; size_t width = tmat->dims()[1]; for (size_t i = 0; i < batch_size; ++i) { @@ -34,8 +34,8 @@ void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, } template -void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, - framework::Tensor* vec) { +void MatrixBitCodeFunctor::AddGrad(const framework::LoDTensor& tmat, + framework::LoDTensor* vec) { size_t batch_size = tmat.dims()[0]; size_t width = tmat.dims()[1]; for (size_t i = 0; i < batch_size; ++i) { @@ -49,8 +49,8 @@ void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, } template -void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, - framework::Tensor* sum, T scale_sum) { +void MatrixBitCodeFunctor::Sum(const framework::LoDTensor& tmat, + framework::LoDTensor* sum, T scale_sum) { size_t num_samples = tmat.dims()[0]; size_t o_width = tmat.dims()[1]; for (size_t i = 0; i < num_samples; ++i) { @@ -69,9 +69,9 @@ void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, } template -void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, - const framework::Tensor& weight, - const framework::Tensor& input) { +void MatrixBitCodeFunctor::Mul(framework::LoDTensor* tmat, + const framework::LoDTensor& weight, + const framework::LoDTensor& input) { size_t num_samples = tmat->dims()[0]; size_t tmat_width = tmat->dims()[1]; size_t input_width = input.dims()[1]; @@ -95,9 +95,9 @@ void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, } template -void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, - framework::Tensor* weight, - const framework::Tensor& input) { +void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, + framework::LoDTensor* weight, + const framework::LoDTensor& input) { size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; @@ -119,37 +119,38 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, } } -// template -// void MatrixBitCodeFunctor::MulGradSparseWeight(const framework::Tensor& -// tmat, -// framework::SelectedRows* weight, -// const framework::Tensor& input) { -// size_t num_samples = tmat.dims()[0]; -// size_t input_width = input.dims()[1]; -// size_t tmat_width = tmat.dims()[1]; -// size_t weight_width = weight->dims()[1]; -// auto tmat_value = tmat.data(); -// auto weight_value = weight->data(); -// auto input_value = input.data(); -// for (size_t i = 0; i < num_samples; ++i) { -// auto code = code_table->get_code(i); -// int code_length = code->get_length(); -// for (int j = 0; j < code_length; ++j) { -// // size_t index = code->calc_index(j); - -// for (size_t k = 0; k < input_width; ++k) { -// weight_value[j * weight_width + k] += -// tmat_value[i * tmat_width + j] * input_value[input_width * i + -// k]; -// } -// } -// } -// } +template +void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, + framework::SelectedRows* weight, + const framework::LoDTensor& input) { + size_t num_samples = tmat.dims()[0]; + size_t input_width = input.dims()[1]; + size_t tmat_width = tmat.dims()[1]; + size_t weight_width = weight->value().dims()[1]; + auto tmat_value = tmat.data(); + auto weight_value = weight->mutable_value()->data(); + auto input_value = input.data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table->get_code(i); + int code_length = code->get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code->calc_index(j); + + for (size_t k = 0; k < input_width; ++k) { + int64_t row_index = + weight->AutoGrownIndex(static_cast(index), false); + + weight_value[row_index * weight_width + k] += + tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; + } + } + } +} template -void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, - const framework::Tensor& weight, - framework::Tensor* input) { +void MatrixBitCodeFunctor::MulGradError(const framework::LoDTensor& tmat, + const framework::LoDTensor& weight, + framework::LoDTensor* input) { size_t num_samples = tmat.dims()[0]; size_t tmat_width = tmat.dims()[1]; size_t input_width = input->dims()[1]; @@ -174,7 +175,7 @@ void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, } template -void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { +void MatrixBitCodeFunctor::Sub(framework::LoDTensor* tmat) { size_t num_samples = tmat->dims()[0]; size_t o_width = tmat->dims()[1]; for (size_t i = 0; i < num_samples; ++i) { diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index 39c3b1520b..e4fe43ce98 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -134,8 +136,9 @@ class SimpleCode : public Code { template class CustomCode : public Code { public: - CustomCode(const framework::Tensor* ptable, const framework::Tensor* pcode, - const int64_t* ids, const int index) + CustomCode(const framework::LoDTensor* ptable, + const framework::LoDTensor* pcode, const int64_t* ids, + const int index) : ptable_(ptable), pcode_(pcode), ids_(ids), index_(index) {} /** * Here the id of root shoud be 1 rather than 0, thus the encoding of class c @@ -169,8 +172,8 @@ class CustomCode : public Code { } private: - const framework::Tensor* ptable_; - const framework::Tensor* pcode_; + const framework::LoDTensor* ptable_; + const framework::LoDTensor* pcode_; const int64_t* ids_; const int index_; }; @@ -194,8 +197,9 @@ class SimpleCodeTable : public CodeTable { template class CustomCodeTable : public CodeTable { public: - explicit CustomCodeTable(const framework::Tensor* ptable, - const framework::Tensor* pcode, const int64_t* ids) + explicit CustomCodeTable(const framework::LoDTensor* ptable, + const framework::LoDTensor* pcode, + const int64_t* ids) : ptable_(ptable), pcode_(pcode), ids_(ids) {} std::unique_ptr get_code(int64_t code) const { @@ -209,8 +213,8 @@ class CustomCodeTable : public CodeTable { } private: - const framework::Tensor* ptable_; - const framework::Tensor* pcode_; + const framework::LoDTensor* ptable_; + const framework::LoDTensor* pcode_; const int64_t* ids_; }; @@ -222,8 +226,8 @@ class MatrixBitCodeFunctor { ids_(ids), code_table(new SimpleCodeTable(num_classes, ids)) {} - explicit MatrixBitCodeFunctor(const framework::Tensor* ptable, - const framework::Tensor* pcode, + explicit MatrixBitCodeFunctor(const framework::LoDTensor* ptable, + const framework::LoDTensor* pcode, const int64_t* ids) : num_classes_(static_cast(ptable->dims()[1])), ids_(ids), @@ -231,38 +235,47 @@ class MatrixBitCodeFunctor { /* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ - void Add(framework::Tensor* tmat, const framework::Tensor& vec); + void Add(framework::LoDTensor* tmat, const framework::LoDTensor& vec); /* For j < code_length vec(0, index(i, j)) += tmat(i, j) */ - void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec); + void AddGrad(const framework::LoDTensor& tmat, framework::LoDTensor* vec); /* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) */ - void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum); + void Sum(const framework::LoDTensor& tmat, framework::LoDTensor* sum, + T scale_sum); /* For j < code_length tmat(i, j) -= bit(i, j) */ - void Sub(framework::Tensor* tmat); + void Sub(framework::LoDTensor* tmat); /* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ - void Mul(framework::Tensor* tmat, const framework::Tensor& weight, - const framework::Tensor& input); + void Mul(framework::LoDTensor* tmat, const framework::LoDTensor& weight, + const framework::LoDTensor& input); /* For index(i, j) >= 0: weight.row(index(i, j)) += tmat(i, j) * input.row(i) */ - void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, - const framework::Tensor& input); + void MulGradWeight(const framework::LoDTensor& tmat, + framework::LoDTensor* weight, + const framework::LoDTensor& input); + /* For SelectedRows Weight, For index(i, j) >= 0: + weight.row(index(i, j)) += tmat(i, j) * input.row(i) + */ + void MulGradWeight(const framework::LoDTensor& tmat, + framework::SelectedRows* weight, + const framework::LoDTensor& input); /* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ - void MulGradError(const framework::Tensor& tmat, - const framework::Tensor& weight, framework::Tensor* input); + void MulGradError(const framework::LoDTensor& tmat, + const framework::LoDTensor& weight, + framework::LoDTensor* input); size_t num_classes_; const int64_t* ids_; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4472f20409..7c92bdd882 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4355,7 +4355,8 @@ def hsigmoid(input, param_attr=None, bias_attr=None, name=None, - is_costum=False): + is_costum=False, + is_sparse=False): """ The hierarchical sigmoid operator is used to accelerate the training process of language model. This operator organizes the classes into a @@ -4394,9 +4395,11 @@ def hsigmoid(input, is not set, the bias is initialized zero. Default: None. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None. + is_costum: (bool|False)using user defined binary tree instead of default complete binary tree + is_sparse: (bool|False)using sparse update instead of dense update Returns: - Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] + Out: (LodTensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] Examples: @@ -4466,7 +4469,8 @@ def hsigmoid(input, inputs=inputs, outputs={"Out": out, "PreOut": pre_out}, - attrs={"num_classes": num_classes}) + attrs={"num_classes": num_classes, + "is_sparse": is_sparse}) return out diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 6152b96912..50dfaee76f 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -16,10 +16,9 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid.core as core +import paddle.fluid as fluid import math -# import paddle.fluid as fluid -# import paddle.fluid.core as core -# from op_builder import OpBuilder from op_test import OpTest np.random.seed(100) @@ -141,67 +140,148 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): return pre_output, out -class TestHSigmoidOp(OpTest): - def setUp(self): - self.op_type = "hierarchical_sigmoid" - num_classes = 6 - feature_size = 8 - batch_size = 4 - x = np.random.random((batch_size, feature_size)).astype("float32") * 2 - w = np.random.random( - (num_classes - 1, feature_size)).astype("float32") * 2 - label = np.random.randint(0, num_classes, (batch_size, 1)) - bias = np.random.random((1, num_classes - 1)).astype("float32") - self.attrs = {'num_classes': num_classes} - self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} - pre_output, out = hsigmoid(x, w, label, bias, num_classes) - self.outputs = {'PreOut': pre_output, 'Out': out} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) - - -class TestHSigmoidOpWithCostumTree(OpTest): - def setUp(self): - self.op_type = "hierarchical_sigmoid" - num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample - feature_size = 8 - batch_size = 4 - x = np.random.random((batch_size, feature_size)).astype("float32") * 2 - w = np.random.random( - (num_classes - 1, feature_size)).astype("float32") * 2 - label = np.array([0, 1, 4, 5]) - ptable = np.array( - [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), - (0, 2, -1, -1, - -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) - pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( - 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store - bias = np.random.random((1, num_classes - 1)).astype("float32") - self.attrs = {'num_classes': num_classes} - self.inputs = { - 'X': x, - 'W': w, - 'PTable': ptable, - 'PCode': pcode, - 'Label': label, - 'Bias': bias - } - pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, - bias, num_classes) - self.outputs = {'PreOut': pre_output, 'Out': out} - - def test_check_output(self): - print("checking output in CostumTree") - self.check_output() - - def test_check_grad(self): - print("checking outputGrad in CostumTree") - self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) +# class TestHSigmoidOp(OpTest): +# def setUp(self): +# self.op_type = "hierarchical_sigmoid" +# num_classes = 6 +# feature_size = 8 +# batch_size = 4 +# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 +# w = np.random.random( +# (num_classes - 1, feature_size)).astype("float32") * 2 +# label = np.random.randint(0, num_classes, (batch_size, 1)) +# bias = np.random.random((1, num_classes - 1)).astype("float32") +# self.attrs = {'num_classes': num_classes, 'is_sparse': False} +# self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} +# pre_output, out = hsigmoid(x, w, label, bias, num_classes) +# self.outputs = {'PreOut': pre_output, 'Out': out} +# def test_check_output(self): +# self.check_output() + +# def test_check_grad(self): +# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + +# class TestHSigmoidOpSparse(OpTest): +# def setUp(self): +# self.op_type = "hierarchical_sigmoid" +# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample +# feature_size = 8 +# batch_size = 4 +# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 +# w = np.random.random( +# (num_classes - 1, feature_size)).astype("float32") * 2 +# label = np.array([0, 1, 4, 5]) +# ptable = np.array( +# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), +# (0, 2, -1, -1, +# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) +# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( +# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store +# bias = np.random.random((1, num_classes - 1)).astype("float32") +# self.attrs = {'num_classes': num_classes, 'is_sparse': True} +# self.inputs = { +# 'X': x, +# 'W': w, +# 'PTable': ptable, +# 'PCode': pcode, +# 'Label': label, +# 'Bias': bias +# } +# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, +# bias, num_classes) +# self.outputs = {'PreOut': pre_output, 'Out': out} + +# def test_check_output(self): +# print("checking output in CostumTree") +# self.check_output() + + +class TestHSigmoidOpWithSparseGrad(): + def hs_net_conf(self): + emb = fluid.layers.data(name="x", shape=[3], dtype='int64') + ptable = fluid.layers.data(name='ptable', shape=[3], dtype='int64') + pcode = fluid.layers.data(name='pcode', shape=[3], dtype='int64') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + data_list = [emb, ptable, pcode, label] + cost = fluid.layers.hsigmoid( + input=emb, + label=predict_word, + non_leaf_num=4, + ptable=ptable, + pcode=pcode, + is_costum=True, + is_sparse=True) + + avg_cost = fluid.layers.reduce_mean(cost) + + return avg_cost, data_list + + def test_training_test(self): + print("im here") + w = np.arange(12).reshape(4, 3) + x = np.ones((2, 3)) + ptable = np.array([(1, 2, -1), (1, 2, -1)]) + pcode = np.array([(1, 0, -1), (0, 0, -1)]) + label = np.array([(1, 4)]) + + loss, data_list = hs_net_conf() + optimizer = fluid.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + + main_program = fluid.default_main_program() + + place = fluid.CPUPlace() + feeder = fluid.DataFeeder(feed_list=data_list, place=place) + data_name_list = [var.name for var in data_list] + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + for pass_id in range(args.num_passes): + for i in range(10): + data = [w, x[i % 2], ptable[i % 2], pcode[i % 2], label[i % 2]] + loss_val = exe.run(main_program, + feed=feeder.feed(data), + fetch_list=[loss]) + print("loss is: {loss}".format(loss=loss)) + + +# class TestHSigmoidOpWithCostumTree(OpTest): +# def setUp(self): +# self.op_type = "hierarchical_sigmoid" +# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample +# feature_size = 8 +# batch_size = 4 +# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 +# w = np.random.random( +# (num_classes - 1, feature_size)).astype("float32") * 2 +# label = np.array([0, 1, 4, 5]) +# ptable = np.array( +# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), +# (0, 2, -1, -1, +# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) +# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( +# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store +# bias = np.random.random((1, num_classes - 1)).astype("float32") +# self.attrs = {'num_classes': num_classes, 'is_sparse': False} +# self.inputs = { +# 'X': x, +# 'W': w, +# 'PTable': ptable, +# 'PCode': pcode, +# 'Label': label, +# 'Bias': bias +# } +# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, +# bias, num_classes) +# self.outputs = {'PreOut': pre_output, 'Out': out} + +# def test_check_output(self): +# print("checking output in CostumTree") +# self.check_output() + +# def test_check_grad(self): +# print("checking outputGrad in CostumTree") +# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) if __name__ == '__main__': unittest.main() From af9a3301dab9ab291d3cdd278734ae129de8a0f0 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Wed, 21 Nov 2018 12:35:21 +0000 Subject: [PATCH 11/23] test=develop --- paddle/fluid/framework/selected_rows.h | 6 +- .../operators/hierarchical_sigmoid_op.cc | 5 +- .../fluid/operators/hierarchical_sigmoid_op.h | 2 +- .../fluid/tests/unittests/test_hsigmoid_op.py | 269 ++++++++++-------- 4 files changed, 152 insertions(+), 130 deletions(-) diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index 4d728ae54a..9d87c3eac7 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -121,7 +121,9 @@ class SelectedRows { int64_t AutoGrownIndex(int64_t key, bool auto_grown); void SyncIndex(); - + /* + * @brief Get complete Dims before + */ DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); dims[0] = height_; @@ -136,7 +138,7 @@ class SelectedRows { std::unordered_map id_to_index_; // should not be used when ids has duplicate member std::unique_ptr value_{nullptr}; - int64_t height_; + int64_t height_; // height indicates the underline tensor's height std::unique_ptr rwlock_{nullptr}; }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index b2f4616441..c350e6489d 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -145,8 +145,9 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("PreOut"), "Input(Preout) should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), - "Output(W@Grad should not be null.)"); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); + "Output(W@Grad should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad should not be null."); if (ctx->HasOutput(framework::GradVarName("Bias"))) { ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 3e2fbafa26..35a1de3e19 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -191,10 +191,10 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { framework::Vector real_rows = cal_rows(path); auto* w_grad = ctx.Output(framework::GradVarName("W")); - w_grad->set_rows(real_rows); // build ids -> rows index map w_grad->SyncIndex(); + w_grad->set_height(w->dims()[0]); auto* w_grad_value = w_grad->mutable_value(); framework::DDim temp_dim(w->dims()); set(temp_dim, 0, real_rows.size()); diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 50dfaee76f..2f4225f912 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -140,148 +140,167 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): return pre_output, out -# class TestHSigmoidOp(OpTest): -# def setUp(self): -# self.op_type = "hierarchical_sigmoid" -# num_classes = 6 -# feature_size = 8 -# batch_size = 4 -# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 -# w = np.random.random( -# (num_classes - 1, feature_size)).astype("float32") * 2 -# label = np.random.randint(0, num_classes, (batch_size, 1)) -# bias = np.random.random((1, num_classes - 1)).astype("float32") -# self.attrs = {'num_classes': num_classes, 'is_sparse': False} -# self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} -# pre_output, out = hsigmoid(x, w, label, bias, num_classes) -# self.outputs = {'PreOut': pre_output, 'Out': out} - -# def test_check_output(self): -# self.check_output() - -# def test_check_grad(self): -# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) - -# class TestHSigmoidOpSparse(OpTest): -# def setUp(self): -# self.op_type = "hierarchical_sigmoid" -# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample -# feature_size = 8 -# batch_size = 4 -# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 -# w = np.random.random( -# (num_classes - 1, feature_size)).astype("float32") * 2 -# label = np.array([0, 1, 4, 5]) -# ptable = np.array( -# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), -# (0, 2, -1, -1, -# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) -# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( -# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store -# bias = np.random.random((1, num_classes - 1)).astype("float32") -# self.attrs = {'num_classes': num_classes, 'is_sparse': True} -# self.inputs = { -# 'X': x, -# 'W': w, -# 'PTable': ptable, -# 'PCode': pcode, -# 'Label': label, -# 'Bias': bias -# } -# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, -# bias, num_classes) -# self.outputs = {'PreOut': pre_output, 'Out': out} - -# def test_check_output(self): -# print("checking output in CostumTree") -# self.check_output() - - -class TestHSigmoidOpWithSparseGrad(): - def hs_net_conf(self): - emb = fluid.layers.data(name="x", shape=[3], dtype='int64') +class TestHSigmoidOp(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 + label = np.random.randint(0, num_classes, (batch_size, 1)) + bias = np.random.random((1, num_classes - 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} + self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} + pre_output, out = hsigmoid(x, w, label, bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + + +class TestHSigmoidOpSparse(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") + w = np.random.random((num_classes - 1, feature_size)).astype("float32") + label = np.array([0, 1, 4, 5]) + ptable = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + bias = np.random.random((1, num_classes - 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': True} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': ptable, + 'PCode': pcode, + 'Label': label, + 'Bias': bias + } + pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, + bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + print("checking output in CostumTree") + self.check_output() + + +class TestHSigmoidOpWithSparseGrad(unittest.TestCase): + def hs_net_conf(self, is_sparse): + input_word = fluid.layers.data(name="x", shape=[1], dtype='int64') ptable = fluid.layers.data(name='ptable', shape=[3], dtype='int64') pcode = fluid.layers.data(name='pcode', shape=[3], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64') - data_list = [emb, ptable, pcode, label] + + data_list = [input_word, ptable, pcode, label] + + emb = fluid.layers.embedding( + input=input_word, + is_sparse=False, + size=[3, 3], + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(3)))) + cost = fluid.layers.hsigmoid( input=emb, - label=predict_word, - non_leaf_num=4, + label=label, + non_leaf_num=3, ptable=ptable, pcode=pcode, is_costum=True, - is_sparse=True) + is_sparse=is_sparse) avg_cost = fluid.layers.reduce_mean(cost) return avg_cost, data_list - def test_training_test(self): - print("im here") - w = np.arange(12).reshape(4, 3) - x = np.ones((2, 3)) - ptable = np.array([(1, 2, -1), (1, 2, -1)]) - pcode = np.array([(1, 0, -1), (0, 0, -1)]) - label = np.array([(1, 4)]) - - loss, data_list = hs_net_conf() - optimizer = fluid.optimizer.SGD(learning_rate=1e-3) - optimizer.minimize(loss) - - main_program = fluid.default_main_program() - - place = fluid.CPUPlace() - feeder = fluid.DataFeeder(feed_list=data_list, place=place) - data_name_list = [var.name for var in data_list] - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - for pass_id in range(args.num_passes): + def training_test(self, is_sparse): + with fluid.program_guard(fluid.Program(), fluid.Program()): + start_up = fluid.default_startup_program() + start_up.random_seed = 1 # Fix random seed + x = np.arange(6).reshape(6) + ptable = np.array([(1, 2, -1), (1, 2, -1)]) + pcode = np.array([(1, 0, -1), (0, 0, -1)]) + label = np.array([1, 4]) + + loss, data_list = self.hs_net_conf(is_sparse) + optimizer = fluid.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + + main_program = fluid.default_main_program() + # print("main program: {program}".format{program=str(main_program)}) + place = fluid.CPUPlace() + feeder = fluid.DataFeeder(feed_list=data_list, place=place) + exe = fluid.Executor(place) + + exe.run(start_up) + result = list() for i in range(10): - data = [w, x[i % 2], ptable[i % 2], pcode[i % 2], label[i % 2]] + data = [([[x[i % 2]]], [list(ptable[i % 2])], + [list(pcode[i % 2])], [label[i % 2]])] + loss_val = exe.run(main_program, feed=feeder.feed(data), fetch_list=[loss]) - print("loss is: {loss}".format(loss=loss)) - - -# class TestHSigmoidOpWithCostumTree(OpTest): -# def setUp(self): -# self.op_type = "hierarchical_sigmoid" -# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample -# feature_size = 8 -# batch_size = 4 -# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 -# w = np.random.random( -# (num_classes - 1, feature_size)).astype("float32") * 2 -# label = np.array([0, 1, 4, 5]) -# ptable = np.array( -# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), -# (0, 2, -1, -1, -# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) -# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( -# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store -# bias = np.random.random((1, num_classes - 1)).astype("float32") -# self.attrs = {'num_classes': num_classes, 'is_sparse': False} -# self.inputs = { -# 'X': x, -# 'W': w, -# 'PTable': ptable, -# 'PCode': pcode, -# 'Label': label, -# 'Bias': bias -# } -# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, -# bias, num_classes) -# self.outputs = {'PreOut': pre_output, 'Out': out} - -# def test_check_output(self): -# print("checking output in CostumTree") -# self.check_output() - -# def test_check_grad(self): -# print("checking outputGrad in CostumTree") -# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + result.append(loss_val) + return result + + def test_hs_grad_with_sparse(self): + dense_result = self.training_test(is_sparse=False) + sparse_result = self.training_test(is_sparse=True) + assert (dense_result == sparse_result) + + +class TestHSigmoidOpWithCostumTree(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 + label = np.array([0, 1, 4, 5]) + ptable = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + bias = np.random.random((1, num_classes - 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': ptable, + 'PCode': pcode, + 'Label': label, + 'Bias': bias + } + pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, + bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + print("checking output in CostumTree") + self.check_output() + + def test_check_grad(self): + print("checking outputGrad in CostumTree") + self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + if __name__ == '__main__': unittest.main() From 57a18e32a18232b65920a8ecb0ea014453bbdf7a Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 22 Nov 2018 04:26:13 +0000 Subject: [PATCH 12/23] test=develop --- paddle/fluid/operators/hierarchical_sigmoid_op.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 35a1de3e19..418fe86f69 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -43,9 +43,7 @@ std::vector cal_rows(const framework::LoDTensor* path) { } } } - for (std::set::iterator it = tmp.begin(); it != tmp.end(); ++it) { - rows.push_back(*it); - } + rows.assign(tmp.begin(), tmp.end()); return rows; } From e9be3366a9cde661293e92306b036aea0ee772c1 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Fri, 23 Nov 2018 02:49:06 +0000 Subject: [PATCH 13/23] test=develop --- paddle/fluid/operators/hierarchical_sigmoid_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 418fe86f69..b4a5fe8309 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -164,7 +164,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); auto out_grad_mat = EigenMatrix::From(*out_grad); - Eigen::array bcast({{1, static_cast(pre_out_grad.dims()[1])}}); + Eigen::array bcast{1, static_cast(pre_out_grad.dims()[1])}; // softrelu derivative pre_out_grad_mat.device(place) = From 0fca16847c89d1018c32da0e7bbc0b6396d5e104 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Fri, 23 Nov 2018 02:52:35 +0000 Subject: [PATCH 14/23] temp --- paddle/fluid/operators/math/matrix_bit_code.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 8baffe1ba1..2967586949 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -102,6 +102,8 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; size_t weight_width = weight->dims()[1]; + VLOG(30) << "sparse w_grad dims is [" << weight->dims()[0] << " ," + << weight->dims()[1] << " ]"; auto tmat_value = tmat.data(); auto weight_value = weight->data(); auto input_value = input.data(); @@ -127,6 +129,8 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; size_t weight_width = weight->value().dims()[1]; + VLOG(30) << "sparse w_grad dims is: [" << weight->value().dims()[0] << " ," + << weight->value().dims()[1] << " ]"; auto tmat_value = tmat.data(); auto weight_value = weight->mutable_value()->data(); auto input_value = input.data(); From 42470f14b77e71a53c25cf318c69c4ca019bb593 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Fri, 23 Nov 2018 06:43:42 +0000 Subject: [PATCH 15/23] test=develop --- paddle/fluid/framework/selected_rows.cc | 52 ------------------- paddle/fluid/framework/selected_rows.h | 50 +++++++++++++++++- .../fluid/operators/math/matrix_bit_code.cc | 2 +- 3 files changed, 50 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index f4f2b769d5..7262f8cc05 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -140,58 +140,6 @@ bool SelectedRows::HasKey(int64_t key) const { : true; } -int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown, - bool is_test) { - if (is_test) { - auto iter = id_to_index_.find(key); - if (iter == id_to_index_.end()) { - return -1; - } else { - return iter->second; - } - } - - rwlock_->RDLock(); - auto iter = id_to_index_.find(key); - if (iter == id_to_index_.end()) { - rwlock_->UNLock(); - if (!auto_grown) { - PADDLE_THROW("key %d not found", key); - } - rwlock_->WRLock(); - auto map_size = id_to_index_.size(); - auto vector_size = rows_.size(); - if (map_size != vector_size) { - rwlock_->UNLock(); - PADDLE_THROW( - "id_to_index_ size %d should have the same size with rows_ %d", - map_size, vector_size); - } - auto write_iter = id_to_index_.find(key); - if (write_iter == id_to_index_.end()) { - int row_num = rows_.size(); - if (row_num == value_->dims()[0]) { - rwlock_->UNLock(); - PADDLE_THROW("selected rows is full, then length exceed %d", row_num); - } - // key logic to put a key into id_to_index_ - rows_.push_back(key); - auto index = static_cast(rows_.size() - 1); - id_to_index_[key] = index; - rwlock_->UNLock(); - return index; - } else { - auto index = write_iter->second; - rwlock_->UNLock(); - return index; - } - } else { - auto index = iter->second; - rwlock_->UNLock(); - return index; - } -} - void SelectedRows::SyncIndex() { rwlock_->WRLock(); id_to_index_.clear(); diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index d3e0f2168b..6c31dada68 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -118,7 +118,55 @@ class SelectedRows { * * @return index of the key. */ - int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false); + int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false) { + if (is_test) { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + rwlock_->RDLock(); + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + rwlock_->UNLock(); + if (!auto_grown) { + PADDLE_THROW("key %d not found", key); + } + rwlock_->WRLock(); + auto map_size = id_to_index_.size(); + auto vector_size = rows_.size(); + if (map_size != vector_size) { + rwlock_->UNLock(); + PADDLE_THROW( + "id_to_index_ size %d should have the same size with rows_ %d", + map_size, vector_size); + } + auto write_iter = id_to_index_.find(key); + if (write_iter == id_to_index_.end()) { + int row_num = rows_.size(); + if (row_num == value_->dims()[0]) { + rwlock_->UNLock(); + PADDLE_THROW("selected rows is full, then length exceed %d", row_num); + } + // key logic to put a key into id_to_index_ + rows_.push_back(key); + auto index = static_cast(rows_.size() - 1); + id_to_index_[key] = index; + rwlock_->UNLock(); + return index; + } else { + auto index = write_iter->second; + rwlock_->UNLock(); + return index; + } + } else { + auto index = iter->second; + rwlock_->UNLock(); + return index; + } + } void SyncIndex(); /* diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 2967586949..9a0cf8701f 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -142,7 +142,7 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, for (size_t k = 0; k < input_width; ++k) { int64_t row_index = - weight->AutoGrownIndex(static_cast(index), false); + weight->AutoGrownIndex(static_cast(index), false, true); weight_value[row_index * weight_width + k] += tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; From 02d68051db17e43f7b0c6785fa9f31384263a741 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 26 Nov 2018 03:09:12 +0000 Subject: [PATCH 16/23] add sparsed bias grad, test=develop --- .../operators/hierarchical_sigmoid_op.cc | 32 +++++++++++++------ .../fluid/operators/hierarchical_sigmoid_op.h | 31 ++++++++++++++---- .../fluid/operators/math/matrix_bit_code.cc | 18 +++++++++++ paddle/fluid/operators/math/matrix_bit_code.h | 5 +++ python/paddle/fluid/layers/nn.py | 4 +-- .../fluid/tests/unittests/test_hsigmoid_op.py | 17 ++++------ 6 files changed, 78 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index c350e6489d..042d90e72f 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -107,8 +107,9 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { "it should have shape like [N, L], L is the length of the Path") .AsDispensable(); AddInput("Bias", - "(LoDTensor, optional), The bias is a tensor with shape" - "[1, num_classes - 1]."); + "(LoDTensor, optional), The bias is a tensor with shape or " + "[non_leaf_num, 1]" + "[num_classes - 1, 1]."); AddOutput( "Out", "(LoDTensor, required) The output of hierarchical sigmoid operator." @@ -148,11 +149,11 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { "Output(W@Grad should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Output(X@Grad should not be null."); - if (ctx->HasOutput(framework::GradVarName("Bias"))) { - ctx->SetOutputDim(framework::GradVarName("Bias"), - ctx->GetInputDim("Bias")); - } if (!ctx->Attrs().Get("is_sparse")) { + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Bias")); + } ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); } ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); @@ -172,20 +173,31 @@ class HierarchicalSigmoidGradOpGradVarTypeInference public: void operator()(const framework::OpDesc& op_desc, framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); + auto out_W_var_name = op_desc.Output(framework::GradVarName("W")).front(); + auto out_Bias_var_name = + op_desc.Output(framework::GradVarName("Bias")).front(); auto attr = op_desc.GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - block->Var(out_var_name) + block->Var(out_W_var_name) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + VLOG(3) << "hierarchical_sigmoid_grad op " + << framework::GradVarName("Bias") << " is set to SelectedRows"; + block->Var(out_Bias_var_name) ->SetType(framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); + block->Var(out_W_var_name) + ->SetType(framework::proto::VarType::LOD_TENSOR); + VLOG(3) << "hierarchical_sigmoid_grad op " + << framework::GradVarName("Bias") << " is set to SelectedRows"; + block->Var(out_Bias_var_name) + ->SetType(framework::proto::VarType::LOD_TENSOR); } - block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); + block->Var(out_W_var_name)->SetDataType(block->Var("W")->GetDataType()); } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index b4a5fe8309..44853dafe9 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -124,13 +124,12 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto* w = ctx.Input("W"); auto* path = ctx.Input("PTable"); auto* code = ctx.Input("PCode"); + auto* bias = ctx.Input("Bias"); auto* in_grad = ctx.Output(framework::GradVarName("X")); bool is_sparse = ctx.Attr("is_sparse"); auto& dev_ctx = ctx.template device_context(); math::SetConstant zero; - auto* bias_grad = - ctx.Output(framework::GradVarName("Bias")); auto* label = ctx.Input("Label"); auto* pre_out = ctx.Input("PreOut"); auto* out_grad = @@ -174,12 +173,15 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { pre_out_grad_mat * out_grad_mat.broadcast(bcast); // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // be consistent with the clipping in forward. - if (bias_grad) { - bias_grad->mutable_data(ctx.GetPlace()); - zero(dev_ctx, bias_grad, static_cast(0.0)); - bit_code->AddGrad(pre_out_grad, bias_grad); - } + if (!is_sparse) { + auto* bias_grad = + ctx.Output(framework::GradVarName("Bias")); + if (bias_grad) { + bias_grad->mutable_data(ctx.GetPlace()); + zero(dev_ctx, bias_grad, static_cast(0.0)); + bit_code->AddGrad(pre_out_grad, bias_grad); + } auto* w_grad = ctx.Output(framework::GradVarName("W")); w_grad->mutable_data(ctx.GetPlace()); @@ -199,6 +201,21 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { w_grad_value->mutable_data(temp_dim, ctx.GetPlace()); zero(dev_ctx, w_grad_value, static_cast(0.0)); + auto* bias_grad = + ctx.Output(framework::GradVarName("Bias")); + if (bias_grad) { + bias_grad->set_rows(real_rows); + // build ids -> rows index map + bias_grad->SyncIndex(); + bias_grad->set_height(bias->dims()[0]); + auto* bias_grad_value = bias_grad->mutable_value(); + std::vector dims = {static_cast(real_rows.size()), + bias->dims()[1]}; + bias_grad_value->mutable_data(framework::make_ddim(dims), + ctx.GetPlace()); + zero(dev_ctx, bias_grad_value, static_cast(0.0)); + bit_code->AddGrad(pre_out_grad, bias_grad); + } bit_code->MulGradWeight(pre_out_grad, w_grad, *in); } bit_code->MulGradError(pre_out_grad, *w, in_grad); diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 9a0cf8701f..0c1aa29a18 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -48,6 +48,24 @@ void MatrixBitCodeFunctor::AddGrad(const framework::LoDTensor& tmat, } } +template +void MatrixBitCodeFunctor::AddGrad(const framework::LoDTensor& tmat, + framework::SelectedRows* vec) { + size_t batch_size = tmat.dims()[0]; + size_t width = tmat.dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table->get_code(i); + int code_length = code->get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code->calc_index(j); + int64_t row_index = + vec->AutoGrownIndex(static_cast(index), false, true); + vec->mutable_value()->data()[row_index] += + tmat.data()[i * width + j]; + } + } +} + template void MatrixBitCodeFunctor::Sum(const framework::LoDTensor& tmat, framework::LoDTensor* sum, T scale_sum) { diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index c8d21ba686..673fcb65c8 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -241,6 +241,11 @@ class MatrixBitCodeFunctor { */ void AddGrad(const framework::LoDTensor& tmat, framework::LoDTensor* vec); + /* For selected rows For j < code_length + vec(0, index(i, j)) += tmat(i, j) + */ + void AddGrad(const framework::LoDTensor& tmat, framework::SelectedRows* vec); + /* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) */ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index b02d75e55b..8170ccf082 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4639,14 +4639,14 @@ def hsigmoid(input, if not is_costum: bias = helper.create_parameter( attr=helper.bias_attr, - shape=[1, num_classes - 1], + shape=[num_classes - 1, 1], is_bias=True, dtype=input.dtype) inputs['Bias'] = bias else: bias = helper.create_parameter( attr=helper.bias_attr, - shape=[1, non_leaf_num], + shape=[non_leaf_num, 1], is_bias=True, dtype=input.dtype) inputs['Bias'] = bias diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 2f4225f912..a3024dded6 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -77,7 +77,7 @@ def hsigmoid(x, w, label, bias, num_classes): length = code_table.get_length() for j in range(length): idx = code_table.cal_index(j) - pre_output[i][j] += bias[0][idx] + pre_output[i][j] += bias[idx][0] for i in range(batch_size): code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() @@ -115,7 +115,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): length = code_table.get_length() for j in range(length): idx = code_table.cal_index(j) - pre_output[i][j] += bias[0][idx] + pre_output[i][j] += bias[idx][0] for i in range(batch_size): code_table = CodeTableWithCustomTree(ptable, pcode, i) length = code_table.get_length() @@ -150,7 +150,7 @@ class TestHSigmoidOp(OpTest): w = np.random.random( (num_classes - 1, feature_size)).astype("float32") * 2 label = np.random.randint(0, num_classes, (batch_size, 1)) - bias = np.random.random((1, num_classes - 1)).astype("float32") + bias = np.random.random((num_classes - 1, 1)).astype("float32") self.attrs = {'num_classes': num_classes, 'is_sparse': False} self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} pre_output, out = hsigmoid(x, w, label, bias, num_classes) @@ -178,7 +178,7 @@ class TestHSigmoidOpSparse(OpTest): -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store - bias = np.random.random((1, num_classes - 1)).astype("float32") + bias = np.random.random((num_classes - 1, 1)).astype("float32") self.attrs = {'num_classes': num_classes, 'is_sparse': True} self.inputs = { 'X': x, @@ -193,7 +193,6 @@ class TestHSigmoidOpSparse(OpTest): self.outputs = {'PreOut': pre_output, 'Out': out} def test_check_output(self): - print("checking output in CostumTree") self.check_output() @@ -208,7 +207,7 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): emb = fluid.layers.embedding( input=input_word, - is_sparse=False, + is_sparse=is_sparse, size=[3, 3], param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( scale=1 / math.sqrt(3)))) @@ -220,6 +219,7 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): ptable=ptable, pcode=pcode, is_costum=True, + bias_attr=True, is_sparse=is_sparse) avg_cost = fluid.layers.reduce_mean(cost) @@ -240,7 +240,6 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): optimizer.minimize(loss) main_program = fluid.default_main_program() - # print("main program: {program}".format{program=str(main_program)}) place = fluid.CPUPlace() feeder = fluid.DataFeeder(feed_list=data_list, place=place) exe = fluid.Executor(place) @@ -279,7 +278,7 @@ class TestHSigmoidOpWithCostumTree(OpTest): -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store - bias = np.random.random((1, num_classes - 1)).astype("float32") + bias = np.random.random((num_classes - 1, 1)).astype("float32") self.attrs = {'num_classes': num_classes, 'is_sparse': False} self.inputs = { 'X': x, @@ -294,11 +293,9 @@ class TestHSigmoidOpWithCostumTree(OpTest): self.outputs = {'PreOut': pre_output, 'Out': out} def test_check_output(self): - print("checking output in CostumTree") self.check_output() def test_check_grad(self): - print("checking outputGrad in CostumTree") self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) From 2f6b529aff0f5a90dec89a05a5bf5ce6bc40555d Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 26 Nov 2018 12:20:15 +0000 Subject: [PATCH 17/23] refine code and comments, test=develop --- paddle/fluid/operators/hierarchical_sigmoid_op.cc | 2 +- paddle/fluid/operators/math/matrix_bit_code.cc | 6 ------ python/paddle/fluid/layers/nn.py | 3 ++- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 042d90e72f..6d1fb29236 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -193,7 +193,7 @@ class HierarchicalSigmoidGradOpGradVarTypeInference block->Var(out_W_var_name) ->SetType(framework::proto::VarType::LOD_TENSOR); VLOG(3) << "hierarchical_sigmoid_grad op " - << framework::GradVarName("Bias") << " is set to SelectedRows"; + << framework::GradVarName("Bias") << " is set to LoDTensor"; block->Var(out_Bias_var_name) ->SetType(framework::proto::VarType::LOD_TENSOR); } diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 0c1aa29a18..e283320bcc 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -120,8 +120,6 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; size_t weight_width = weight->dims()[1]; - VLOG(30) << "sparse w_grad dims is [" << weight->dims()[0] << " ," - << weight->dims()[1] << " ]"; auto tmat_value = tmat.data(); auto weight_value = weight->data(); auto input_value = input.data(); @@ -147,8 +145,6 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; size_t weight_width = weight->value().dims()[1]; - VLOG(30) << "sparse w_grad dims is: [" << weight->value().dims()[0] << " ," - << weight->value().dims()[1] << " ]"; auto tmat_value = tmat.data(); auto weight_value = weight->mutable_value()->data(); auto input_value = input.data(); @@ -157,11 +153,9 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); - for (size_t k = 0; k < input_width; ++k) { int64_t row_index = weight->AutoGrownIndex(static_cast(index), false, true); - weight_value[row_index * weight_width + k] += tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8170ccf082..e98989f5bd 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4581,7 +4581,8 @@ def hsigmoid(input, is not set, the bias is initialized zero. Default: None. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None. - is_costum: (bool|False)using user defined binary tree instead of default complete binary tree + is_costum: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is + set you need to set ptable/pcode/non_leaf_num, otherwise num_classes should be set is_sparse: (bool|False)using sparse update instead of dense update Returns: From 81e145764d870ed1a408e3fec4fc0d4d17e1bbec Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 26 Nov 2018 12:58:52 +0000 Subject: [PATCH 18/23] refine code and comments, test=develop --- paddle/fluid/operators/hierarchical_sigmoid_op.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 44853dafe9..f046fba7fc 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -30,14 +30,14 @@ template ; using platform::Transform; -std::vector cal_rows(const framework::LoDTensor* path) { +std::vector cal_rows(const framework::LoDTensor& path) { std::set tmp; std::vector rows; rows.clear(); - for (size_t i = 0; i < static_cast(path->dims()[0]); i++) { - for (size_t j = 0; j < static_cast(path->dims()[1]); j++) { + for (size_t i = 0; i < static_cast(path.dims()[0]); i++) { + for (size_t j = 0; j < static_cast(path.dims()[1]); j++) { int64_t temp = - path->data()[i * static_cast(path->dims()[1]) + j]; + path.data()[i * static_cast(path.dims()[1]) + j]; if (temp >= 0) { tmp.insert(temp); } @@ -188,7 +188,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { zero(dev_ctx, w_grad, static_cast(0.0)); bit_code->MulGradWeight(pre_out_grad, w_grad, *in); } else { - framework::Vector real_rows = cal_rows(path); + framework::Vector real_rows = cal_rows(*path); auto* w_grad = ctx.Output(framework::GradVarName("W")); w_grad->set_rows(real_rows); From b10df8bcfae243c76d56469cf0eb5a25f9b8a043 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 27 Nov 2018 04:37:48 +0000 Subject: [PATCH 19/23] refine code and add none bias ut, test=develop --- paddle/fluid/framework/selected_rows.h | 3 +- .../operators/hierarchical_sigmoid_op.cc | 52 ++++++++++------- .../fluid/operators/hierarchical_sigmoid_op.h | 9 +-- .../fluid/operators/math/matrix_bit_code.cc | 4 +- paddle/fluid/operators/math/matrix_bit_code.h | 2 +- python/paddle/fluid/layers/nn.py | 9 +-- .../fluid/tests/unittests/test_hsigmoid_op.py | 57 ++++++++++++++++--- 7 files changed, 94 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index 6c31dada68..bc5726382f 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -118,7 +118,8 @@ class SelectedRows { * * @return index of the key. */ - int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false) { + inline int64_t AutoGrownIndex(int64_t key, bool auto_grown, + bool is_test = false) { if (is_test) { auto iter = id_to_index_.find(key); if (iter == id_to_index_.end()) { diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 6d1fb29236..f3329c4855 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/hierarchical_sigmoid_op.h" +#include #include - namespace paddle { namespace operators { @@ -109,7 +109,8 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Bias", "(LoDTensor, optional), The bias is a tensor with shape or " "[non_leaf_num, 1]" - "[num_classes - 1, 1]."); + "[num_classes - 1, 1].") + .AsDispensable(); AddOutput( "Out", "(LoDTensor, required) The output of hierarchical sigmoid operator." @@ -173,31 +174,42 @@ class HierarchicalSigmoidGradOpGradVarTypeInference public: void operator()(const framework::OpDesc& op_desc, framework::BlockDesc* block) const override { - auto out_W_var_name = op_desc.Output(framework::GradVarName("W")).front(); - auto out_Bias_var_name = - op_desc.Output(framework::GradVarName("Bias")).front(); + auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front(); + auto bias_grad_var_name_vec = + op_desc.Output(framework::GradVarName("Bias")); + std::string bias_grad_var_name; + bool hasBias = false; + if (bias_grad_var_name_vec.size()) { + hasBias = true; + bias_grad_var_name = + op_desc.Output(framework::GradVarName("Bias")).front(); + } auto attr = op_desc.GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { - VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") - << " is set to SelectedRows"; - block->Var(out_W_var_name) - ->SetType(framework::proto::VarType::SELECTED_ROWS); - VLOG(3) << "hierarchical_sigmoid_grad op " - << framework::GradVarName("Bias") << " is set to SelectedRows"; - block->Var(out_Bias_var_name) + VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") + << " is set to SelectedRows"; + block->Var(w_grad_var_name) ->SetType(framework::proto::VarType::SELECTED_ROWS); + if (hasBias) { + VLOG(30) << "hierarchical_sigmoid_grad op " + << framework::GradVarName("Bias") << " is set to SelectedRows"; + block->Var(bias_grad_var_name) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + } } else { - VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") - << " is set to LoDTensor"; - block->Var(out_W_var_name) - ->SetType(framework::proto::VarType::LOD_TENSOR); - VLOG(3) << "hierarchical_sigmoid_grad op " - << framework::GradVarName("Bias") << " is set to LoDTensor"; - block->Var(out_Bias_var_name) + VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") + << " is set to LoDTensor"; + block->Var(w_grad_var_name) ->SetType(framework::proto::VarType::LOD_TENSOR); + if (hasBias) { + VLOG(30) << "hierarchical_sigmoid_grad op " + << framework::GradVarName("Bias") << " is set to LoDTensor"; + block->Var(bias_grad_var_name) + ->SetType(framework::proto::VarType::LOD_TENSOR); + } } - block->Var(out_W_var_name)->SetDataType(block->Var("W")->GetDataType()); + block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType()); } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index f046fba7fc..de219bacdd 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -33,7 +33,6 @@ using platform::Transform; std::vector cal_rows(const framework::LoDTensor& path) { std::set tmp; std::vector rows; - rows.clear(); for (size_t i = 0; i < static_cast(path.dims()[0]); i++) { for (size_t j = 0; j < static_cast(path.dims()[1]); j++) { int64_t temp = @@ -63,8 +62,6 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { bool is_custom = false; if (path) { is_custom = true; - } else { - is_custom = false; } int64_t code_length = path ? path->dims()[1] : math::FindLastSet(num_classes - 1); @@ -96,7 +93,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); auto out_mat = framework::EigenVector::Flatten(*out); if (bias) { - bit_code->Add(pre_out, *bias); + bit_code->Add(*bias, pre_out); } bit_code->Mul(pre_out, *w, *in); // clip to [-40, 40] @@ -145,8 +142,6 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { bool is_custom = false; if (path) { is_custom = true; - } else { - is_custom = false; } std::unique_ptr> bit_code; @@ -192,7 +187,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto* w_grad = ctx.Output(framework::GradVarName("W")); w_grad->set_rows(real_rows); - // build ids -> rows index map + // Build a map of id -> row_index to speed up finding the index of one id w_grad->SyncIndex(); w_grad->set_height(w->dims()[0]); auto* w_grad_value = w_grad->mutable_value(); diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index e283320bcc..297e8d850b 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -19,8 +19,8 @@ namespace operators { namespace math { template -void MatrixBitCodeFunctor::Add(framework::LoDTensor* tmat, - const framework::LoDTensor& vec) { +void MatrixBitCodeFunctor::Add(const framework::LoDTensor& vec, + framework::LoDTensor* tmat) { size_t batch_size = tmat->dims()[0]; size_t width = tmat->dims()[1]; for (size_t i = 0; i < batch_size; ++i) { diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index 673fcb65c8..3add06cb63 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -234,7 +234,7 @@ class MatrixBitCodeFunctor { /* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ - void Add(framework::LoDTensor* tmat, const framework::LoDTensor& vec); + void Add(const framework::LoDTensor& vec, framework::LoDTensor* tmat); /* For j < code_length vec(0, index(i, j)) += tmat(i, j) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e98989f5bd..7da3a9b4fb 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4535,12 +4535,12 @@ def nce(input, def hsigmoid(input, label, num_classes=None, - non_leaf_num=None, - ptable=None, - pcode=None, param_attr=None, bias_attr=None, name=None, + non_leaf_num=None, + ptable=None, + pcode=None, is_costum=False, is_sparse=False): """ @@ -4583,7 +4583,8 @@ def hsigmoid(input, will be named automatically. Default: None. is_costum: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is set you need to set ptable/pcode/non_leaf_num, otherwise num_classes should be set - is_sparse: (bool|False)using sparse update instead of dense update + is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient + of W and input will be sparse. Returns: Out: (LodTensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index a3024dded6..955fc51d57 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -110,12 +110,13 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): pre_output = np.zeros((batch_size, code_length)) pre_sum = np.zeros((batch_size, 1)) out = np.zeros((batch_size, 1)).astype("float32") - for i in range(batch_size): - code_table = CodeTableWithCustomTree(ptable, pcode, i) - length = code_table.get_length() - for j in range(length): - idx = code_table.cal_index(j) - pre_output[i][j] += bias[idx][0] + if isinstance(bias, np.ndarray): + for i in range(batch_size): + code_table = CodeTableWithCustomTree(ptable, pcode, i) + length = code_table.get_length() + for j in range(length): + idx = code_table.cal_index(j) + pre_output[i][j] += bias[idx][0] for i in range(batch_size): code_table = CodeTableWithCustomTree(ptable, pcode, i) length = code_table.get_length() @@ -215,11 +216,11 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): cost = fluid.layers.hsigmoid( input=emb, label=label, + bias_attr=True, non_leaf_num=3, ptable=ptable, pcode=pcode, is_costum=True, - bias_attr=True, is_sparse=is_sparse) avg_cost = fluid.layers.reduce_mean(cost) @@ -299,5 +300,47 @@ class TestHSigmoidOpWithCostumTree(OpTest): self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) +class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 + label = np.array([0, 1, 4, 5]) + ptable = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + # bias = np.random.random((num_classes - 1, 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': ptable, + 'PCode': pcode, + 'Label': label, + } + pre_output, out = hsigmoidWithCustomTree( + x=x, + w=w, + ptable=ptable, + pcode=pcode, + label=label, + bias=None, + num_classes=num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'W'], ['Out'], no_grad_set=set('Label')) + + if __name__ == '__main__': unittest.main() From 7389597ce239cd7925268a31f451da314429b078 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 27 Nov 2018 05:16:24 +0000 Subject: [PATCH 20/23] Update API.spec, test=develop --- paddle/fluid/API.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e06d3459da..26ecc1071f 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -98,7 +98,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs= paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0)) -paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'non_leaf_num', 'ptable', 'pcode', 'param_attr', 'bias_attr', 'name', 'is_costum', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, False, False)) +paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'non_leaf_num', 'ptable', 'pcode', 'is_costum', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, False, False)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) From c3c3c0b33cf9100dd8f90f039ef0f130f53bafef Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 27 Nov 2018 09:24:53 +0000 Subject: [PATCH 21/23] polish code, test=develop --- paddle/fluid/framework/mixed_vector.h | 6 -- paddle/fluid/framework/selected_rows.cc | 52 ++++++++++ paddle/fluid/framework/selected_rows.h | 55 ++--------- .../operators/hierarchical_sigmoid_op.cc | 2 +- .../fluid/operators/hierarchical_sigmoid_op.h | 79 ++++++++-------- .../fluid/operators/math/matrix_bit_code.cc | 62 ++++++------ paddle/fluid/operators/math/matrix_bit_code.h | 94 +++++++++---------- python/paddle/fluid/layers/nn.py | 2 +- .../fluid/tests/unittests/test_hsigmoid_op.py | 6 +- 9 files changed, 176 insertions(+), 182 deletions(-) diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index 21118c4fc9..6940250c3f 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -488,12 +488,6 @@ class CPUVector : public std::vector> { return os; } - size_t size() const noexcept { - size_t size = - static_cast(std::vector>::size()); - return size; - } - T &operator[](size_t id) { return this->at(id); } const T &operator[](size_t id) const { return this->at(id); } diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 7262f8cc05..f4f2b769d5 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -140,6 +140,58 @@ bool SelectedRows::HasKey(int64_t key) const { : true; } +int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown, + bool is_test) { + if (is_test) { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + + rwlock_->RDLock(); + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + rwlock_->UNLock(); + if (!auto_grown) { + PADDLE_THROW("key %d not found", key); + } + rwlock_->WRLock(); + auto map_size = id_to_index_.size(); + auto vector_size = rows_.size(); + if (map_size != vector_size) { + rwlock_->UNLock(); + PADDLE_THROW( + "id_to_index_ size %d should have the same size with rows_ %d", + map_size, vector_size); + } + auto write_iter = id_to_index_.find(key); + if (write_iter == id_to_index_.end()) { + int row_num = rows_.size(); + if (row_num == value_->dims()[0]) { + rwlock_->UNLock(); + PADDLE_THROW("selected rows is full, then length exceed %d", row_num); + } + // key logic to put a key into id_to_index_ + rows_.push_back(key); + auto index = static_cast(rows_.size() - 1); + id_to_index_[key] = index; + rwlock_->UNLock(); + return index; + } else { + auto index = write_iter->second; + rwlock_->UNLock(); + return index; + } + } else { + auto index = iter->second; + rwlock_->UNLock(); + return index; + } +} + void SelectedRows::SyncIndex() { rwlock_->WRLock(); id_to_index_.clear(); diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index bc5726382f..44384082db 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -118,54 +118,17 @@ class SelectedRows { * * @return index of the key. */ - inline int64_t AutoGrownIndex(int64_t key, bool auto_grown, - bool is_test = false) { - if (is_test) { - auto iter = id_to_index_.find(key); - if (iter == id_to_index_.end()) { - return -1; - } else { - return iter->second; - } - } - rwlock_->RDLock(); + int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false); + + /* + * @brief Get the index of the key from id_to_index_ map. + */ + inline int64_t GetIndexFromId(int64_t key) { auto iter = id_to_index_.find(key); if (iter == id_to_index_.end()) { - rwlock_->UNLock(); - if (!auto_grown) { - PADDLE_THROW("key %d not found", key); - } - rwlock_->WRLock(); - auto map_size = id_to_index_.size(); - auto vector_size = rows_.size(); - if (map_size != vector_size) { - rwlock_->UNLock(); - PADDLE_THROW( - "id_to_index_ size %d should have the same size with rows_ %d", - map_size, vector_size); - } - auto write_iter = id_to_index_.find(key); - if (write_iter == id_to_index_.end()) { - int row_num = rows_.size(); - if (row_num == value_->dims()[0]) { - rwlock_->UNLock(); - PADDLE_THROW("selected rows is full, then length exceed %d", row_num); - } - // key logic to put a key into id_to_index_ - rows_.push_back(key); - auto index = static_cast(rows_.size() - 1); - id_to_index_[key] = index; - rwlock_->UNLock(); - return index; - } else { - auto index = write_iter->second; - rwlock_->UNLock(); - return index; - } + return -1; } else { - auto index = iter->second; - rwlock_->UNLock(); - return index; + return iter->second; } } @@ -185,7 +148,7 @@ class SelectedRows { // SelectedRows add a Tensor, will the duplicate rows be handled. Vector rows_; std::unordered_map - id_to_index_; // should not be used when ids has duplicate member + id_to_index_; // should not be used when rows_ has duplicate member std::unique_ptr value_{nullptr}; int64_t height_; // height indicates the underline tensor's height std::unique_ptr rwlock_{nullptr}; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index f3329c4855..5b09958e73 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -101,7 +101,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { "it should have shape like [N, L], L is the length of the Path") .AsDispensable(); AddInput( - "PCode", + "PathCode", "(LoDTensor, optional), The Code on each Node of the Path from root " "to current word" "it should have shape like [N, L], L is the length of the Path") diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index de219bacdd..6cb011611d 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -19,9 +19,11 @@ limitations under the License. */ #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/platform/transform.h" + namespace paddle { namespace operators { @@ -30,31 +32,26 @@ template ; using platform::Transform; -std::vector cal_rows(const framework::LoDTensor& path) { - std::set tmp; - std::vector rows; - for (size_t i = 0; i < static_cast(path.dims()[0]); i++) { - for (size_t j = 0; j < static_cast(path.dims()[1]); j++) { - int64_t temp = - path.data()[i * static_cast(path.dims()[1]) + j]; - if (temp >= 0) { - tmp.insert(temp); - } +static std::vector PathToRows(const framework::LoDTensor& path) { + std::set rows; + for (int64_t i = 0; i < path.numel(); ++i) { + int64_t row = path.data()[i]; + if (row < 0) { + continue; } + rows.emplace(row); } - rows.assign(tmp.begin(), tmp.end()); - return rows; + return std::vector(rows.begin(), rows.end()); } - template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* w = ctx.Input("W"); + auto in = detail::Ref(ctx.Input("X")); + auto w = detail::Ref(ctx.Input("W")); auto* path = ctx.Input("PTable"); - auto* code = ctx.Input("PCode"); - auto* label = ctx.Input("Label"); + auto* code = ctx.Input("PathCode"); + auto label = detail::Ref(ctx.Input("Label")); auto* bias = ctx.Input("Bias"); auto* out = ctx.Output("Out"); auto* pre_out = ctx.Output("PreOut"); @@ -65,7 +62,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { } int64_t code_length = path ? path->dims()[1] : math::FindLastSet(num_classes - 1); - int64_t batch_size = in->dims()[0]; + int64_t batch_size = in.dims()[0]; framework::LoDTensor sum; auto& dev_ctx = ctx.template device_context(); auto* pre_out_data = pre_out->mutable_data( @@ -81,10 +78,10 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { std::unique_ptr> bit_code; if (!is_custom) { bit_code.reset(new math::MatrixBitCodeFunctor(num_classes, - label->data())); + label.data())); } else { - bit_code.reset(new math::MatrixBitCodeFunctor(path, code, - label->data())); + bit_code.reset(new math::MatrixBitCodeFunctor(*path, *code, + label.data())); } std::vector sum_dims({batch_size, 1UL}); @@ -95,7 +92,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { if (bias) { bit_code->Add(*bias, pre_out); } - bit_code->Mul(pre_out, *w, *in); + bit_code->Mul(pre_out, w, in); // clip to [-40, 40] Transform trans; trans(ctx.template device_context(), pre_out_data, @@ -117,23 +114,23 @@ template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* w = ctx.Input("W"); + auto in = detail::Ref(ctx.Input("X")); + auto w = detail::Ref(ctx.Input("W")); auto* path = ctx.Input("PTable"); - auto* code = ctx.Input("PCode"); + auto* code = ctx.Input("PathCode"); auto* bias = ctx.Input("Bias"); auto* in_grad = ctx.Output(framework::GradVarName("X")); bool is_sparse = ctx.Attr("is_sparse"); auto& dev_ctx = ctx.template device_context(); math::SetConstant zero; - auto* label = ctx.Input("Label"); - auto* pre_out = ctx.Input("PreOut"); - auto* out_grad = - ctx.Input(framework::GradVarName("Out")); + auto label = detail::Ref(ctx.Input("Label")); + auto pre_out = detail::Ref(ctx.Input("PreOut")); + auto out_grad = detail::Ref( + ctx.Input(framework::GradVarName("Out"))); framework::LoDTensor pre_out_grad; - pre_out_grad.mutable_data(pre_out->dims(), ctx.GetPlace()); + pre_out_grad.mutable_data(pre_out.dims(), ctx.GetPlace()); in_grad->mutable_data(ctx.GetPlace()); zero(dev_ctx, in_grad, static_cast(0.0)); @@ -147,16 +144,16 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { std::unique_ptr> bit_code; if (!is_custom) { bit_code.reset(new math::MatrixBitCodeFunctor(num_classes, - label->data())); + label.data())); } else { - bit_code.reset(new math::MatrixBitCodeFunctor(path, code, - label->data())); + bit_code.reset(new math::MatrixBitCodeFunctor(*path, *code, + label.data())); } auto& place = *ctx.template device_context().eigen_device(); - auto pre_out_mat = EigenMatrix::From(*pre_out); + auto pre_out_mat = EigenMatrix::From(pre_out); auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); - auto out_grad_mat = EigenMatrix::From(*out_grad); + auto out_grad_mat = EigenMatrix::From(out_grad); Eigen::array bcast{1, static_cast(pre_out_grad.dims()[1])}; @@ -181,17 +178,17 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { ctx.Output(framework::GradVarName("W")); w_grad->mutable_data(ctx.GetPlace()); zero(dev_ctx, w_grad, static_cast(0.0)); - bit_code->MulGradWeight(pre_out_grad, w_grad, *in); + bit_code->MulGradWeight(pre_out_grad, w_grad, in); } else { - framework::Vector real_rows = cal_rows(*path); + framework::Vector real_rows = PathToRows(*path); auto* w_grad = ctx.Output(framework::GradVarName("W")); w_grad->set_rows(real_rows); // Build a map of id -> row_index to speed up finding the index of one id w_grad->SyncIndex(); - w_grad->set_height(w->dims()[0]); + w_grad->set_height(w.dims()[0]); auto* w_grad_value = w_grad->mutable_value(); - framework::DDim temp_dim(w->dims()); + framework::DDim temp_dim(w.dims()); set(temp_dim, 0, real_rows.size()); w_grad_value->mutable_data(temp_dim, ctx.GetPlace()); @@ -211,9 +208,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { zero(dev_ctx, bias_grad_value, static_cast(0.0)); bit_code->AddGrad(pre_out_grad, bias_grad); } - bit_code->MulGradWeight(pre_out_grad, w_grad, *in); + bit_code->MulGradWeight(pre_out_grad, w_grad, in); } - bit_code->MulGradError(pre_out_grad, *w, in_grad); + bit_code->MulGradError(pre_out_grad, w, in_grad); } }; diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 297e8d850b..71b9293eed 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -19,12 +19,12 @@ namespace operators { namespace math { template -void MatrixBitCodeFunctor::Add(const framework::LoDTensor& vec, - framework::LoDTensor* tmat) { +void MatrixBitCodeFunctor::Add(const framework::Tensor& vec, + framework::Tensor* tmat) { size_t batch_size = tmat->dims()[0]; size_t width = tmat->dims()[1]; for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table->get_code(i); + auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); @@ -34,12 +34,12 @@ void MatrixBitCodeFunctor::Add(const framework::LoDTensor& vec, } template -void MatrixBitCodeFunctor::AddGrad(const framework::LoDTensor& tmat, - framework::LoDTensor* vec) { +void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, + framework::Tensor* vec) { size_t batch_size = tmat.dims()[0]; size_t width = tmat.dims()[1]; for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table->get_code(i); + auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); @@ -49,17 +49,16 @@ void MatrixBitCodeFunctor::AddGrad(const framework::LoDTensor& tmat, } template -void MatrixBitCodeFunctor::AddGrad(const framework::LoDTensor& tmat, +void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, framework::SelectedRows* vec) { size_t batch_size = tmat.dims()[0]; size_t width = tmat.dims()[1]; for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table->get_code(i); + auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); - int64_t row_index = - vec->AutoGrownIndex(static_cast(index), false, true); + int64_t row_index = vec->GetIndexFromId(static_cast(index)); vec->mutable_value()->data()[row_index] += tmat.data()[i * width + j]; } @@ -67,13 +66,13 @@ void MatrixBitCodeFunctor::AddGrad(const framework::LoDTensor& tmat, } template -void MatrixBitCodeFunctor::Sum(const framework::LoDTensor& tmat, - framework::LoDTensor* sum, T scale_sum) { +void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, + framework::Tensor* sum, T scale_sum) { size_t num_samples = tmat.dims()[0]; size_t o_width = tmat.dims()[1]; for (size_t i = 0; i < num_samples; ++i) { T sm = static_cast(0.0); - auto code = code_table->get_code(i); + auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { if (code->calc_bit(j)) { @@ -87,9 +86,9 @@ void MatrixBitCodeFunctor::Sum(const framework::LoDTensor& tmat, } template -void MatrixBitCodeFunctor::Mul(framework::LoDTensor* tmat, - const framework::LoDTensor& weight, - const framework::LoDTensor& input) { +void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, + const framework::Tensor& weight, + const framework::Tensor& input) { size_t num_samples = tmat->dims()[0]; size_t tmat_width = tmat->dims()[1]; size_t input_width = input.dims()[1]; @@ -98,7 +97,7 @@ void MatrixBitCodeFunctor::Mul(framework::LoDTensor* tmat, auto weight_value = weight.data(); auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table->get_code(i); + auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); @@ -113,9 +112,9 @@ void MatrixBitCodeFunctor::Mul(framework::LoDTensor* tmat, } template -void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, - framework::LoDTensor* weight, - const framework::LoDTensor& input) { +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, + framework::Tensor* weight, + const framework::Tensor& input) { size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; @@ -124,7 +123,7 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, auto weight_value = weight->data(); auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table->get_code(i); + auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); @@ -138,9 +137,9 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, } template -void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, framework::SelectedRows* weight, - const framework::LoDTensor& input) { + const framework::Tensor& input) { size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; @@ -149,13 +148,12 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, auto weight_value = weight->mutable_value()->data(); auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table->get_code(i); + auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); for (size_t k = 0; k < input_width; ++k) { - int64_t row_index = - weight->AutoGrownIndex(static_cast(index), false, true); + int64_t row_index = weight->GetIndexFromId(static_cast(index)); weight_value[row_index * weight_width + k] += tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; } @@ -164,9 +162,9 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, } template -void MatrixBitCodeFunctor::MulGradError(const framework::LoDTensor& tmat, - const framework::LoDTensor& weight, - framework::LoDTensor* input) { +void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, + const framework::Tensor& weight, + framework::Tensor* input) { size_t num_samples = tmat.dims()[0]; size_t tmat_width = tmat.dims()[1]; size_t input_width = input->dims()[1]; @@ -176,7 +174,7 @@ void MatrixBitCodeFunctor::MulGradError(const framework::LoDTensor& tmat, auto input_value = input->data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table->get_code(i); + auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); @@ -191,11 +189,11 @@ void MatrixBitCodeFunctor::MulGradError(const framework::LoDTensor& tmat, } template -void MatrixBitCodeFunctor::Sub(framework::LoDTensor* tmat) { +void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { size_t num_samples = tmat->dims()[0]; size_t o_width = tmat->dims()[1]; for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table->get_code(i); + auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { if (code->calc_bit(j)) { diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index 3add06cb63..c30bb52641 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -132,13 +132,15 @@ class SimpleCode : public Code { size_t c_; }; -template +template class CustomCode : public Code { public: - CustomCode(const framework::LoDTensor* ptable, - const framework::LoDTensor* pcode, const int64_t* ids, - const int index) - : ptable_(ptable), pcode_(pcode), ids_(ids), index_(index) {} + CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode, + const int64_t* ids, int index) + : ids_(ids), index_(index) { + ptable_ = ptable.Slice(index, index + 1); + pcode_ = pcode.Slice(index, index + 1); + } /** * Here the id of root shoud be 1 rather than 0, thus the encoding of class c * is `c + num_classes` and all siblings can get the same weight indice using @@ -148,20 +150,13 @@ class CustomCode : public Code { * Binary classification path is the suffixes of encoding, thus leave out the * left most bit in calc_bit. */ - size_t calc_index(int bit) const { - return ptable_ - ->data()[index_ * static_cast(ptable_->dims()[1]) + bit]; - } - bool calc_bit(int bit) const { - return pcode_ - ->data()[index_ * static_cast(ptable_->dims()[1]) + bit]; - } + size_t calc_index(int bit) const { return ptable_.data()[bit]; } + bool calc_bit(int bit) const { return pcode_.data()[bit]; } int get_length() const { int length = 0; - for (int i = 0; i < static_cast(ptable_->dims()[1]); i++) { - if (ptable_->data()[index_ * static_cast(ptable_->dims()[1]) + - i] >= 0) { + for (int i = 0; i < static_cast(ptable_.dims()[1]); i++) { + if (ptable_.data()[i] >= 0) { length++; } else { return length; @@ -171,15 +166,15 @@ class CustomCode : public Code { } private: - const framework::LoDTensor* ptable_; - const framework::LoDTensor* pcode_; + framework::Tensor ptable_; + framework::Tensor pcode_; const int64_t* ids_; const int index_; }; class SimpleCodeTable : public CodeTable { public: - explicit SimpleCodeTable(size_t num_classes, const int64_t* ids) + SimpleCodeTable(size_t num_classes, const int64_t* ids) : num_classes_(num_classes), ids_(ids) {} std::unique_ptr get_code(int64_t code) const { std::unique_ptr coder(new SimpleCode(code, num_classes_, ids_)); @@ -193,97 +188,92 @@ class SimpleCodeTable : public CodeTable { const int64_t* ids_; }; -template +template class CustomCodeTable : public CodeTable { public: - explicit CustomCodeTable(const framework::LoDTensor* ptable, - const framework::LoDTensor* pcode, - const int64_t* ids) + CustomCodeTable(const framework::Tensor& ptable, + const framework::Tensor& pcode, const int64_t* ids) : ptable_(ptable), pcode_(pcode), ids_(ids) {} std::unique_ptr get_code(int64_t code) const { - std::unique_ptr coder(new CustomCode(ptable_, pcode_, ids_, code)); + std::unique_ptr coder(new CustomCode(ptable_, pcode_, ids_, code)); return coder; } - size_t size() const { return static_cast(ptable_->dims()[1]); } + size_t size() const { return static_cast(ptable_.dims()[1]); } int get_max_code_length() const { - return static_cast(ptable_->dims()[1]); + return static_cast(ptable_.dims()[1]); } private: - const framework::LoDTensor* ptable_; - const framework::LoDTensor* pcode_; + const framework::Tensor& ptable_; + const framework::Tensor& pcode_; const int64_t* ids_; }; template class MatrixBitCodeFunctor { public: - explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) + MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) : num_classes_(num_classes), ids_(ids), - code_table(new SimpleCodeTable(num_classes, ids)) {} + code_table_(new SimpleCodeTable(num_classes, ids)) {} - explicit MatrixBitCodeFunctor(const framework::LoDTensor* ptable, - const framework::LoDTensor* pcode, - const int64_t* ids) - : num_classes_(static_cast(ptable->dims()[1])), + MatrixBitCodeFunctor(const framework::Tensor& ptable, + const framework::Tensor& pcode, const int64_t* ids) + : num_classes_(static_cast(ptable.dims()[1])), ids_(ids), - code_table(new CustomCodeTable(ptable, pcode, ids)) {} + code_table_(new CustomCodeTable(ptable, pcode, ids)) {} /* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ - void Add(const framework::LoDTensor& vec, framework::LoDTensor* tmat); + void Add(const framework::Tensor& vec, framework::Tensor* tmat); /* For j < code_length vec(0, index(i, j)) += tmat(i, j) */ - void AddGrad(const framework::LoDTensor& tmat, framework::LoDTensor* vec); + void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec); /* For selected rows For j < code_length vec(0, index(i, j)) += tmat(i, j) */ - void AddGrad(const framework::LoDTensor& tmat, framework::SelectedRows* vec); + void AddGrad(const framework::Tensor& tmat, framework::SelectedRows* vec); /* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) */ - void Sum(const framework::LoDTensor& tmat, framework::LoDTensor* sum, - T scale_sum); + void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum); /* For j < code_length tmat(i, j) -= bit(i, j) */ - void Sub(framework::LoDTensor* tmat); + void Sub(framework::Tensor* tmat); /* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ - void Mul(framework::LoDTensor* tmat, const framework::LoDTensor& weight, - const framework::LoDTensor& input); + void Mul(framework::Tensor* tmat, const framework::Tensor& weight, + const framework::Tensor& input); /* For index(i, j) >= 0: weight.row(index(i, j)) += tmat(i, j) * input.row(i) */ - void MulGradWeight(const framework::LoDTensor& tmat, - framework::LoDTensor* weight, - const framework::LoDTensor& input); + void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, + const framework::Tensor& input); /* For SelectedRows Weight, For index(i, j) >= 0: weight.row(index(i, j)) += tmat(i, j) * input.row(i) */ - void MulGradWeight(const framework::LoDTensor& tmat, + void MulGradWeight(const framework::Tensor& tmat, framework::SelectedRows* weight, - const framework::LoDTensor& input); + const framework::Tensor& input); /* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ - void MulGradError(const framework::LoDTensor& tmat, - const framework::LoDTensor& weight, - framework::LoDTensor* input); + void MulGradError(const framework::Tensor& tmat, + const framework::Tensor& weight, framework::Tensor* input); size_t num_classes_; const int64_t* ids_; - std::unique_ptr code_table; + std::unique_ptr code_table_; }; } // namespace math } // namespace operators diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8e7cff8056..fd02b445e7 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4639,7 +4639,7 @@ def hsigmoid(input, "X": input, "W": weights, "PTable": ptable, - "PCode": pcode, + "PathCode": pcode, "Label": label } if helper.bias_attr: diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 955fc51d57..8152ce9b78 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -185,7 +185,7 @@ class TestHSigmoidOpSparse(OpTest): 'X': x, 'W': w, 'PTable': ptable, - 'PCode': pcode, + 'PathCode': pcode, 'Label': label, 'Bias': bias } @@ -285,7 +285,7 @@ class TestHSigmoidOpWithCostumTree(OpTest): 'X': x, 'W': w, 'PTable': ptable, - 'PCode': pcode, + 'PathCode': pcode, 'Label': label, 'Bias': bias } @@ -322,7 +322,7 @@ class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest): 'X': x, 'W': w, 'PTable': ptable, - 'PCode': pcode, + 'PathCode': pcode, 'Label': label, } pre_output, out = hsigmoidWithCustomTree( From c469334cfb38db3b1bd95bb2c735e44549a3a015 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 27 Nov 2018 11:30:44 +0000 Subject: [PATCH 22/23] polish python code and comment, test=develop --- .../fluid/operators/hierarchical_sigmoid_op.h | 16 ++-- python/paddle/fluid/layers/nn.py | 58 ++++++++------ .../fluid/tests/unittests/test_hsigmoid_op.py | 75 ++++++++++--------- .../fluid/tests/unittests/test_layers.py | 12 +-- 4 files changed, 88 insertions(+), 73 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 6cb011611d..07ff8f947e 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -47,11 +47,11 @@ template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto in = detail::Ref(ctx.Input("X")); - auto w = detail::Ref(ctx.Input("W")); + auto& in = detail::Ref(ctx.Input("X")); + auto& w = detail::Ref(ctx.Input("W")); auto* path = ctx.Input("PTable"); auto* code = ctx.Input("PathCode"); - auto label = detail::Ref(ctx.Input("Label")); + auto& label = detail::Ref(ctx.Input("Label")); auto* bias = ctx.Input("Bias"); auto* out = ctx.Output("Out"); auto* pre_out = ctx.Output("PreOut"); @@ -114,8 +114,8 @@ template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto in = detail::Ref(ctx.Input("X")); - auto w = detail::Ref(ctx.Input("W")); + auto& in = detail::Ref(ctx.Input("X")); + auto& w = detail::Ref(ctx.Input("W")); auto* path = ctx.Input("PTable"); auto* code = ctx.Input("PathCode"); auto* bias = ctx.Input("Bias"); @@ -124,9 +124,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { bool is_sparse = ctx.Attr("is_sparse"); auto& dev_ctx = ctx.template device_context(); math::SetConstant zero; - auto label = detail::Ref(ctx.Input("Label")); - auto pre_out = detail::Ref(ctx.Input("PreOut")); - auto out_grad = detail::Ref( + auto& label = detail::Ref(ctx.Input("Label")); + auto& pre_out = detail::Ref(ctx.Input("PreOut")); + auto& out_grad = detail::Ref( ctx.Input(framework::GradVarName("Out"))); framework::LoDTensor pre_out_grad; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 44116a262c..b22e9715b8 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4589,23 +4589,33 @@ def hsigmoid(input, bias_attr=None, name=None, non_leaf_num=None, - ptable=None, - pcode=None, - is_costum=False, + path_table=None, + path_code=None, + is_custom=False, is_sparse=False): """ The hierarchical sigmoid operator is used to accelerate the training process of language model. This operator organizes the classes into a - complete binary tree, each leaf node represents a class(a word) and each + complete binary tree, or you can use is_custom to pass your own tree to + implement hierarchical. Each leaf node represents a class(a word) and each internal node acts as a binary classifier. For each word there's a unique path from root to it's leaf node, hsigmoid calculate the cost for each internal node on the path, and sum them to get a total cost. hsigmoid can achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N` represents the size of word dict. - Refer to `Hierarchical Probabilistic Neural Network Language Model + Using default tree you can Refer to `Hierarchical Probabilistic Neural Network Language Model `_ + And if you want to use the costumed tree by set 'is_custom' as true you may need to do following things first: + 1. using your word dict to build a binary tree, each leaf node should be an word of your word dict + 2. build a dict to store word_id -> word's leaf to root path, we call it path_table. + 3. build a dict to store word_id -> code of word's leaf to root path, we call it path_code. Code + means label of each binary classification, using 1 indicate true, 0 indicate false. + 4. now, each word should has its path and code along the path, you can pass a batch of path and code + related to the same batch of inputs. + + Args: input (Variable): The input tensor variable with shape :math:`[N \\times D]`, where :math:`N` is the size of mini-batch, @@ -4613,13 +4623,6 @@ def hsigmoid(input, label (Variable): The tensor variable contains labels of training data. It's a tensor with shape is :math:`[N \\times 1]`. num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set - non_leaf_num: this defines the number of non-leaf nodes in costumed tree - ptable: (Variable|None) this variable can store each batch of samples' path to root, - it should be in leaf -> root order - ptable should have the same shape with pcode, and for each sample i ptable[i] indicates a np.array like - structure and each element in this array is indexes in parent nodes' Weight Matrix. - pcode: (Variable|None) this variable can store each batch of samples' code, - each code consist with every code of parent nodes. it should be in leaf -> root order param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create ParamAttr as param_attr. If the Initializer of the param_attr @@ -4631,8 +4634,15 @@ def hsigmoid(input, is not set, the bias is initialized zero. Default: None. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None. - is_costum: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is - set you need to set ptable/pcode/non_leaf_num, otherwise num_classes should be set + non_leaf_num: this defines the number of non-leaf nodes in costumed tree + path_table: (Variable|None) this variable can store each batch of samples' path to root, + it should be in leaf -> root order + path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like + structure and each element in this array is indexes in parent nodes' Weight Matrix. + path_code: (Variable|None) this variable can store each batch of samples' code, + each code consist with every code of parent nodes. it should be in leaf -> root order + is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is + set you need to set path_table/path_code/non_leaf_num, otherwise num_classes should be set is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient of W and input will be sparse. @@ -4653,22 +4663,22 @@ def hsigmoid(input, out = helper.create_variable_for_type_inference(dtype) pre_out = helper.create_variable_for_type_inference(dtype) dim = input.shape[1] - if ((num_classes is None) or (num_classes < 2)) and (not is_costum): + if ((num_classes is None) or (num_classes < 2)) and (not is_custom): raise ValueError( "num_classes must not be less than 2 with default tree") - if (is_costum) and (pcode is None): - raise ValueError("pcode should not be None with costum tree") - elif (is_costum) and (ptable is None): - raise ValueError("ptable should not be None with costum tree") - elif (is_costum) and (non_leaf_num is None): + if (is_custom) and (path_code is None): + raise ValueError("path_code should not be None with costum tree") + elif (is_custom) and (path_table is None): + raise ValueError("path_table should not be None with costum tree") + elif (is_custom) and (non_leaf_num is None): raise ValueError("non_leaf_num should not be None with costum tree") else: pass weights = None - if not is_costum: + if not is_custom: weights = helper.create_parameter( attr=helper.param_attr, shape=[num_classes - 1, dim], @@ -4683,12 +4693,12 @@ def hsigmoid(input, inputs = { "X": input, "W": weights, - "PTable": ptable, - "PathCode": pcode, + "PTable": path_table, + "PathCode": path_code, "Label": label } if helper.bias_attr: - if not is_costum: + if not is_custom: bias = helper.create_parameter( attr=helper.bias_attr, shape=[num_classes - 1, 1], diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 8152ce9b78..4254c3bb25 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -43,9 +43,9 @@ class CodeTable(object): class CodeTableWithCustomTree(object): - def __init__(self, ptable, pcode, index): - self.ptable_ = ptable - self.pcode_ = pcode + def __init__(self, path_table, path_code, index): + self.ptable_ = path_table + self.pcode_ = path_code self.index_ = index def cal_index(self, bit): @@ -102,9 +102,10 @@ def hsigmoid(x, w, label, bias, num_classes): return pre_output, out -def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): +def hsigmoidWithCustomTree(x, w, path_table, path_code, label, bias, + num_classes): batch_size = x.shape[0] - code_length = len(ptable[0]) + code_length = len(path_table[0]) code_table = [0 for _ in range(code_length)] # init pre_out with shape [N, code_length] pre_output = np.zeros((batch_size, code_length)) @@ -112,13 +113,13 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): out = np.zeros((batch_size, 1)).astype("float32") if isinstance(bias, np.ndarray): for i in range(batch_size): - code_table = CodeTableWithCustomTree(ptable, pcode, i) + code_table = CodeTableWithCustomTree(path_table, path_code, i) length = code_table.get_length() for j in range(length): idx = code_table.cal_index(j) pre_output[i][j] += bias[idx][0] for i in range(batch_size): - code_table = CodeTableWithCustomTree(ptable, pcode, i) + code_table = CodeTableWithCustomTree(path_table, path_code, i) length = code_table.get_length() for j in range(length): idx = code_table.cal_index(j) @@ -127,7 +128,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): pre_output = np.clip(pre_output, -40.0, 40.0) # out(i, 0) = \sum_j bit(i, j) * preout(i, j) for i in range(batch_size): - code_table = CodeTableWithCustomTree(ptable, pcode, i) + code_table = CodeTableWithCustomTree(path_table, path_code, i) length = code_table.get_length() sum = 0.0 for j in range(length): @@ -173,24 +174,24 @@ class TestHSigmoidOpSparse(OpTest): x = np.random.random((batch_size, feature_size)).astype("float32") w = np.random.random((num_classes - 1, feature_size)).astype("float32") label = np.array([0, 1, 4, 5]) - ptable = np.array( + path_table = np.array( [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), (0, 2, -1, -1, -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) - pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store bias = np.random.random((num_classes - 1, 1)).astype("float32") self.attrs = {'num_classes': num_classes, 'is_sparse': True} self.inputs = { 'X': x, 'W': w, - 'PTable': ptable, - 'PathCode': pcode, + 'PTable': path_table, + 'PathCode': path_code, 'Label': label, 'Bias': bias } - pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, - bias, num_classes) + pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code, + label, bias, num_classes) self.outputs = {'PreOut': pre_output, 'Out': out} def test_check_output(self): @@ -200,11 +201,13 @@ class TestHSigmoidOpSparse(OpTest): class TestHSigmoidOpWithSparseGrad(unittest.TestCase): def hs_net_conf(self, is_sparse): input_word = fluid.layers.data(name="x", shape=[1], dtype='int64') - ptable = fluid.layers.data(name='ptable', shape=[3], dtype='int64') - pcode = fluid.layers.data(name='pcode', shape=[3], dtype='int64') + path_table = fluid.layers.data( + name='path_table', shape=[3], dtype='int64') + path_code = fluid.layers.data( + name='path_code', shape=[3], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64') - data_list = [input_word, ptable, pcode, label] + data_list = [input_word, path_table, path_code, label] emb = fluid.layers.embedding( input=input_word, @@ -218,9 +221,9 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): label=label, bias_attr=True, non_leaf_num=3, - ptable=ptable, - pcode=pcode, - is_costum=True, + path_table=path_table, + path_code=path_code, + is_custom=True, is_sparse=is_sparse) avg_cost = fluid.layers.reduce_mean(cost) @@ -232,8 +235,8 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): start_up = fluid.default_startup_program() start_up.random_seed = 1 # Fix random seed x = np.arange(6).reshape(6) - ptable = np.array([(1, 2, -1), (1, 2, -1)]) - pcode = np.array([(1, 0, -1), (0, 0, -1)]) + path_table = np.array([(1, 2, -1), (1, 2, -1)]) + path_code = np.array([(1, 0, -1), (0, 0, -1)]) label = np.array([1, 4]) loss, data_list = self.hs_net_conf(is_sparse) @@ -248,8 +251,8 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): exe.run(start_up) result = list() for i in range(10): - data = [([[x[i % 2]]], [list(ptable[i % 2])], - [list(pcode[i % 2])], [label[i % 2]])] + data = [([[x[i % 2]]], [list(path_table[i % 2])], + [list(path_code[i % 2])], [label[i % 2]])] loss_val = exe.run(main_program, feed=feeder.feed(data), @@ -273,24 +276,24 @@ class TestHSigmoidOpWithCostumTree(OpTest): w = np.random.random( (num_classes - 1, feature_size)).astype("float32") * 2 label = np.array([0, 1, 4, 5]) - ptable = np.array( + path_table = np.array( [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), (0, 2, -1, -1, -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) - pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store bias = np.random.random((num_classes - 1, 1)).astype("float32") self.attrs = {'num_classes': num_classes, 'is_sparse': False} self.inputs = { 'X': x, 'W': w, - 'PTable': ptable, - 'PathCode': pcode, + 'PTable': path_table, + 'PathCode': path_code, 'Label': label, 'Bias': bias } - pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, - bias, num_classes) + pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code, + label, bias, num_classes) self.outputs = {'PreOut': pre_output, 'Out': out} def test_check_output(self): @@ -310,26 +313,26 @@ class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest): w = np.random.random( (num_classes - 1, feature_size)).astype("float32") * 2 label = np.array([0, 1, 4, 5]) - ptable = np.array( + path_table = np.array( [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), (0, 2, -1, -1, -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) - pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store # bias = np.random.random((num_classes - 1, 1)).astype("float32") self.attrs = {'num_classes': num_classes, 'is_sparse': False} self.inputs = { 'X': x, 'W': w, - 'PTable': ptable, - 'PathCode': pcode, + 'PTable': path_table, + 'PathCode': path_code, 'Label': label, } pre_output, out = hsigmoidWithCustomTree( x=x, w=w, - ptable=ptable, - pcode=pcode, + path_table=path_table, + path_code=path_code, label=label, bias=None, num_classes=num_classes) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 0dc3388b94..b8477820ee 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -190,16 +190,18 @@ class TestBook(unittest.TestCase): with program_guard(program2): x2 = layers.data(name='x2', shape=[4, 8], dtype='float32') y2 = layers.data(name='y2', shape=[4], dtype='int64') - ptable = layers.data(name='ptable', shape=[4, 6], dtype='int64') - pcode = layers.data(name='pcode', shape=[4, 6], dtype='int64') + path_table = layers.data( + name='path_table', shape=[4, 6], dtype='int64') + path_code = layers.data( + name='path_code', shape=[4, 6], dtype='int64') self.assertIsNotNone( layers.hsigmoid( input=x2, label=y2, non_leaf_num=6, - ptable=ptable, - pcode=pcode, - is_costum=True)) + path_table=path_table, + path_code=path_code, + is_custom=True)) print(str(program2)) def test_sequence_expand(self): From a08dc83eb0d729f87040690e0ea6a1fabc70e228 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 27 Nov 2018 12:04:14 +0000 Subject: [PATCH 23/23] remove arg 'non_leaf_num', test=develop --- paddle/fluid/API.spec | 2 +- .../fluid/operators/hierarchical_sigmoid_op.cc | 2 +- python/paddle/fluid/layers/nn.py | 18 +++++++++--------- .../fluid/tests/unittests/test_hsigmoid_op.py | 2 +- .../fluid/tests/unittests/test_layers.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 894d8dda3d..c40f603341 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -97,8 +97,8 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) -paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'non_leaf_num', 'ptable', 'pcode', 'is_costum', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, False, False)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)) +paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 5b09958e73..972dcf5494 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -108,7 +108,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddInput("Bias", "(LoDTensor, optional), The bias is a tensor with shape or " - "[non_leaf_num, 1]" + "[num_classes, 1]" "[num_classes - 1, 1].") .AsDispensable(); AddOutput( diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index b22e9715b8..4df74edfce 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4584,11 +4584,10 @@ def nce(input, def hsigmoid(input, label, - num_classes=None, + num_classes, param_attr=None, bias_attr=None, name=None, - non_leaf_num=None, path_table=None, path_code=None, is_custom=False, @@ -4622,7 +4621,9 @@ def hsigmoid(input, and :math:`D` is the feature size. label (Variable): The tensor variable contains labels of training data. It's a tensor with shape is :math:`[N \\times 1]`. - num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set + num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set, + it should never be None under is_custom=False, but while is_custom is true, it should be non leaf num + which indicates the num of classes using by binary classify. param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create ParamAttr as param_attr. If the Initializer of the param_attr @@ -4634,7 +4635,6 @@ def hsigmoid(input, is not set, the bias is initialized zero. Default: None. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None. - non_leaf_num: this defines the number of non-leaf nodes in costumed tree path_table: (Variable|None) this variable can store each batch of samples' path to root, it should be in leaf -> root order path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like @@ -4642,7 +4642,7 @@ def hsigmoid(input, path_code: (Variable|None) this variable can store each batch of samples' code, each code consist with every code of parent nodes. it should be in leaf -> root order is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is - set you need to set path_table/path_code/non_leaf_num, otherwise num_classes should be set + set you need to set path_table/path_code/num_classes, otherwise num_classes should be set is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient of W and input will be sparse. @@ -4671,8 +4671,8 @@ def hsigmoid(input, raise ValueError("path_code should not be None with costum tree") elif (is_custom) and (path_table is None): raise ValueError("path_table should not be None with costum tree") - elif (is_custom) and (non_leaf_num is None): - raise ValueError("non_leaf_num should not be None with costum tree") + elif (is_custom) and (num_classes is None): + raise ValueError("num_classes should not be None with costum tree") else: pass @@ -4687,7 +4687,7 @@ def hsigmoid(input, else: weights = helper.create_parameter( attr=helper.param_attr, - shape=[non_leaf_num, dim], + shape=[num_classes, dim], is_bias=False, dtype=input.dtype) inputs = { @@ -4708,7 +4708,7 @@ def hsigmoid(input, else: bias = helper.create_parameter( attr=helper.bias_attr, - shape=[non_leaf_num, 1], + shape=[num_classes, 1], is_bias=True, dtype=input.dtype) inputs['Bias'] = bias diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 4254c3bb25..2a6c93f75f 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -220,7 +220,7 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): input=emb, label=label, bias_attr=True, - non_leaf_num=3, + num_classes=3, path_table=path_table, path_code=path_code, is_custom=True, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index b8477820ee..5411607711 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -198,7 +198,7 @@ class TestBook(unittest.TestCase): layers.hsigmoid( input=x2, label=y2, - non_leaf_num=6, + num_classes=6, path_table=path_table, path_code=path_code, is_custom=True))