commit
10b2534ebc
@ -0,0 +1,186 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/nce_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class NCEOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"));
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"));
|
||||
PADDLE_ENFORCE(ctx->HasInput("Weight"));
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Cost"));
|
||||
PADDLE_ENFORCE(ctx->HasOutput("SampleLogits"));
|
||||
PADDLE_ENFORCE(ctx->HasOutput("SampleLabels"));
|
||||
|
||||
auto x_dims = ctx->GetInputDim("Input");
|
||||
auto label_dims = ctx->GetInputDim("Label");
|
||||
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]);
|
||||
int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1;
|
||||
if (ctx->HasInput("Bias")) {
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Weight")[0],
|
||||
ctx->GetInputDim("Bias")[0]);
|
||||
}
|
||||
auto num_neg_samples = ctx->Attrs().Get<int>("num_neg_samples");
|
||||
auto num_total_classes = ctx->Attrs().Get<int>("num_total_classes");
|
||||
std::vector<int> custom_neg_classes =
|
||||
ctx->Attrs().Get<std::vector<int>>("custom_neg_classes");
|
||||
PADDLE_ENFORCE_EQ(num_total_classes, ctx->GetInputDim("Weight")[0]);
|
||||
if (custom_neg_classes.size() > 0) {
|
||||
PADDLE_ENFORCE_EQ(custom_neg_classes.size(),
|
||||
static_cast<size_t>(num_neg_samples));
|
||||
}
|
||||
// set dims of output(Out)
|
||||
std::vector<int64_t> out_dims;
|
||||
out_dims.push_back(x_dims[0]);
|
||||
out_dims.push_back(1);
|
||||
ctx->SetOutputDim("Cost", framework::make_ddim(out_dims));
|
||||
|
||||
// set dims of output(SampleOut)
|
||||
std::vector<int64_t> sample_out_dims;
|
||||
sample_out_dims.push_back(x_dims[0]);
|
||||
sample_out_dims.push_back(num_neg_samples + num_true_classes);
|
||||
ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims));
|
||||
ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
|
||||
AddInput(
|
||||
"Label",
|
||||
"(Tensor) A tensor of shape [batch_size, num_true_class]. "
|
||||
"'num_true_class' is the number of target classes in each sample."
|
||||
"The number of target classes per sample should be same. "
|
||||
"If you have a variable number of target classes, "
|
||||
"you can pad them out to a constant number by either repeating them"
|
||||
" or by padding with an otherwise unused class.)");
|
||||
AddInput("Weight",
|
||||
"(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the "
|
||||
"total number of class.");
|
||||
AddInput(
|
||||
"Bias",
|
||||
"(Tensor) A tensor of shape [num_class, 1]. 'num_class' is the total "
|
||||
"number of class. It is a dispensable input.")
|
||||
.AsDispensable();
|
||||
AddInput("SampleWeight",
|
||||
"(Tensor) A tensor of shape [batch_size, 1] storing a weight for "
|
||||
"each sample. And it is a dispensable input. The default value of "
|
||||
"sample is 1.")
|
||||
.AsDispensable();
|
||||
AddOutput("Cost",
|
||||
"(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
|
||||
AddOutput("SampleLogits",
|
||||
"An intermediate tensor of shape[batch_size, num_neg_samples + "
|
||||
"num_pos_samples]."
|
||||
"This tensor is output of forward kernel and used in backward "
|
||||
"kernel to compute grads."
|
||||
"Given X is the dot product of input tensor and sampled labels' "
|
||||
"weights."
|
||||
"Then 'SampleLogits' is sigmoid(X).")
|
||||
.AsIntermediate();
|
||||
AddOutput("SampleLabels",
|
||||
"An intermediate tensor of shape[batch_size, num_neg_samples + "
|
||||
"num_pos_samples]."
|
||||
"This tensor is output of forward kernel and used in backward "
|
||||
"kernel to compute grads."
|
||||
"")
|
||||
.AsIntermediate();
|
||||
AddAttr<int>("num_total_classes",
|
||||
"Total number of classes in all samples.");
|
||||
AddAttr<int>("num_neg_samples",
|
||||
"The number of negative classes. The default value is 10.")
|
||||
.SetDefault(10);
|
||||
AddAttr<std::vector<int>>("custom_neg_classes",
|
||||
"This attribute only be used in unitest. Classes "
|
||||
"in this list wiil be used as negative classes "
|
||||
"for every samples. Under normal conditions, "
|
||||
"user should avoid setting this attribute.");
|
||||
AddComment(R"DOC(
|
||||
Compute and return the noise-contrastive estimation training loss.
|
||||
See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
|
||||
By default this operator uses a uniform distribution for sampling.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class NCEOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"));
|
||||
PADDLE_ENFORCE(ctx->HasInput("Weight"));
|
||||
PADDLE_ENFORCE(ctx->HasInput("Cost"));
|
||||
PADDLE_ENFORCE(ctx->HasInput("SampleLogits"));
|
||||
PADDLE_ENFORCE(ctx->HasInput("SampleLabels"));
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")),
|
||||
"The input(Out@GRAD) should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("Input");
|
||||
auto x_grad_name = framework::GradVarName("Input");
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
ctx->SetOutputDim(x_grad_name, x_dims);
|
||||
}
|
||||
|
||||
auto w_dims = ctx->GetInputDim("Weight");
|
||||
auto w_grad_name = framework::GradVarName("Weight");
|
||||
if (ctx->HasOutput(w_grad_name)) {
|
||||
ctx->SetOutputDim(w_grad_name, w_dims);
|
||||
}
|
||||
|
||||
auto bias_grad_name = framework::GradVarName("Bias");
|
||||
if (ctx->HasOutput(bias_grad_name)) {
|
||||
auto bias_dims = ctx->GetInputDim("Bias");
|
||||
ctx->SetOutputDim(bias_grad_name, bias_dims);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(nce, ops::NCEOp, ops::NCEOpMaker, nce_grad, ops::NCEOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::NCEKernel<paddle::platform::CPUPlace, double>);
|
||||
REGISTER_OP_CPU_KERNEL(nce_grad,
|
||||
ops::NCEGradKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::NCEGradKernel<paddle::platform::CPUPlace, double>);
|
@ -0,0 +1,211 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <math.h>
|
||||
#include <random>
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename Place, typename T>
|
||||
void PrepareSamples(const framework::ExecutionContext& context) {
|
||||
auto label = context.Input<Tensor>("Label");
|
||||
const int64_t* label_data = label->data<int64_t>();
|
||||
auto label_dims = label->dims();
|
||||
int num_total_classes = context.Attr<int>("num_total_classes");
|
||||
// for unitest
|
||||
std::vector<int> custom_neg_classes =
|
||||
context.Attr<std::vector<int>>("custom_neg_classes");
|
||||
// random machine
|
||||
std::random_device rd;
|
||||
std::mt19937 rng(rd());
|
||||
std::uniform_int_distribution<int> rand(0, num_total_classes - 1);
|
||||
|
||||
auto sample_labels = context.Output<Tensor>("SampleLabels");
|
||||
auto sample_labels_dims = sample_labels->dims();
|
||||
int64_t* sample_labels_data =
|
||||
sample_labels->mutable_data<int64_t>(context.GetPlace());
|
||||
|
||||
int num_label = label_dims.size() == 2 ? label_dims[1] : 1;
|
||||
int index = 0;
|
||||
for (size_t i = 0; i < label_dims[0]; ++i) {
|
||||
int j = 0;
|
||||
for (; j < num_label; ++j) {
|
||||
sample_labels_data[index++] = label_data[i * num_label + j];
|
||||
}
|
||||
if (custom_neg_classes.size() > 0) {
|
||||
for (auto label : custom_neg_classes) {
|
||||
sample_labels_data[index++] = label;
|
||||
}
|
||||
} else {
|
||||
for (; j < sample_labels_dims[1]; ++j) {
|
||||
// TODO(wanghaoshuang): support more distribution sampling
|
||||
sample_labels_data[index++] = rand(rng);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Place, typename T>
|
||||
class NCEKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
PrepareSamples<Place, T>(context);
|
||||
auto sample_labels = context.Output<Tensor>("SampleLabels");
|
||||
const int64_t* sample_labels_data = sample_labels->data<int64_t>();
|
||||
auto sample_out = context.Output<Tensor>("SampleLogits");
|
||||
T* sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
|
||||
auto label = context.Input<Tensor>("Label");
|
||||
auto sample_weight = context.Input<Tensor>("SampleWeight");
|
||||
const T* sample_weight_data = nullptr;
|
||||
if (sample_weight != nullptr) {
|
||||
sample_weight_data = sample_weight->data<T>();
|
||||
}
|
||||
auto out = context.Output<Tensor>("Cost");
|
||||
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
int num_neg_samples = context.Attr<int>("num_neg_samples");
|
||||
int num_total_classes = context.Attr<int>("num_total_classes");
|
||||
int num_true_class = 1;
|
||||
if (label != nullptr) {
|
||||
num_true_class = label->dims()[1];
|
||||
}
|
||||
T b = 1. / num_total_classes * num_neg_samples;
|
||||
// forward bias
|
||||
auto bias = context.Input<Tensor>("Bias");
|
||||
if (bias != nullptr) {
|
||||
const T* bias_data = bias->data<T>();
|
||||
for (size_t i = 0; i < sample_labels->numel(); ++i) {
|
||||
sample_out_data[i] = bias_data[sample_labels_data[i]];
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < sample_labels->numel(); ++i) {
|
||||
sample_out_data[i] = 0;
|
||||
}
|
||||
}
|
||||
// forward mul
|
||||
auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
|
||||
auto weight_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
|
||||
for (size_t i = 0; i < sample_labels->numel(); ++i) {
|
||||
Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
|
||||
(input_mat.chip((int)(i / sample_labels->dims()[1]), 0) *
|
||||
weight_mat.chip(sample_labels_data[i], 0))
|
||||
.sum();
|
||||
sample_out_data[i] += result(0);
|
||||
sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
|
||||
}
|
||||
// forward cost
|
||||
for (size_t i = 0; i < sample_labels->dims()[0]; ++i) {
|
||||
size_t j = 0;
|
||||
out_data[i] = 0;
|
||||
T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
|
||||
// for true classes
|
||||
for (; j < num_true_class; ++j) {
|
||||
T o = sample_out_data[i * sample_out->dims()[1] + j];
|
||||
T cost = -log(o / (o + b));
|
||||
out_data[i] += w * cost;
|
||||
}
|
||||
// for sampled neg classes
|
||||
for (; j < sample_labels->dims()[1]; ++j) {
|
||||
T o = sample_out_data[i * sample_out->dims()[1] + j];
|
||||
T cost = -log(b / (o + b));
|
||||
out_data[i] += w * cost;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class NCEGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto d_out = context.Input<Tensor>(framework::GradVarName("Cost"));
|
||||
const T* d_out_data = d_out->data<T>();
|
||||
auto label = context.Input<Tensor>("Label");
|
||||
auto sample_out = context.Input<Tensor>("SampleLogits");
|
||||
const T* sample_out_data = sample_out->data<T>();
|
||||
auto sample_labels = context.Input<Tensor>("SampleLabels");
|
||||
const int64_t* sample_labels_data = sample_labels->data<int64_t>();
|
||||
auto sample_weight = context.Input<Tensor>("SampleWeight");
|
||||
const T* sample_weight_data = nullptr;
|
||||
if (sample_weight != nullptr) {
|
||||
sample_weight_data = sample_weight->data<T>();
|
||||
}
|
||||
int num_neg_samples = context.Attr<int>("num_neg_samples");
|
||||
int num_total_classes = context.Attr<int>("num_total_classes");
|
||||
int num_true_class = 1;
|
||||
if (label != nullptr) {
|
||||
num_true_class = label->dims()[1];
|
||||
}
|
||||
T b = 1. / num_total_classes * num_neg_samples;
|
||||
Tensor sample_grad; // tmp tensor
|
||||
T* sample_grad_data =
|
||||
sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
|
||||
// backward cost
|
||||
for (size_t i = 0; i < sample_labels->numel(); ++i) {
|
||||
T o = sample_out_data[i];
|
||||
T w = sample_weight == nullptr
|
||||
? 1
|
||||
: sample_weight_data[i / sample_labels->dims()[1]];
|
||||
sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class
|
||||
? w * (b / (o + b)) * (o - 1)
|
||||
: w * (o * (1 - o) / (o + b));
|
||||
sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]];
|
||||
}
|
||||
// get d_bias
|
||||
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
|
||||
if (d_bias != nullptr) {
|
||||
T* d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
|
||||
std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
|
||||
for (size_t i = 0; i < sample_labels->numel(); ++i) {
|
||||
d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
|
||||
}
|
||||
}
|
||||
// get d_w
|
||||
auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
|
||||
if (d_w != nullptr) {
|
||||
auto d_w_data = d_w->mutable_data<T>(context.GetPlace());
|
||||
std::fill(d_w_data, d_w_data + d_w->numel(), 0.0);
|
||||
auto d_w_matrix = EigenMatrix<T>::From(*d_w);
|
||||
auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
|
||||
for (size_t i = 0; i < sample_labels->numel(); ++i) {
|
||||
d_w_matrix.chip(sample_labels_data[i], 0) +=
|
||||
x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) *
|
||||
sample_grad_data[i];
|
||||
}
|
||||
}
|
||||
// get d_x
|
||||
auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
|
||||
if (d_x != nullptr) {
|
||||
d_x->mutable_data<T>(context.GetPlace());
|
||||
auto d_x_matrix = EigenMatrix<T>::From(*d_x);
|
||||
auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
|
||||
for (size_t i = 0; i < sample_labels->numel(); ++i) {
|
||||
d_x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) +=
|
||||
w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,98 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def nce(input, weight, bias, sample_weight, labels, num_classes,
|
||||
num_sample_class):
|
||||
samples = []
|
||||
sample_labels = []
|
||||
batch_size = input.shape[0]
|
||||
num_true_class = labels.shape[1]
|
||||
for i in range(batch_size):
|
||||
w = 1 if sample_weight is None else sample_weight[i]
|
||||
for label in labels[i]:
|
||||
samples.append((i, label, True, w))
|
||||
sample_labels.append(label)
|
||||
for num in range(num_sample_class):
|
||||
samples.append((i, num, False, w))
|
||||
sample_labels.append(num)
|
||||
# forward bias
|
||||
sample_out = np.zeros(len(samples)).astype(np.float32)
|
||||
if bias is not None:
|
||||
for i in range(len(samples)):
|
||||
sample_out[i] = bias[samples[i][1]]
|
||||
# forward weight
|
||||
for i in range(len(samples)):
|
||||
sample_out[i] += np.dot(input[samples[i][0]], weight[samples[i][1]])
|
||||
|
||||
# forward activation
|
||||
sample_out = 1.0 / (1.0 + np.exp(-sample_out))
|
||||
# forward cost
|
||||
out = np.zeros(batch_size).astype(np.float32)
|
||||
b = 1.0 / num_classes * num_sample_class
|
||||
for i in range(len(samples)):
|
||||
o = sample_out[i]
|
||||
cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b))
|
||||
out[samples[i][0]] += cost * samples[i][3]
|
||||
return (out[:, np.newaxis], np.array(sample_out).reshape(
|
||||
batch_size, num_sample_class + num_true_class),
|
||||
np.array(sample_labels).reshape(batch_size,
|
||||
num_sample_class + num_true_class))
|
||||
|
||||
|
||||
class TestNCE(OpTest):
|
||||
def generate_data(self, dim, batch_size, num_classes, num_true_class,
|
||||
num_neg_samples):
|
||||
input = np.random.randn(batch_size, dim).astype(np.float32)
|
||||
weight = np.random.randn(num_classes, dim).astype(np.float32)
|
||||
bias = np.random.randn(num_classes).astype(np.float32)
|
||||
sample_weight = np.random.randn(batch_size).astype(np.float32)
|
||||
labels = np.random.randint(0, num_classes, (batch_size, num_true_class))
|
||||
self.attrs = {
|
||||
'num_total_classes': num_classes,
|
||||
'num_neg_samples': num_neg_samples,
|
||||
'custom_neg_classes': range(num_neg_samples)
|
||||
}
|
||||
self.inputs = {
|
||||
'Input': input,
|
||||
'Label': labels,
|
||||
'Weight': weight,
|
||||
'Bias': bias,
|
||||
'SampleWeight': sample_weight
|
||||
}
|
||||
|
||||
def set_data(self):
|
||||
self.generate_data(5, 5, 4, 1, 2)
|
||||
|
||||
def compute(self):
|
||||
out = nce(self.inputs['Input'], self.inputs['Weight'],
|
||||
self.inputs['Bias'], self.inputs['SampleWeight'],
|
||||
self.inputs['Label'], self.attrs['num_total_classes'],
|
||||
self.attrs['num_neg_samples'])
|
||||
self.outputs = {
|
||||
'Cost': out[0],
|
||||
'SampleLogits': out[1],
|
||||
'SampleLabels': out[2]
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = 'nce'
|
||||
self.set_data()
|
||||
self.compute()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(
|
||||
["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02)
|
||||
|
||||
|
||||
class TestNCECase1(TestNCE):
|
||||
def set_data(self):
|
||||
self.generate_data(10, 20, 10, 2, 5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue