Make NCE_OP more efficient and support SelectedRows (#14469)

* Fix truncated normal.

* Fix.

* Make nce support more distribution.

* Fix API.spec.

* Fix python API.

* Fix.
test=develop

* Fix API.spec
test=develop

* Fix sampler.

* Fix order of arguments in python API.
test=develop

* NCE add selectedrows support

* NCE update weighted sampling

* fix bugs in nce_op, and assign_value_op optimized

* fix bugs in nce_op, revert assign_value_op

* nce_op optimize

* nce_op optimize

* nce_op optimize

* add selectedRows test later

test=develop

* add selectedRows supported

* add selectedRows supported

test=develop

* add selectedRows supported

* add nce selectedRows supported, test=develop

* add nce selectedRows supported

* add nce selectedRows supported, test=develop

* fix height in nce, test=develop

* add ut

* add ut, test=develop

* make AutoGrownIndex inline
test=develop

* fix tinny error, test=develop
local_add_cudnn_lstm
tangwei12 6 years ago committed by GitHub
parent 1c48d61442
commit 56a4912b76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -97,7 +97,7 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))

@ -60,75 +60,30 @@ float LogUniformSampler::Probability(int64_t value) const {
return (log((value + 2.0) / (value + 1.0))) / log_range_;
}
CustomSampler::CustomSampler(int64_t range, const float* probabilities,
CustomSampler::CustomSampler(int64_t range, const float *probabilities,
const int *alias, const float *alias_probabilities,
unsigned int seed)
: Sampler(range, seed) {
random_engine_ = std::make_shared<std::mt19937_64>(seed_);
random_engine_ = std::make_shared<std::mt19937>(seed_);
real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
alias_probs_ = std::make_shared<std::vector<float>>(range + 1);
alias_ = std::make_shared<std::vector<int64_t>>(range + 1);
probs_ = std::make_shared<std::vector<float>>(range + 1);
std::queue<std::pair<int64_t, float>> bigs;
std::queue<std::pair<int64_t, float>> littles;
for (int64_t i = 0; i <= range; ++i) {
(*probs_)[i] = probabilities[i];
float normal_prob = probabilities[i] * (range + 1);
if (normal_prob - 1.0 > 1e-4) {
bigs.emplace(i, normal_prob);
} else if (1.0 - normal_prob > 1e-4) {
littles.emplace(i, normal_prob);
} else {
(*alias_probs_)[i] = normal_prob;
(*alias_)[i] = -1;
}
}
while ((!littles.empty()) && (!bigs.empty())) {
auto big = bigs.front();
auto little = littles.front();
bigs.pop();
littles.pop();
(*alias_probs_)[little.first] = little.second;
(*alias_)[little.first] = big.first;
auto big_left = big.second - (1 - little.second);
if (big_left - 1.0 > 1e-4) {
bigs.emplace(big.first, big_left);
} else if (1.0 - big_left > 1e-4) {
littles.emplace(big.first, big_left);
} else {
(*alias_probs_)[big.first] = big_left;
(*alias_)[big.first] = -1;
}
}
if (!littles.empty()) { // littles.second is close to 1.0
auto little = littles.front();
(*alias_probs_)[little.first] = 1.0;
(*alias_)[little.first] = -1;
}
if (!bigs.empty()) { // bigs.second is close to 1.0
auto big = bigs.front();
(*alias_probs_)[big.first] = 1.0;
(*alias_)[big.first] = -1;
}
alias_probs_ = alias_probabilities;
probs_ = probabilities;
alias_ = alias;
}
int64_t CustomSampler::Sample() const {
auto index = (*int_dist_)(*random_engine_);
auto p = (*real_dist_)(*random_engine_);
if (p > (*alias_probs_)[index]) {
return (*alias_)[index];
if (p > alias_probs_[index]) {
return alias_[index];
} else {
return index;
}
}
float CustomSampler::Probability(int64_t value) const {
return (*probs_)[value];
}
float CustomSampler::Probability(int64_t value) const { return probs_[value]; }
} // namespace math
} // namespace operators

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdint>
#include <memory>
#include <random>
@ -38,9 +39,12 @@ class Sampler {
seed_ = seed;
}
}
virtual ~Sampler();
// Sample a single value
virtual int64_t Sample() const = 0;
// The probability that a single call to Sample() returns the given value.
virtual float Probability(int64_t value) const = 0;
@ -99,6 +103,7 @@ class LogUniformSampler : public Sampler {
class CustomSampler : public Sampler {
public:
explicit CustomSampler(int64_t range, const float* probabilities,
const int* alias, const float* alias_probabilities,
unsigned int seed = 0UL);
~CustomSampler() override {}
@ -108,10 +113,10 @@ class CustomSampler : public Sampler {
float Probability(int64_t value) const override;
private:
std::shared_ptr<std::vector<float>> alias_probs_;
std::shared_ptr<std::vector<int64_t>> alias_;
std::shared_ptr<std::vector<float>> probs_;
std::shared_ptr<std::mt19937_64> random_engine_;
const float* alias_probs_;
const int* alias_;
const float* probs_;
std::shared_ptr<std::mt19937> random_engine_;
std::shared_ptr<std::uniform_real_distribution<>> real_dist_;
std::shared_ptr<std::uniform_int_distribution<>> int_dist_;
};

@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/nce_op.h"
#include <string>
#include <vector>
namespace paddle {
@ -25,7 +26,7 @@ class NCEOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"));
PADDLE_ENFORCE(ctx->HasInput("Label"));
PADDLE_ENFORCE(ctx->HasInput("Weight"));
@ -67,7 +68,7 @@ class NCEOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
platform::CPUPlace());
@ -101,11 +102,24 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable();
AddInput(
"CustomDistribution",
"CustomDistProbs",
"(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable();
AddInput(
"CustomDistAlias",
"(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable();
AddInput(
"CustomDistAliasProbs",
"(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable();
AddOutput("Cost",
"(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
AddOutput("SampleLogits",
@ -124,21 +138,22 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
"kernel to compute grads."
"")
.AsIntermediate();
AddAttr<int>("num_total_classes",
"Total number of classes in all samples.");
AddAttr<int>("num_neg_samples",
"The number of negative classes. The default value is 10.")
.SetDefault(10);
AddAttr<int>("sampler",
"(int) Which sampler to be used to sample negative class."
"0: Uniform; 1: LogUniform; 2: CostumDist.")
.SetDefault(0);
AddAttr<int>("seed",
"(int) The seed used in sampler. If it is 0, "
"the sampler will generate a seed randomly.")
.SetDefault(0);
AddAttr<bool>("is_sparse", "(boolean, default false) Sparse update.")
.SetDefault(false);
AddAttr<std::vector<int>>("custom_neg_classes",
"This attribute only be used in unitest. Classes "
@ -156,11 +171,19 @@ By default this operator uses a uniform distribution for sampling.
}
};
class NCEOpGradDescMaker : public framework::DefaultGradOpDescMaker<true> {
using ::paddle::framework::DefaultGradOpDescMaker<
true>::DefaultGradOpDescMaker;
protected:
virtual std::string GradOpType() const { return "nce_grad"; }
};
class NCEOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"));
PADDLE_ENFORCE(ctx->HasInput("Weight"));
PADDLE_ENFORCE(ctx->HasInput("Cost"));
@ -190,20 +213,45 @@ class NCEOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
platform::CPUPlace());
}
};
class NCEOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front();
auto bias_grad = op_desc.Output(framework::GradVarName("Bias")).front();
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad
<< " is set to SelectedRows";
block->Var(weight_grad)
->SetType(framework::proto::VarType::SELECTED_ROWS);
block->Var(bias_grad)->SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad
<< " is set to LoDTensor";
block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
block->Var(bias_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
}
block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType());
block->Var(bias_grad)->SetDataType(block->Var("Input")->GetDataType());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad);
REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpGradDescMaker, ops::NCEOpMaker);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
ops::NCEKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(nce_grad,

File diff suppressed because it is too large Load Diff

@ -4394,7 +4394,8 @@ def nce(input,
name=None,
sampler="uniform",
custom_dist=None,
seed=0):
seed=0,
is_sparse=False):
"""
${comment}
@ -4420,11 +4421,12 @@ def nce(input,
sampler (str): The sampler used to sample class from negtive classes.
It can be 'uniform', 'log_uniform' or 'custom_dist'.
default: 'uniform'.
custom_dist (Variable): A tensor with shape [num_total_classes].
custom_dist (float[]): A float[] with size=num_total_classes.
It is used when sampler is set to 'custom_dist'.
custom_dist[i] is the probsbility of i-th class to be sampled.
default: None.
seed (int): The seed used in sampler. default: 0.
is_sparse(bool): The flag indicating whether to use sparse update, the weight@GRAD and bias@GRAD will be changed to SelectedRows.
Returns:
Variable: The output nce loss.
@ -4476,12 +4478,7 @@ def nce(input,
shape=[num_total_classes, dim],
is_bias=False,
dtype=input.dtype)
inputs = {
'Input': input,
'Label': label,
'Weight': w,
'SampleWeight': sample_weight if sample_weight is not None else []
}
inputs = {}
if helper.bias_attr:
b = helper.create_parameter(
attr=helper.bias_attr,
@ -4493,18 +4490,10 @@ def nce(input,
sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype)
sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype)
if num_neg_samples is None:
num_neg_samples = 10
else:
num_neg_samples = int(num_neg_samples)
inputs = {
'Input': input,
'Label': label,
'Weight': w,
'Bias': b,
'SampleWeight': sample_weight if sample_weight is not None else []
}
inputs['Input'] = input
inputs['Label'] = label
inputs['Weight'] = w
inputs['SampleWeight'] = sample_weight if sample_weight is not None else []
if sampler == "uniform":
sampler = 0
@ -4512,17 +4501,73 @@ def nce(input,
sampler = 1
elif sampler == "custom_dist":
assert custom_dist is not None
assert isinstance(custom_dist, Variable)
inputs['CustomDistribution'] = custom_dist
# assert isinstance(custom_dist, Variable)
custom_dist_len = len(custom_dist)
alias_probs_ = [0] * custom_dist_len
alias_ = [0] * custom_dist_len
bigs = []
littles = []
for i in range(custom_dist_len):
normal_prob = custom_dist[i] * custom_dist_len
if normal_prob - 1.0 > 1e-4:
bigs.append((i, normal_prob))
elif 1.0 - normal_prob > 1e-4:
littles.append((i, normal_prob))
else:
alias_probs_[i] = normal_prob
alias_[i] = -1
while len(bigs) and len(littles):
big = bigs.pop(0)
little = littles.pop(0)
big_idx = big[0]
big_prob = big[1]
alias_probs_[little[0]] = little[1]
alias_[little[0]] = big_idx
big_left = big[1] + little[1] - 1
if big_left - 1.0 > 1e-4:
bigs.append((big_idx, big_left))
elif 1.0 - big_left > 1e-4:
littles.append((big_idx, big_left))
else:
alias_probs_[big_idx] = big_left
alias_[big_idx] = -1
if len(bigs):
big = bigs.pop(0)
alias_probs_[big[0]] = 1.0
alias_[big[0]] = -1
if len(littles):
little = littles.pop(0)
alias_probs_[little[0]] = 1.0
alias_[little[0]] = -1
probs = assign(input=np.array(custom_dist).astype('float32'))
custom_alias = assign(input=np.array(alias_).astype('int32'))
custom_alias_probs = assign(
input=np.array(alias_probs_).astype('float32'))
inputs['CustomDistProbs'] = probs
inputs['CustomDistAlias'] = custom_alias
inputs['CustomDistAliasProbs'] = custom_alias_probs
sampler = 2
else:
raise Exception("Unsupported sampler type.")
if num_neg_samples is None:
num_neg_samples = 10
else:
num_neg_samples = int(num_neg_samples)
attrs = {
'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples,
'seed': seed,
'sampler': sampler
'sampler': sampler,
'is_sparse': is_sparse
}
helper.append_op(
@ -6474,7 +6519,7 @@ def crop(x, shape=None, offsets=None, name=None):
helper = LayerHelper('crop', **locals())
if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)):
isinstance(shape, Variable)):
raise ValueError("The shape should be a list, tuple or Variable.")
if offsets is None:
@ -6596,7 +6641,7 @@ def affine_grid(theta, out_shape, name=None):
helper = LayerHelper('affine_grid')
if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \
isinstance(out_shape, Variable)):
isinstance(out_shape, Variable)):
raise ValueError("The out_shape should be a list, tuple or Variable.")
if not isinstance(theta, Variable):

@ -14,8 +14,12 @@
from __future__ import print_function
import unittest
import numpy as np
import unittest
import paddle.fluid as fluid
import paddle.fluid.initializer as initializer
from op_test import OpTest
@ -59,7 +63,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes,
class TestNCE(OpTest):
def generate_data(self, dim, batch_size, num_classes, num_true_class,
num_neg_samples):
num_neg_samples, is_sparse):
input = np.random.randn(batch_size, dim).astype(np.float32)
weight = np.random.randn(num_classes, dim).astype(np.float32)
bias = np.random.randn(num_classes).astype(np.float32)
@ -70,7 +74,8 @@ class TestNCE(OpTest):
'num_neg_samples': num_neg_samples,
'custom_neg_classes': list(range(num_neg_samples)),
'seed': 0,
'sampler': 0
'sampler': 0,
'is_sparse': is_sparse
}
self.inputs = {
'Input': input,
@ -81,7 +86,7 @@ class TestNCE(OpTest):
}
def set_data(self):
self.generate_data(5, 5, 4, 1, 2)
self.generate_data(5, 5, 4, 1, 2, False)
def compute(self):
out = nce(self.inputs['Input'], self.inputs['Weight'],
@ -107,9 +112,110 @@ class TestNCE(OpTest):
["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02)
class TestNCECase1(TestNCE):
class TestNCECase1Tensor(TestNCE):
def set_data(self):
self.generate_data(10, 20, 10, 2, 5)
self.generate_data(10, 20, 10, 2, 5, False)
class TestNCECase1SelectedRows(unittest.TestCase):
def setUp(self):
self.base_lr = 0.0001
self.batch_size = 8
@staticmethod
def get_place():
place = fluid.core.CPUPlace()
return place
@staticmethod
def get_train_data(batch_size):
batchs = []
for i in range(batch_size):
input = np.random.randn(batch_size, 10).astype(np.float32)
labels = np.random.randint(0, 20, (batch_size, 1))
batchs.append([input, labels])
return batchs
def get_optimizer(self):
# SGD optimizer
optimizer = fluid.optimizer.SGD(learning_rate=self.base_lr)
return optimizer
def train_network(self, num_total_classes, num_neg_samples, sampler,
custom_dist, is_sparse):
input = fluid.layers.data(name="input", shape=[10], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
w_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 10],
dtype='float32',
name='nce_w',
initializer=initializer.ConstantInitializer())
b_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 1],
dtype='float32',
name='nce_b',
initializer=initializer.ConstantInitializer())
cost = fluid.layers.nce(input=input,
label=label,
num_total_classes=num_total_classes,
sampler=sampler,
custom_dist=custom_dist,
sample_weight=None,
param_attr='nce_w',
bias_attr='nce_b',
seed=1,
num_neg_samples=num_neg_samples,
is_sparse=is_sparse)
avg_cost = fluid.layers.mean(cost)
# optimizer
optimizer = self.get_optimizer()
optimizer.minimize(avg_cost)
return [avg_cost, [input, label]]
def test_input_is_selected_rows(self):
place = self.get_place()
exe = fluid.Executor(place)
data = self.get_train_data(self.batch_size)
nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32')
rets = []
# for dense
dense_scope = fluid.core.Scope()
dense_startup_program = fluid.framework.Program()
dense_train_program = fluid.framework.Program()
with fluid.scope_guard(dense_scope):
with fluid.program_guard(dense_train_program,
dense_startup_program):
cost, feeds = self.train_network(20, 5, "custom_dist",
nid_freq_arr.tolist(), False)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
exe.run(dense_startup_program)
loss_val = exe.run(dense_train_program,
feed=feeder.feed(data),
fetch_list=[cost.name])
rets.append(np.mean(loss_val))
# for sparse
sparse_scope = fluid.core.Scope()
sparse_startup_program = fluid.framework.Program()
sparse_train_program = fluid.framework.Program()
with fluid.scope_guard(sparse_scope):
with fluid.program_guard(sparse_train_program,
sparse_startup_program):
cost, feeds = self.train_network(20, 5, "custom_dist",
nid_freq_arr.tolist(), True)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
exe.run(sparse_startup_program)
loss_val = exe.run(sparse_train_program,
feed=feeder.feed(data),
fetch_list=[cost.name])
rets.append(np.mean(loss_val))
self.assertEqual(rets[0], rets[1])
if __name__ == '__main__':

Loading…
Cancel
Save