commit
46ccfc0171
@ -0,0 +1,225 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "RowConvOp.h"
|
||||
#include <iostream>
|
||||
#include "paddle/math/Vector.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
template <>
|
||||
void RowConv<DEVICE_TYPE_CPU>(CpuMatrix& out,
|
||||
const CpuMatrix& in,
|
||||
const CpuMatrix& filter,
|
||||
const CpuIVector& seq) {
|
||||
const int* starts = seq.getData();
|
||||
const size_t numSeq = seq.getSize() - 1;
|
||||
const size_t contextLength = filter.getHeight();
|
||||
for (size_t i = 0; i < numSeq; ++i) {
|
||||
size_t begin = starts[i];
|
||||
size_t end = starts[i + 1];
|
||||
for (size_t j = begin; j < end; ++j) {
|
||||
MatrixPtr x;
|
||||
MatrixPtr w;
|
||||
if ((j + contextLength) < end) {
|
||||
x = (const_cast<CpuMatrix&>(in)).subMatrix(j, contextLength);
|
||||
w = (const_cast<CpuMatrix&>(filter)).subMatrix(0, contextLength);
|
||||
} else {
|
||||
x = (const_cast<CpuMatrix&>(in)).subMatrix(j, end - j);
|
||||
w = (const_cast<CpuMatrix&>(filter)).subMatrix(0, end - j);
|
||||
}
|
||||
MatrixPtr y = out.subMatrix(j, 1);
|
||||
y->addDotMulVMM(*x, *w);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void RowConvGrad<DEVICE_TYPE_CPU>(const CpuMatrix& outG,
|
||||
const CpuMatrix& in,
|
||||
const CpuMatrix& filter,
|
||||
CpuMatrix& inG,
|
||||
CpuMatrix& filterG,
|
||||
const CpuIVector& seq) {
|
||||
// gradient w.r.t filter
|
||||
const int* starts = seq.getData();
|
||||
const size_t numSeq = seq.getSize() - 1;
|
||||
const size_t contextLength = filter.getHeight();
|
||||
if (filterG) {
|
||||
for (size_t i = 0; i < numSeq; ++i) {
|
||||
size_t begin = starts[i];
|
||||
size_t end = starts[i + 1];
|
||||
size_t steps = end - begin;
|
||||
for (size_t j = 0; j < contextLength && (begin + j) < end; ++j) {
|
||||
MatrixPtr x =
|
||||
(const_cast<CpuMatrix&>(in)).subMatrix(begin + j, steps - j);
|
||||
MatrixPtr dy =
|
||||
(const_cast<CpuMatrix&>(outG)).subMatrix(begin, steps - j);
|
||||
MatrixPtr dw = filterG.subMatrix(j, 1);
|
||||
dw->addDotMulVMM(*dy, *x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gradient w.r.t input feature
|
||||
if (inG) {
|
||||
for (size_t i = 0; i < numSeq; ++i) {
|
||||
size_t begin = starts[i];
|
||||
size_t end = starts[i + 1];
|
||||
size_t steps = end - begin;
|
||||
for (size_t j = 0; j < steps; ++j) {
|
||||
MatrixPtr dx = inG.subMatrix(begin + j, 1);
|
||||
for (size_t t = 0; t < contextLength; ++t) {
|
||||
if (int(j - t) >= 0) {
|
||||
MatrixPtr dy =
|
||||
(const_cast<CpuMatrix&>(outG)).subMatrix(begin + j - t, 1);
|
||||
MatrixPtr w = (const_cast<CpuMatrix&>(filter)).subMatrix(t, 1);
|
||||
dx->addDotMul(*dy, *w, 1.0, 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief The row convolution is called lookahead convolution. It is firstly
|
||||
* introduced in deep-speech2 system. The bidirectional RNN that learns
|
||||
* representation for a sequence by performing a forward and a backward pass
|
||||
* through the entire sequence. However, unlike unidirectional RNNs,
|
||||
* bidirectional RNNs are challenging to deploy in an online and low-latency
|
||||
* setting. The lookahead convolution incorporates information from future
|
||||
* subsequences in a computationally efficient manner to improve unidirectional
|
||||
* recurrent neural networks.
|
||||
*
|
||||
* The connection of row convolution is different form the 1D sequence
|
||||
* convolution. Assumed that, the future context-length is k, that is to say,
|
||||
* it can get the output at timestep t by using the the input feature from t-th
|
||||
* timestep to (t+k)-th timestep. Assumed that the hidden dim of input
|
||||
* activations are d, the activations r_t for the new layer at time-step t are:
|
||||
*
|
||||
*
|
||||
* -- k + 1
|
||||
* r(t,i) = > W(i,j) * h(t+j-1, i), for (1 <= i <= d)
|
||||
* -- j = 1
|
||||
*
|
||||
*
|
||||
* The weight shape is: (k + 1) x d
|
||||
* Function Arguments:
|
||||
*
|
||||
* \param inputs[0] The input activations.
|
||||
* \param inputs[0] The filter (or weight) and shape is (k+1) x d.
|
||||
* \param outputs[1] The output activations.
|
||||
*
|
||||
* [1] Dario Amodei, etc. Deep Speech 2 : End-to-End Speech Recognition in
|
||||
* English
|
||||
* and Mandarin. https://arxiv.org/abs/1512.02595
|
||||
*/
|
||||
|
||||
template <DeviceType Device>
|
||||
class RowConvFunc : public FunctionBase {
|
||||
public:
|
||||
void init(const FuncConfig& config) override {}
|
||||
|
||||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
// check
|
||||
CHECK_EQ(2UL, inputs.size());
|
||||
CHECK_EQ(1UL, outputs.size());
|
||||
// TODO(qingqing): support ASSIGN_TO.
|
||||
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
||||
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
|
||||
<< "SequenceArg required here.";
|
||||
const auto in = dynamic_cast<const SequenceArg&>(inputs[0]);
|
||||
auto out = dynamic_cast<const SequenceArg&>(outputs[0]);
|
||||
auto w = inputs[1];
|
||||
CHECK(in.data() && out.data() && in.getSequenceId().data());
|
||||
CHECK_EQ(in.shape().ndims(), 2UL);
|
||||
CHECK(in.shape() == out.shape());
|
||||
CHECK_EQ(w.shape()[1], in.shape()[1]);
|
||||
|
||||
auto outMat = out.matrix<Device>();
|
||||
const auto inMat = in.matrix<Device>();
|
||||
const auto wMat = w.matrix<Device>();
|
||||
const auto seqId = in.getSequenceId().vector<int, Device>();
|
||||
|
||||
RowConv<Device>(outMat, inMat, wMat, seqId);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief The backward of row convolution function. This function calculated
|
||||
* the gradient w.r.t filter and the gradient w.r.t input activations(or data).
|
||||
*
|
||||
* Argument in this Function:
|
||||
*
|
||||
* \param inputs[0] The gradient w.r.t output activations.
|
||||
* \param inputs[1] The input activations.
|
||||
* \param inputs[2] The filter (or weight) and shape is (k+1) x d.
|
||||
* \param outputs[0] The gradient w.r.t input activations.
|
||||
* \param outputs[1] The gradient w.r.r filter.
|
||||
*
|
||||
* Abbreviation:
|
||||
* w.r.t: with respect to.
|
||||
*/
|
||||
|
||||
template <DeviceType Device>
|
||||
class RowConvGradFunc : public FunctionBase {
|
||||
// TODO(qingqing): split into RowConvDataFunc and RowConvWeightFunc
|
||||
public:
|
||||
void init(const FuncConfig& config) override {}
|
||||
|
||||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||
// check
|
||||
CHECK_EQ(3UL, inputs.size());
|
||||
CHECK_EQ(2UL, outputs.size());
|
||||
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
||||
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
|
||||
CHECK(inputs[0].isSequenceArg() && inputs[1].isSequenceArg() &&
|
||||
outputs[0].isSequenceArg())
|
||||
<< "SequenceArg required here.";
|
||||
|
||||
const auto outGrad = dynamic_cast<const SequenceArg&>(inputs[0]);
|
||||
const auto in = dynamic_cast<const SequenceArg&>(inputs[1]);
|
||||
const auto w = inputs[2];
|
||||
auto inGrad = dynamic_cast<const SequenceArg&>(outputs[0]);
|
||||
auto wGrad = outputs[1];
|
||||
|
||||
CHECK_EQ(in.shape().ndims(), 2UL);
|
||||
CHECK(in.shape() == inGrad.shape());
|
||||
CHECK(in.shape() == outGrad.shape());
|
||||
CHECK_EQ(wGrad.shape()[1], in.shape()[1]);
|
||||
|
||||
const auto outGMat = outGrad.matrix<Device>();
|
||||
const auto inMat = in.matrix<Device>();
|
||||
const auto wMat = w.matrix<Device>();
|
||||
auto inGMat = inGrad.data()
|
||||
? inGrad.matrix<Device>()
|
||||
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
|
||||
auto wGMat = wGrad.data()
|
||||
? wGrad.matrix<Device>()
|
||||
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
|
||||
const auto seqId = in.getSequenceId().vector<int, Device>();
|
||||
|
||||
RowConvGrad<Device>(outGMat, inMat, wMat, inGMat, wGMat, seqId);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_TYPED_FUNC(RowConv, CPU, RowConvFunc);
|
||||
REGISTER_TYPED_FUNC(RowConvGrad, CPU, RowConvGradFunc);
|
||||
#ifndef PADDLE_ONLY_CPU
|
||||
REGISTER_TYPED_FUNC(RowConv, GPU, RowConvFunc);
|
||||
REGISTER_TYPED_FUNC(RowConvGrad, GPU, RowConvGradFunc);
|
||||
#endif
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,56 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Function.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* \brief The forward of row convolution.
|
||||
*
|
||||
* \param[out] out The output data and shape is h x d. h is the sum of
|
||||
* time steps of all samples in one mini-batch.
|
||||
* \param[in] in The input data and shape is h x d.
|
||||
* \param[in] filter The filter and shape is k x d. The lookahead step
|
||||
* number plus one equals k.
|
||||
* \param[in] seq The sequence start positions.
|
||||
*
|
||||
*/
|
||||
template <DeviceType DType>
|
||||
void RowConv(typename Tensor<real, DType>::Matrix& out,
|
||||
const typename Tensor<real, DType>::Matrix& in,
|
||||
const typename Tensor<real, DType>::Matrix& filter,
|
||||
const typename Tensor<int, DType>::Vector& seq);
|
||||
|
||||
/**
|
||||
* \brief The backward of row convolution.
|
||||
*
|
||||
* \param[in] outG The gradient w.r.t output data.
|
||||
* \param[in] in The input data.
|
||||
* \param[in] filter The filter.
|
||||
* \param[out] inG The gradient w.r.t input data.
|
||||
* \param[out] filterG The gradient w.r.t filter.
|
||||
* \param[in] seq The sequence start positions.
|
||||
*
|
||||
*/
|
||||
template <DeviceType DType>
|
||||
void RowConvGrad(const typename Tensor<real, DType>::Matrix& outG,
|
||||
const typename Tensor<real, DType>::Matrix& in,
|
||||
const typename Tensor<real, DType>::Matrix& filter,
|
||||
typename Tensor<real, DType>::Matrix& inG,
|
||||
typename Tensor<real, DType>::Matrix& filterG,
|
||||
const typename Tensor<int, DType>::Vector& seq);
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,62 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "FunctionTest.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) {
|
||||
FunctionCompare test("RowConv", FuncConfig());
|
||||
|
||||
test.addSequence(SequenceIdArg(TensorShape{batchSize}));
|
||||
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{contextLength, dim}));
|
||||
|
||||
test.addOutputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}),
|
||||
ADD_TO);
|
||||
|
||||
test.run();
|
||||
}
|
||||
|
||||
void testRowConvBw(size_t batchSize, size_t dim, size_t contextLength) {
|
||||
FunctionCompare test("RowConvGrad", FuncConfig());
|
||||
|
||||
test.addSequence(SequenceIdArg(TensorShape{batchSize}));
|
||||
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
|
||||
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
|
||||
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{contextLength, dim}));
|
||||
|
||||
test.addOutputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}),
|
||||
ADD_TO);
|
||||
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{contextLength, dim}),
|
||||
ADD_TO);
|
||||
|
||||
test.run();
|
||||
}
|
||||
|
||||
TEST(RowConv, real) {
|
||||
for (size_t numSamples : {17, 129, 2020}) {
|
||||
for (size_t dim : {16, 512, 2560}) {
|
||||
for (size_t context : {3, 19, 65}) {
|
||||
VLOG(3) << " numSamples=" << numSamples << " dim=" << dim
|
||||
<< " context length=" << context;
|
||||
testRowConvFw(numSamples, dim, context);
|
||||
testRowConvBw(numSamples, dim, context);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,106 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "RowConvLayer.h"
|
||||
#include "paddle/utils/Stat.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
REGISTER_LAYER(row_conv, RowConvLayer);
|
||||
|
||||
bool RowConvLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
/* Initialize the basic parent class */
|
||||
Layer::init(layerMap, parameterMap);
|
||||
|
||||
contexLength_ = config_.inputs(0).row_conv_conf().context_length();
|
||||
|
||||
CHECK_EQ(inputLayers_.size(), 1UL);
|
||||
weight_.reset(new Weight(contexLength_, getSize(), parameters_[0]));
|
||||
createFunction(forward_, "RowConv", FuncConfig());
|
||||
createFunction(backward_, "RowConvGrad", FuncConfig());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void RowConvLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
MatrixPtr input = getInputValue(0);
|
||||
size_t height = input->getHeight();
|
||||
size_t width = input->getWidth();
|
||||
CHECK_EQ(width, getSize());
|
||||
resetOutput(height, width);
|
||||
|
||||
const auto startPos = getInput(0).sequenceStartPositions->getVector(useGpu_);
|
||||
MatrixPtr w = weight_->getW();
|
||||
wDims_ = TensorShape({w->getHeight(), w->getWidth()});
|
||||
|
||||
MatrixPtr outV = getOutputValue();
|
||||
BufferArgs inputs;
|
||||
BufferArgs outputs;
|
||||
inputs.addArg(*getInputValue(0), *startPos);
|
||||
inputs.addArg(*w, wDims_);
|
||||
outputs.addArg(*getOutputValue(), *startPos, ADD_TO);
|
||||
|
||||
{
|
||||
REGISTER_TIMER_INFO("RowConvForward", getName().c_str());
|
||||
forward_[0]->calc(inputs, outputs);
|
||||
}
|
||||
|
||||
/* activation */ {
|
||||
REGISTER_TIMER_INFO("FwAtvTimer", getName().c_str());
|
||||
forwardActivation();
|
||||
}
|
||||
}
|
||||
|
||||
void RowConvLayer::backward(const UpdateCallback& callback) {
|
||||
/* Do derivation */ {
|
||||
REGISTER_TIMER_INFO("BpAvtTimer", getName().c_str());
|
||||
backwardActivation();
|
||||
}
|
||||
|
||||
const auto startPos = getInput(0).sequenceStartPositions->getVector(useGpu_);
|
||||
|
||||
BufferArgs inputs;
|
||||
BufferArgs outputs;
|
||||
inputs.addArg(*getOutputGrad(), *startPos);
|
||||
inputs.addArg(*getInputValue(0), *startPos);
|
||||
inputs.addArg(*weight_->getW(), wDims_);
|
||||
|
||||
MatrixPtr inGrad = getInputGrad(0);
|
||||
MatrixPtr wGrad = weight_->getWGrad();
|
||||
size_t h = getInputValue(0)->getHeight();
|
||||
size_t w = getInputValue(0)->getWidth();
|
||||
outputs.addArg(
|
||||
inGrad ? (*inGrad) : *(Matrix::create(nullptr, h, w, false, useGpu_)),
|
||||
*startPos,
|
||||
ADD_TO);
|
||||
outputs.addArg(
|
||||
wGrad ? (*wGrad)
|
||||
: *(Matrix::create(nullptr, contexLength_, w, false, useGpu_)),
|
||||
wDims_,
|
||||
ADD_TO);
|
||||
|
||||
{
|
||||
REGISTER_TIMER_INFO("RowConvBackward", getName().c_str());
|
||||
backward_[0]->calc(inputs, outputs);
|
||||
}
|
||||
|
||||
{
|
||||
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
|
||||
weight_->getParameterPtr()->incUpdate(callback);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,44 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Layer.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* \brief Row Convolution Layer.
|
||||
*/
|
||||
class RowConvLayer : public Layer {
|
||||
public:
|
||||
explicit RowConvLayer(const LayerConfig& config) : Layer(config) {}
|
||||
|
||||
~RowConvLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
void forward(PassType passType) override;
|
||||
void backward(const UpdateCallback& callback = nullptr) override;
|
||||
|
||||
protected:
|
||||
// Row convolution weight, context_lenght_ * fan_out.
|
||||
// fan_out is the size of output feature.
|
||||
std::unique_ptr<Weight> weight_;
|
||||
|
||||
// The step number to look ahead plus one equals contexLength_.
|
||||
size_t contexLength_;
|
||||
TensorShape wDims_;
|
||||
};
|
||||
} // namespace paddle
|
@ -0,0 +1,2 @@
|
||||
cc_library(stringpiece SRCS stringpiece.cc)
|
||||
cc_test(stringpiece_test SRCS stringpiece_test.cc DEPS stringpiece glog gflags)
|
@ -0,0 +1,141 @@
|
||||
/*
|
||||
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
#include "paddle/strings/stringpiece.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace paddle {
|
||||
|
||||
StringPiece::StringPiece() : data_(NULL), size_(0) {}
|
||||
|
||||
StringPiece::StringPiece(const char* d, size_t n) : data_(d), size_(n) {
|
||||
if (d == NULL && n != 0)
|
||||
throw std::invalid_argument(
|
||||
"StringPiece requires len to be 0 for NULL data");
|
||||
}
|
||||
|
||||
StringPiece::StringPiece(const char* s) : data_(s) {
|
||||
size_ = (s == NULL) ? 0 : strlen(s);
|
||||
}
|
||||
|
||||
StringPiece::StringPiece(const std::string& s)
|
||||
: data_(s.data()), size_(s.size()) {}
|
||||
|
||||
char StringPiece::operator[](size_t n) const {
|
||||
if (n >= len())
|
||||
throw std::invalid_argument("index out of StringPiece length");
|
||||
return data_[n];
|
||||
}
|
||||
|
||||
int Compare(StringPiece a, StringPiece b) {
|
||||
const size_t min_len = (a.len() < b.len()) ? a.len() : b.len();
|
||||
int r = memcmp(a.data(), b.data(), min_len);
|
||||
if (r == 0) {
|
||||
if (a.len() < b.len())
|
||||
return -1;
|
||||
else if (a.len() > b.len())
|
||||
return 1;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
bool operator==(StringPiece x, StringPiece y) {
|
||||
return ((x.len() == y.len()) &&
|
||||
(x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0));
|
||||
}
|
||||
|
||||
bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
|
||||
|
||||
bool operator<(StringPiece x, StringPiece y) { return Compare(x, y) < 0; }
|
||||
bool operator>(StringPiece x, StringPiece y) { return Compare(x, y) > 0; }
|
||||
|
||||
bool operator<=(StringPiece x, StringPiece y) { return Compare(x, y) <= 0; }
|
||||
bool operator>=(StringPiece x, StringPiece y) { return Compare(x, y) >= 0; }
|
||||
|
||||
bool HasPrefix(StringPiece s, StringPiece x) {
|
||||
return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0));
|
||||
}
|
||||
|
||||
bool HasSuffix(StringPiece s, StringPiece x) {
|
||||
return ((s.len() >= x.len()) &&
|
||||
(memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0));
|
||||
}
|
||||
|
||||
StringPiece SkipPrefix(StringPiece s, size_t n) {
|
||||
if (n > s.len())
|
||||
throw std::invalid_argument("Skip distance larger than StringPiece length");
|
||||
return StringPiece(s.data() + n, s.len() - n);
|
||||
}
|
||||
|
||||
StringPiece SkipSuffix(StringPiece s, size_t n) {
|
||||
if (n > s.len())
|
||||
throw std::invalid_argument("Skip distance larger than StringPiece length");
|
||||
return StringPiece(s.data(), s.len() - n);
|
||||
}
|
||||
|
||||
StringPiece TrimPrefix(StringPiece s, StringPiece x) {
|
||||
return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s;
|
||||
}
|
||||
|
||||
StringPiece TrimSuffix(StringPiece s, StringPiece x) {
|
||||
return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s;
|
||||
}
|
||||
|
||||
bool Contains(StringPiece s, StringPiece sub) {
|
||||
return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end();
|
||||
}
|
||||
|
||||
size_t Index(StringPiece s, StringPiece sub) {
|
||||
auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end());
|
||||
return e != s.end() ? e - s.data() : StringPiece::npos;
|
||||
}
|
||||
|
||||
size_t Find(StringPiece s, char c, size_t pos) {
|
||||
if (pos >= s.len()) {
|
||||
return StringPiece::npos;
|
||||
}
|
||||
const char* result =
|
||||
reinterpret_cast<const char*>(memchr(s.data() + pos, c, s.len() - pos));
|
||||
return result != nullptr ? result - s.data() : StringPiece::npos;
|
||||
}
|
||||
|
||||
size_t RFind(StringPiece s, char c, size_t pos) {
|
||||
if (s.len() == 0) return StringPiece::npos;
|
||||
for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data();
|
||||
p--) {
|
||||
if (*p == c) {
|
||||
return p - s.data();
|
||||
}
|
||||
}
|
||||
return StringPiece::npos;
|
||||
}
|
||||
|
||||
StringPiece SubStr(StringPiece s, size_t pos, size_t n) {
|
||||
if (pos > s.len()) pos = s.len();
|
||||
if (n > s.len() - pos) n = s.len() - pos;
|
||||
return StringPiece(s.data() + pos, n);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& o, StringPiece piece) {
|
||||
return o << piece.ToString();
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,105 @@
|
||||
/*
|
||||
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
|
||||
// StringPiece points into a std::string object but doesn't own the
|
||||
// string. It is for efficient access to strings. Like Go's string
|
||||
// type. Not that StringPiece doesn't mutate the underlying string,
|
||||
// so it is thread-safe given that the underlying string doesn't
|
||||
// change. Because StringPiece contains a little data members, and
|
||||
// its syntax is simple as it doesn't own/manage the string, it is
|
||||
// cheap to construct StringPieces and pass them around.
|
||||
class StringPiece {
|
||||
public:
|
||||
static const size_t npos = static_cast<size_t>(-1);
|
||||
|
||||
// We provide non-explicit singleton constructors so users can
|
||||
// pass in a "const char*" or a "string" wherever a "StringPiece"
|
||||
// is expected. These contructors ensure that if data_ is NULL,
|
||||
// size_ is 0.
|
||||
StringPiece();
|
||||
StringPiece(const char* d, size_t n);
|
||||
StringPiece(const char* d);
|
||||
StringPiece(const std::string& s);
|
||||
|
||||
const char* data() const { return data_; }
|
||||
size_t len() const { return size_; }
|
||||
|
||||
char operator[](size_t n) const;
|
||||
|
||||
// StringPiece doesn't own the string, so both iterator and const
|
||||
// iterator are const char* indeed.
|
||||
typedef const char* const_iterator;
|
||||
typedef const char* iterator;
|
||||
iterator begin() const { return data_; }
|
||||
iterator end() const { return data_ + size_; }
|
||||
|
||||
// Return a string that contains the copy of the referenced data.
|
||||
std::string ToString() const { return std::string(data_, size_); }
|
||||
|
||||
private:
|
||||
const char* data_;
|
||||
size_t size_;
|
||||
|
||||
// Intentionally copyable
|
||||
};
|
||||
|
||||
int Compare(StringPiece a, StringPiece b);
|
||||
|
||||
bool operator==(StringPiece x, StringPiece y);
|
||||
bool operator!=(StringPiece x, StringPiece y);
|
||||
bool operator<(StringPiece x, StringPiece y);
|
||||
bool operator>(StringPiece x, StringPiece y);
|
||||
bool operator<=(StringPiece x, StringPiece y);
|
||||
bool operator>=(StringPiece x, StringPiece y);
|
||||
|
||||
bool HasPrefix(StringPiece s, StringPiece prefix);
|
||||
bool HasSuffix(StringPiece s, StringPiece suffix);
|
||||
|
||||
StringPiece SkipPrefix(StringPiece s, size_t n);
|
||||
StringPiece SkipSuffix(StringPiece s, size_t n);
|
||||
|
||||
// Skip the prefix (or suffix) if it matches with the string.
|
||||
StringPiece TrimPrefix(StringPiece s, StringPiece prefix);
|
||||
StringPiece TrimSuffix(StringPiece s, StringPiece suffix);
|
||||
|
||||
// Returns if s contains sub. Any s except for empty s contains an
|
||||
// empty sub.
|
||||
bool Contains(StringPiece s, StringPiece sub);
|
||||
|
||||
// Return the first occurrence of sub in s, or npos. If both s and
|
||||
// sub is empty, it returns npos; otherwise, if only sub is empty, it
|
||||
// returns 0.
|
||||
size_t Index(StringPiece s, StringPiece sub);
|
||||
|
||||
// Return the first occurrence of c in s[pos:end], or npos.
|
||||
size_t Find(StringPiece s, char c, size_t pos);
|
||||
|
||||
// Search range is [0..pos] inclusive. If pos == npos, search everything.
|
||||
size_t RFind(StringPiece s, char c, size_t pos);
|
||||
|
||||
StringPiece SubStr(StringPiece s, size_t pos, size_t n);
|
||||
|
||||
// allow StringPiece to be logged
|
||||
std::ostream& operator<<(std::ostream& o, StringPiece piece);
|
||||
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,36 @@
|
||||
type: "nn"
|
||||
layers {
|
||||
name: "input"
|
||||
type: "data"
|
||||
size: 300
|
||||
active_type: ""
|
||||
}
|
||||
layers {
|
||||
name: "__prelu_layer_0__"
|
||||
type: "prelu"
|
||||
size: 300
|
||||
active_type: ""
|
||||
inputs {
|
||||
input_layer_name: "input"
|
||||
input_parameter_name: "___prelu_layer_0__.w0"
|
||||
}
|
||||
}
|
||||
parameters {
|
||||
name: "___prelu_layer_0__.w0"
|
||||
size: 300
|
||||
initial_mean: 0.0
|
||||
initial_std: 0.057735026919
|
||||
initial_strategy: 0
|
||||
initial_smart: true
|
||||
}
|
||||
input_layer_names: "input"
|
||||
output_layer_names: "__prelu_layer_0__"
|
||||
sub_models {
|
||||
name: "root"
|
||||
layer_names: "input"
|
||||
layer_names: "__prelu_layer_0__"
|
||||
input_layer_names: "input"
|
||||
output_layer_names: "__prelu_layer_0__"
|
||||
is_recurrent_layer_group: false
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue