commit
23ac845921
@ -0,0 +1,117 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "Layer.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
class KmaxSeqScoreLayer : public Layer {
|
||||
private:
|
||||
MatrixPtr scores_;
|
||||
size_t beamSize_;
|
||||
void kmaxScorePerSeq(const real* score,
|
||||
real* sortedRes,
|
||||
const ICpuGpuVectorPtr seqStartPos);
|
||||
|
||||
public:
|
||||
explicit KmaxSeqScoreLayer(const LayerConfig& config) : Layer(config) {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
|
||||
void forward(PassType passType) override;
|
||||
void backward(const UpdateCallback& callback = nullptr) override;
|
||||
};
|
||||
|
||||
REGISTER_LAYER(kmax_seq_score, KmaxSeqScoreLayer);
|
||||
|
||||
bool KmaxSeqScoreLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
bool ret = Layer::init(layerMap, parameterMap);
|
||||
CHECK_EQ(1U, inputLayers_.size());
|
||||
|
||||
beamSize_ = config_.beam_size();
|
||||
CHECK_GE(beamSize_, 1U);
|
||||
|
||||
setNeedSequenceInfo(false);
|
||||
setNeedGradient(false);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void KmaxSeqScoreLayer::kmaxScorePerSeq(const real* scores,
|
||||
real* sortedIds,
|
||||
const ICpuGpuVectorPtr seqStartPos) {
|
||||
int* starts = seqStartPos->getMutableData(false);
|
||||
std::vector<real> indices;
|
||||
for (size_t i = 0; i < seqStartPos->getSize() - 1; ++i) {
|
||||
int seqLen = starts[i + 1] - starts[i];
|
||||
int k = std::min(static_cast<int>(beamSize_), seqLen);
|
||||
|
||||
indices.resize(seqLen, 0);
|
||||
std::iota(begin(indices), end(indices), 0.);
|
||||
std::vector<real> tmpScore(scores + starts[i], scores + starts[i + 1]);
|
||||
std::partial_sort(
|
||||
begin(indices),
|
||||
begin(indices) + k,
|
||||
end(indices),
|
||||
[&](size_t a, size_t b) { return tmpScore[a] > tmpScore[b]; });
|
||||
memcpy(sortedIds + (i * beamSize_), indices.data(), k * sizeof(real));
|
||||
}
|
||||
}
|
||||
|
||||
void KmaxSeqScoreLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
|
||||
const Argument& input = getInput(0);
|
||||
const MatrixPtr inputScore = getInputValue(0);
|
||||
|
||||
CHECK(input.hasSeq() || input.hasSubseq())
|
||||
<< "input of " << getName()
|
||||
<< " must be a sequence or a nested sequence.";
|
||||
CHECK_EQ(input.value->getWidth(), 1UL)
|
||||
<< "input of " << getName()
|
||||
<< " is score over a sequence or a nested sequence, so its width "
|
||||
<< " must be 1.";
|
||||
|
||||
if (useGpu_) {
|
||||
// this Layer runs only in CPU, if the model is runing on GPU,
|
||||
// then copy the input to this layer from GPU to CPU.
|
||||
Matrix::resizeOrCreate(scores_,
|
||||
inputScore->getHeight(),
|
||||
1,
|
||||
false /* trans */,
|
||||
false /* useGpu */);
|
||||
scores_->copyFrom(*inputScore);
|
||||
} else {
|
||||
scores_ = inputScore;
|
||||
}
|
||||
|
||||
Matrix::resizeOrCreate(
|
||||
output_.value,
|
||||
input.hasSubseq() ? input.getNumSubSequences() : input.getNumSequences(),
|
||||
beamSize_,
|
||||
false,
|
||||
false);
|
||||
output_.value->one();
|
||||
output_.value->mulScalar(-1.);
|
||||
|
||||
kmaxScorePerSeq(scores_->getData(),
|
||||
output_.value->getData(),
|
||||
input.hasSubseq() ? input.subSequenceStartPositions
|
||||
: input.sequenceStartPositions);
|
||||
}
|
||||
|
||||
void KmaxSeqScoreLayer::backward(const UpdateCallback& callback) {}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,160 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ModelConfig.pb.h"
|
||||
#include "paddle/gserver/layers/DataLayer.h"
|
||||
#include "paddle/trainer/Trainer.h"
|
||||
#include "paddle/utils/GlobalConstants.h"
|
||||
|
||||
#include "LayerGradUtil.h"
|
||||
#include "paddle/testing/TestUtil.h"
|
||||
|
||||
using namespace paddle; // NOLINT
|
||||
using namespace std; // NOLINT
|
||||
|
||||
DECLARE_bool(use_gpu);
|
||||
DECLARE_int32(gpu_id);
|
||||
DECLARE_bool(thread_local_rand_use_global_seed);
|
||||
|
||||
vector<int> randSampling(int range, int n) {
|
||||
CHECK_GE(range, n);
|
||||
vector<int> num(range);
|
||||
iota(begin(num), end(num), 0);
|
||||
if (range == n) return num;
|
||||
|
||||
random_shuffle(begin(num), end(num));
|
||||
num.resize(n);
|
||||
return num;
|
||||
}
|
||||
|
||||
void genRandomSeqInfo(vector<int>& seqStartPosition,
|
||||
vector<int>& subSeqStartPosition) {
|
||||
const int maxSeqNum = 100;
|
||||
// generate random start position information
|
||||
int seqNum = 1 + (rand() % maxSeqNum);
|
||||
seqStartPosition.resize(seqNum + 1, 0);
|
||||
subSeqStartPosition.resize(1, 0);
|
||||
|
||||
for (int i = 0; i < seqNum; ++i) {
|
||||
int subSeqLen = 1 + (rand() % maxSeqNum);
|
||||
for (int j = 0; j < subSeqLen; ++j)
|
||||
subSeqStartPosition.push_back(subSeqStartPosition.back() + subSeqLen);
|
||||
seqStartPosition[i + 1] = subSeqStartPosition.back();
|
||||
}
|
||||
}
|
||||
|
||||
void genRandomGroundTruth(real* values,
|
||||
vector<vector<int>>& groundTruth,
|
||||
vector<int>& startPos,
|
||||
size_t beamSize) {
|
||||
groundTruth.resize(startPos.size() - 1, vector<int>(beamSize, -1));
|
||||
for (size_t i = 0; i < startPos.size() - 1; ++i) {
|
||||
int seqLen = startPos[i + 1] - startPos[i];
|
||||
vector<int> pos =
|
||||
randSampling(seqLen, min(static_cast<int>(beamSize), seqLen));
|
||||
for (size_t j = 0; j < pos.size(); ++j) {
|
||||
groundTruth[i][j] = pos[j];
|
||||
values[startPos[i] + pos[j]] = 1.;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void checkLayerOut(vector<vector<int>> groundTruth,
|
||||
real* layerOut,
|
||||
size_t beamSize) {
|
||||
for (size_t i = 0; i < groundTruth.size(); ++i) {
|
||||
int begPos = i * beamSize;
|
||||
vector<real> tmp(layerOut + begPos, layerOut + begPos + beamSize);
|
||||
sort(begin(tmp), end(tmp));
|
||||
sort(begin(groundTruth[i]), end(groundTruth[i]));
|
||||
for (size_t j = 0; j < beamSize; ++j) CHECK_EQ(tmp[j], groundTruth[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Layer, kmaxSeqScoreLayer) {
|
||||
const size_t maxBeamSize = 100;
|
||||
int beamSize = 1 + (rand() % maxBeamSize);
|
||||
|
||||
vector<int> seqStartPosition;
|
||||
vector<int> subSeqStartPosition;
|
||||
genRandomSeqInfo(seqStartPosition, subSeqStartPosition);
|
||||
MatrixPtr inValue =
|
||||
Matrix::create(subSeqStartPosition.back(), 1, false, false);
|
||||
|
||||
for (auto hasSubseq : {false, true}) {
|
||||
vector<vector<int>> groundTruth;
|
||||
inValue->randomizeUniform();
|
||||
genRandomGroundTruth(inValue->getData(),
|
||||
groundTruth,
|
||||
hasSubseq ? subSeqStartPosition : seqStartPosition,
|
||||
beamSize);
|
||||
|
||||
for (auto useGpu : {false, true}) {
|
||||
TestConfig config;
|
||||
config.layerConfig.set_type("kmax_seq_score");
|
||||
config.layerConfig.set_beam_size(beamSize);
|
||||
|
||||
if (hasSubseq) {
|
||||
config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA,
|
||||
"scores",
|
||||
inValue,
|
||||
seqStartPosition,
|
||||
subSeqStartPosition});
|
||||
} else {
|
||||
config.inputDefs.push_back(
|
||||
{INPUT_SELF_DEFINE_DATA, "scores", inValue, seqStartPosition});
|
||||
}
|
||||
config.layerConfig.add_inputs();
|
||||
|
||||
// data layer initialize
|
||||
std::vector<DataLayerPtr> dataLayers;
|
||||
LayerMap layerMap;
|
||||
vector<Argument> datas;
|
||||
initDataLayer(
|
||||
config,
|
||||
&dataLayers,
|
||||
&datas,
|
||||
&layerMap,
|
||||
"kmax_seq_score",
|
||||
100 /* actually this parameter is unused in self-defined input*/,
|
||||
false,
|
||||
useGpu);
|
||||
// test layer initialize
|
||||
std::vector<ParameterPtr> parameters;
|
||||
LayerPtr kmaxSeqScoreLayer;
|
||||
FLAGS_use_gpu = useGpu;
|
||||
initTestLayer(config, &layerMap, ¶meters, &kmaxSeqScoreLayer);
|
||||
kmaxSeqScoreLayer->forward(PASS_TRAIN);
|
||||
|
||||
const MatrixPtr outValue = kmaxSeqScoreLayer->getOutputValue();
|
||||
CHECK_EQ(outValue->getHeight(),
|
||||
hasSubseq ? subSeqStartPosition.size() - 1
|
||||
: seqStartPosition.size() - 1);
|
||||
CHECK_EQ(outValue->getWidth(), beamSize);
|
||||
checkLayerOut(groundTruth, outValue->getData(), beamSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
initMain(argc, argv);
|
||||
FLAGS_thread_local_rand_use_global_seed = true;
|
||||
srand((size_t)(time(NULL)));
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -0,0 +1,84 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <random>
|
||||
#include <type_traits>
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
|
||||
// Use std::random and thrust::random(thrust is a std library in CUDA) to
|
||||
// implement uniform random.
|
||||
template <typename T>
|
||||
class CPUUniformRandomKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* tensor = context.Output<framework::Tensor>(0);
|
||||
T* data = tensor->mutable_data<T>(context.GetPlace());
|
||||
unsigned int seed =
|
||||
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
|
||||
std::minstd_rand engine;
|
||||
if (seed == 0) {
|
||||
seed = std::random_device()();
|
||||
}
|
||||
engine.seed(seed);
|
||||
std::uniform_real_distribution<T> dist(
|
||||
static_cast<T>(context.op_.GetAttr<float>("min")),
|
||||
static_cast<T>(context.op_.GetAttr<float>("max")));
|
||||
for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) {
|
||||
data[i] = dist(engine);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class UniformRandomOp : public framework::OperatorWithKernel {
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext& ctx) const override {
|
||||
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
|
||||
"uniform_random's min must less then max");
|
||||
auto* tensor = ctx.Output<framework::Tensor>(0);
|
||||
auto dims = GetAttr<std::vector<int>>("dims");
|
||||
tensor->Resize(framework::make_ddim(dims));
|
||||
}
|
||||
};
|
||||
|
||||
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
UniformRandomOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddOutput("Out", "The output tensor of uniform random op");
|
||||
AddComment(R"DOC(Uniform random operator.
|
||||
|
||||
Used to initialize tensor with uniform random generator.
|
||||
)DOC");
|
||||
AddAttr<std::vector<int>>("dims", "the dimension of random tensor");
|
||||
AddAttr<float>("min", "Minimum value of uniform random").SetDefault(-1.0f);
|
||||
AddAttr<float>("max", "Maximun value of uniform random").SetDefault(1.0f);
|
||||
AddAttr<int>("seed",
|
||||
"Random seed of uniform random. "
|
||||
"0 means generate a seed by system")
|
||||
.SetDefault(0);
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP(uniform_random, paddle::operators::UniformRandomOp,
|
||||
paddle::operators::UniformRandomOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(uniform_random,
|
||||
paddle::operators::CPUUniformRandomKernel<float>);
|
@ -0,0 +1,70 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/random.h>
|
||||
#include <thrust/transform.h>
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct UniformGenerator {
|
||||
T min_, max_;
|
||||
unsigned int seed_;
|
||||
|
||||
__host__ __device__ UniformGenerator(T min, T max, int seed)
|
||||
: min_(min), max_(max), seed_(seed) {}
|
||||
|
||||
__host__ __device__ T operator()(const unsigned int n) const {
|
||||
thrust::minstd_rand rng;
|
||||
rng.seed(seed_);
|
||||
thrust::uniform_real_distribution<T> dist(min_, max_);
|
||||
rng.discard(n);
|
||||
return dist(rng);
|
||||
}
|
||||
};
|
||||
|
||||
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
|
||||
// Use std::random and thrust::random(thrust is a std library in CUDA) to
|
||||
// implement uniform random.
|
||||
template <typename T>
|
||||
class GPUUniformRandomKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* tensor = context.Output<framework::Tensor>(0);
|
||||
T* data = tensor->mutable_data<T>(context.GetPlace());
|
||||
unsigned int seed =
|
||||
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
|
||||
if (seed == 0) {
|
||||
seed = std::random_device()();
|
||||
}
|
||||
T min = static_cast<T>(context.op_.GetAttr<float>("min"));
|
||||
T max = static_cast<T>(context.op_.GetAttr<float>("max"));
|
||||
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
|
||||
ssize_t N = framework::product(tensor->dims());
|
||||
thrust::transform(index_sequence_begin, index_sequence_begin + N,
|
||||
thrust::device_ptr<T>(data),
|
||||
UniformGenerator<T>(min, max, seed));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(uniform_random,
|
||||
paddle::operators::GPUUniformRandomKernel<float>);
|
@ -0,0 +1,66 @@
|
||||
type: "nn"
|
||||
layers {
|
||||
name: "input"
|
||||
type: "data"
|
||||
size: 300
|
||||
active_type: ""
|
||||
}
|
||||
layers {
|
||||
name: "data"
|
||||
type: "data"
|
||||
size: 128
|
||||
active_type: ""
|
||||
}
|
||||
layers {
|
||||
name: "__fc_layer_0__"
|
||||
type: "fc"
|
||||
size: 1
|
||||
active_type: "exponential"
|
||||
inputs {
|
||||
input_layer_name: "data"
|
||||
input_parameter_name: "___fc_layer_0__.w0"
|
||||
}
|
||||
bias_parameter_name: "___fc_layer_0__.wbias"
|
||||
}
|
||||
layers {
|
||||
name: "__kmax_sequence_score_layer_0__"
|
||||
type: "kmax_seq_score"
|
||||
active_type: ""
|
||||
inputs {
|
||||
input_layer_name: "__fc_layer_0__"
|
||||
}
|
||||
beam_size: 5
|
||||
}
|
||||
parameters {
|
||||
name: "___fc_layer_0__.w0"
|
||||
size: 128
|
||||
initial_mean: 0.0
|
||||
initial_std: 0.0883883476483
|
||||
dims: 128
|
||||
dims: 1
|
||||
initial_strategy: 0
|
||||
initial_smart: true
|
||||
}
|
||||
parameters {
|
||||
name: "___fc_layer_0__.wbias"
|
||||
size: 1
|
||||
initial_mean: 0.0
|
||||
initial_std: 0.0
|
||||
dims: 1
|
||||
dims: 1
|
||||
initial_strategy: 0
|
||||
initial_smart: false
|
||||
}
|
||||
input_layer_names: "data"
|
||||
output_layer_names: "__kmax_sequence_score_layer_0__"
|
||||
sub_models {
|
||||
name: "root"
|
||||
layer_names: "input"
|
||||
layer_names: "data"
|
||||
layer_names: "__fc_layer_0__"
|
||||
layer_names: "__kmax_sequence_score_layer_0__"
|
||||
input_layer_names: "data"
|
||||
output_layer_names: "__kmax_sequence_score_layer_0__"
|
||||
is_recurrent_layer_group: false
|
||||
}
|
||||
|
@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
#coding=utf-8
|
||||
from paddle.trainer_config_helpers import *
|
||||
|
||||
data = data_layer(name='input', size=300)
|
||||
|
||||
data = data_layer(name="data", size=128)
|
||||
scores = fc_layer(input=data, size=1, act=ExpActivation())
|
||||
kmax_seq_id = kmax_sequence_score_layer(input=scores, beam_size=5)
|
||||
|
||||
outputs(kmax_seq_id)
|
@ -0,0 +1,35 @@
|
||||
import unittest
|
||||
from paddle.v2.framework.op import Operator
|
||||
import paddle.v2.framework.core as core
|
||||
import numpy
|
||||
|
||||
|
||||
class UniformRandomTest(unittest.TestCase):
|
||||
def test_uniform_random_cpu(self):
|
||||
self.uniform_random_test(place=core.CPUPlace())
|
||||
|
||||
def test_uniform_random_gpu(self):
|
||||
if core.is_compile_gpu():
|
||||
self.uniform_random_test(place=core.GPUPlace(0))
|
||||
|
||||
def uniform_random_test(self, place):
|
||||
scope = core.Scope()
|
||||
scope.new_var("X").get_tensor()
|
||||
|
||||
op = Operator(
|
||||
"uniform_random",
|
||||
Out="X",
|
||||
dims=[1000, 784],
|
||||
min=-5.0,
|
||||
max=10.0,
|
||||
seed=10)
|
||||
|
||||
op.infer_shape(scope)
|
||||
ctx = core.DeviceContext.create(place)
|
||||
op.run(scope, ctx)
|
||||
tensor = numpy.array(scope.find_var("X").get_tensor())
|
||||
self.assertAlmostEqual(tensor.mean(), 2.5, delta=0.1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue