Merge pull request #1064 from hedaoyuan/buffer

Add BufferArg as the Function argument type and modify the Function prototype to remove the inouts argument.
avx_docs
hedaoyuan 8 years ago committed by GitHub
commit 7df67bae8d

@ -0,0 +1,31 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include "BufferArg.h"
namespace paddle {
const SequenceArg& BufferArg::sequence() const {
// CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA);
return dynamic_cast<const SequenceArg&>(*this);
}
const SparseMatrixArg& BufferArg::sparse() const {
// CHECK_EQ(bufferType_, TENSOR_SPARSE);
return dynamic_cast<const SparseMatrixArg&>(*this);
}
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -0,0 +1,90 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "BufferArg.h"
#include <gtest/gtest.h>
#include "Function.h"
#include "paddle/math/MemoryHandle.h"
namespace paddle {
TEST(BufferTest, BufferArg) {
TensorShape shape({8, 10});
CpuMemoryHandle memory(shape.getElements() *
sizeOfValuType(VALUE_TYPE_FLOAT));
BufferArg buffer(memory.getBuf(), VALUE_TYPE_FLOAT, shape);
EXPECT_EQ(buffer.data(), memory.getBuf());
}
TEST(BufferTest, SequenceIdArg) {
TensorShape shape({10});
CpuMemoryHandle memory(shape.getElements() *
sizeOfValuType(VALUE_TYPE_INT32));
SequenceIdArg buffer(memory.getBuf(), shape);
EXPECT_EQ(buffer.data(), memory.getBuf());
EXPECT_EQ(buffer.numSeqs(), 9);
}
TEST(BufferTest, asArgument) {
MatrixPtr matrix = Matrix::create(100, 200);
VectorPtr vector = Vector::create(100, false);
CpuSparseMatrix sparse(200, 300, 50);
// prepare arguments
BufferArgs argments;
argments.addArg(*matrix);
argments.addArg(*vector);
argments.addArg(sparse);
// function
auto function = [=](const BufferArgs& inputs) {
EXPECT_EQ(inputs.size(), 3);
// check inputs[0]
EXPECT_EQ(inputs[0].shape().ndims(), 2);
EXPECT_EQ(inputs[0].shape()[0], 100);
EXPECT_EQ(inputs[0].shape()[1], 200);
EXPECT_EQ(inputs[0].data(), matrix->getData());
EXPECT_EQ(inputs[0].matrix<DEVICE_TYPE_CPU>().getHeight(),
matrix->getHeight());
EXPECT_EQ(inputs[0].matrix<DEVICE_TYPE_CPU>().getWidth(),
matrix->getWidth());
EXPECT_EQ(inputs[0].matrix<DEVICE_TYPE_CPU>().getData(), matrix->getData());
// check inputs[1]
EXPECT_EQ(inputs[1].shape().ndims(), 1);
EXPECT_EQ(inputs[1].shape()[0], 100);
EXPECT_EQ(inputs[1].data(), vector->getData());
CpuVector inVector = inputs[1].vector<real, DEVICE_TYPE_CPU>();
EXPECT_EQ(inVector.getSize(), vector->getSize());
EXPECT_EQ(inVector.getData(), vector->getData());
// check inputs[2]
EXPECT_EQ(inputs[2].shape().ndims(), 2);
EXPECT_EQ(inputs[2].shape()[0], 200);
EXPECT_EQ(inputs[2].shape()[1], 300);
EXPECT_EQ(inputs[2].data(), sparse.getData());
// CHECK_EQ(inputs[2].sparse().nnz(), 50);
// CHECK_EQ(inputs[2].sparse().dataFormat(), SPARSE_CSR_FORMAT);
// CHECK_EQ(inputs[2].sparse().dataType(), SPARSE_FLOAT_VALUE);
EXPECT_EQ(inputs[2].sparse().getRowBuf(), sparse.getRows());
EXPECT_EQ(inputs[2].sparse().getColBuf(), sparse.getCols());
};
// call function
function(argments);
}
} // namespace paddle

@ -3,6 +3,7 @@ file(GLOB cpp_files . *Op.cpp)
list(APPEND h_files Function.h)
list(APPEND cpp_files Function.cpp)
list(APPEND cpp_files BufferArg.cpp)
if(WITH_GPU)
file(GLOB cu_files . *OpGpu.cu)
@ -18,8 +19,12 @@ if(WITH_TESTING)
# TODO:
# file(GLOB test_files . *OpTest.cpp)
# add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files})
add_simple_unittest(CrossMapNormalOpTest)
add_simple_unittest(ContextProjectionOpTest)
# add_simple_unittest(CrossMapNormalOpTest)
add_simple_unittest(TensorShapeTest)
add_simple_unittest(TensorTypeTest)
add_simple_unittest(BufferArgTest)
add_simple_unittest(FunctionTest)
# add_simple_unittest(ContextProjectionOpTest)
endif()
endif()

File diff suppressed because it is too large Load Diff

@ -31,14 +31,15 @@ namespace paddle {
* \param[in] is_padding whether padding 0 or not.
*
*/
template <DeviceType Device>
void ContextProjectionForward(typename MatrixT<Device>::type* output,
const typename MatrixT<Device>::type* input,
const typename MatrixT<Device>::type* weight,
const typename SequenceT<Device>::type& sequence,
size_t context_length,
int context_start,
size_t begin_pad);
template <DeviceType DType>
void ContextProjectionForward(
typename Tensor<real, DType>::Matrix& output,
const typename Tensor<real, DType>::Matrix& input,
const typename Tensor<real, DType>::Matrix& weight,
const typename Tensor<int, DType>::Vector& sequence,
size_t context_length,
int context_start,
size_t begin_pad);
/**
* \brief Context Projection Backward.
@ -53,30 +54,31 @@ void ContextProjectionForward(typename MatrixT<Device>::type* output,
* \param[in] is_padding whether padding 0 or not.
*
*/
template <DeviceType Device>
void ContextProjectionBackward(typename MatrixT<Device>::type* out_grad,
typename MatrixT<Device>::type* in_grad,
typename MatrixT<Device>::type* w_grad,
const typename SequenceT<Device>::type& seq_vec,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad);
template <DeviceType DType>
void ContextProjectionBackward(
typename Tensor<real, DType>::Matrix& out_grad,
typename Tensor<real, DType>::Matrix& in_grad,
typename Tensor<real, DType>::Matrix& w_grad,
const typename Tensor<int, DType>::Vector& seq_vec,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad);
template <DeviceType Device>
template <DeviceType DType>
void ContextProjectionBackwardData(
typename MatrixT<Device>::type* out_grad,
typename MatrixT<Device>::type* in_grad,
const typename SequenceT<Device>::type& sequence,
typename Tensor<real, DType>::Matrix& out_grad,
typename Tensor<real, DType>::Matrix& in_grad,
const typename Tensor<int, DType>::Vector& sequence,
size_t context_length,
int context_start);
template <DeviceType Device>
template <DeviceType DType>
void ContextProjectionBackwardWeight(
typename MatrixT<Device>::type* out_grad,
typename MatrixT<Device>::type* w_grad,
const typename SequenceT<Device>::type& seq_vec,
typename Tensor<real, DType>::Matrix& out_grad,
typename Tensor<real, DType>::Matrix& w_grad,
const typename Tensor<int, DType>::Vector& seq_vec,
size_t context_length,
int context_start,
size_t total_pad,

@ -120,20 +120,19 @@ void hl_context_projection_forward(const real* input,
}
template <>
void ContextProjectionForward<DEVICE_TYPE_GPU>(GpuMatrix* output,
const GpuMatrix* input,
const GpuMatrix* weight,
void ContextProjectionForward<DEVICE_TYPE_GPU>(GpuMatrix& output,
const GpuMatrix& input,
const GpuMatrix& weight,
const GpuIVector& sequence,
size_t context_length,
int context_start,
size_t begin_pad) {
CHECK(input && output);
hl_context_projection_forward(input->getData(),
hl_context_projection_forward(input.getData(),
sequence.getData(),
weight ? weight->getData() : nullptr,
output->getData(),
weight ? weight.getData() : nullptr,
output.getData(),
sequence.getSize() - 1,
input->getWidth(),
input.getWidth(),
context_length,
context_start,
begin_pad);
@ -217,17 +216,16 @@ void hl_context_projection_backward_data(real* out_grad,
}
template <>
void ContextProjectionBackwardData<DEVICE_TYPE_GPU>(GpuMatrix* out_grad,
GpuMatrix* in_grad,
void ContextProjectionBackwardData<DEVICE_TYPE_GPU>(GpuMatrix& out_grad,
GpuMatrix& in_grad,
const GpuIVector& sequence,
size_t context_length,
int context_start) {
CHECK(in_grad && out_grad);
hl_context_projection_backward_data(out_grad->getData(),
hl_context_projection_backward_data(out_grad.getData(),
sequence.getData(),
in_grad->getData(),
in_grad.getData(),
sequence.getSize() - 1,
in_grad->getWidth(),
in_grad.getWidth(),
context_length,
context_start);
}
@ -348,19 +346,18 @@ void hl_context_projection_backward_weight(real* out_grad,
template <>
void ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(
GpuMatrix* out_grad,
GpuMatrix* w_grad,
GpuMatrix& out_grad,
GpuMatrix& w_grad,
const GpuIVector& seq_vec,
size_t context_length,
int context_start,
size_t total_pad,
size_t begin_pad) {
CHECK(out_grad && w_grad);
hl_context_projection_backward_weight(out_grad->getData(),
hl_context_projection_backward_weight(out_grad.getData(),
seq_vec.getData(),
w_grad->getData(),
w_grad.getData(),
seq_vec.getSize() - 1,
w_grad->getWidth(),
w_grad.getWidth(),
total_pad,
context_length,
context_start,
@ -368,16 +365,15 @@ void ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(
}
template <>
void ContextProjectionBackward<DEVICE_TYPE_GPU>(GpuMatrix* out_grad,
GpuMatrix* in_grad,
GpuMatrix* w_grad,
void ContextProjectionBackward<DEVICE_TYPE_GPU>(GpuMatrix& out_grad,
GpuMatrix& in_grad,
GpuMatrix& w_grad,
const GpuIVector& sequence,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad) {
CHECK(out_grad);
if (in_grad) {
ContextProjectionBackwardData<DEVICE_TYPE_GPU>(
out_grad,

@ -112,6 +112,8 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad,
}
/**
* \brief {o_0, o_1} = calc(i_0)
*
* \param inputs[0] input value.
* \param outputs[0] output value.
* \param outputs[1] denoms.
@ -125,27 +127,24 @@ public:
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(1, static_cast<int>(inputs.size()));
CHECK_EQ(2, static_cast<int>(outputs.size()));
CHECK_EQ(0, static_cast<int>(inouts.size()));
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)1, inputs.size());
CHECK_EQ((size_t)2, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == outputs[0].shape());
CHECK(inputs[0].shape() == outputs[1].shape());
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
CrossMapNormal<Device>(outputs[0].getData(),
outputs[1].getData(),
inputs[0].getData(),
CrossMapNormal<Device>(outputs[0].data<real>(),
outputs[1].data<real>(),
inputs[0].data<real>(),
samples,
channels,
height,
@ -162,6 +161,8 @@ private:
};
/**
* \brief {o_0} = calc(i_0, i_1, i_2, i_3)
*
* \param inputs[0] input value.
* \param inputs[1] output value.
* \param inputs[2] output grad.
@ -177,31 +178,29 @@ public:
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(4, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size()));
CHECK_EQ(0, static_cast<int>(inouts.size()));
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
}
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
CrossMapNormalGrad<Device>(outputs[0].getData(),
inputs[0].getData(),
inputs[1].getData(),
inputs[2].getData(),
inputs[3].getData(),
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)4, inputs.size());
CHECK_EQ((size_t)1, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == inputs[1].shape());
CHECK(inputs[0].shape() == inputs[2].shape());
CHECK(inputs[0].shape() == inputs[3].shape());
CHECK(inputs[0].shape() == outputs[0].shape());
// TODO(hedaoyuan): need support ASSIGN_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
CrossMapNormalGrad<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
inputs[1].data<real>(),
inputs[2].data<real>(),
inputs[3].data<real>(),
samples,
channels,
height,

@ -76,6 +76,20 @@ FuncConfig& FuncConfig::set<bool>(const std::string& key, bool v) {
return *this;
}
void BufferArgs::addArg(const Matrix& arg,
const TensorShape& shape,
ArgType argType) {
args_.push_back(std::make_shared<BufferArg>(arg, shape, argType));
}
void BufferArgs::addArg(const CpuSparseMatrix& arg, ArgType argType) {
args_.push_back(std::make_shared<SparseMatrixArg>(arg, argType));
}
void BufferArgs::addArg(const GpuSparseMatrix& arg, ArgType argType) {
args_.push_back(std::make_shared<SparseMatrixArg>(arg, argType));
}
ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_;
} // namespace paddle

@ -16,57 +16,17 @@ limitations under the License. */
#include <map>
#include <vector>
#include "BufferArg.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/ClassRegistrar.h"
namespace paddle {
enum DeviceType {
DEVICE_TYPE_UNSPECIFIED = 0,
DEVICE_TYPE_CPU = 1,
DEVICE_TYPE_GPU = 2,
};
template <DeviceType Device>
struct MatrixT;
template <>
struct MatrixT<DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
template <>
struct MatrixT<DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
template <DeviceType Device>
struct SequenceT;
template <>
struct SequenceT<DEVICE_TYPE_CPU> {
using type = CpuIVector;
};
template <>
struct SequenceT<DEVICE_TYPE_GPU> {
using type = GpuIVector;
};
typedef std::vector<size_t> Dims;
class Tensor {
public:
Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {}
real* getData() const { return buf_; }
real* buf_;
Dims dims_;
};
typedef std::vector<Tensor> Arguments;
/**
* Function Configuration.
* The argument type of Function::init.
* Follow-up will consider moving this data structure to Proto inside.
*/
class FuncConfig {
public:
union value {
@ -86,15 +46,70 @@ protected:
std::map<std::string, value> valueMap_;
};
/**
* Argument type for Function::calc().
* A BufferArgs contains a set of BufferArg,
* because Function can have multiple inputs and outputs.
*/
class BufferArgs {
public:
BufferArgs() {}
size_t size() const { return args_.size(); }
// add argument into BufferArgs
// Tensor can be Matrix, Vector, IVector.
// For inputs, do not need argType.
// For outputs, the argType needs to be specified as ASSIGN_TO or ADD_TO.
template <typename Tensor>
void addArg(const Tensor& arg, ArgType argType = UNSPECIFIED) {
args_.push_back(std::make_shared<BufferArg>(arg, argType));
}
// Add arg into BufferArgs and reshape the arg.
//
// For example, arg represents an image buffer,
// but Matrix can only represent a two-dimensional Tensor.
// So need an extra argument to describe the shape of the image buffer.
void addArg(const Matrix& arg,
const TensorShape& shape,
ArgType argType = UNSPECIFIED);
void addArg(const CpuSparseMatrix& arg, ArgType argType = UNSPECIFIED);
void addArg(const GpuSparseMatrix& arg, ArgType argType = UNSPECIFIED);
// get argument
const BufferArg& operator[](size_t num) const {
CHECK_LT(num, args_.size());
return *args_[num];
}
private:
std::vector<BufferArgPtr> args_;
};
/**
* \brief Base class for Function.
* The basic Function implementation requires override init and calc interfaces.
*
* Function inputs are readonly, Function outputs have two modes: ASSIGN_TO
* and ADD_TO.
* If output.getArgType() == ASSIGN_TO, this is assign mode, and the calculation
* result of Function assigned to the output BufferArg.
* If output.getArgType() == ADD_TO, this is add mode, and the calculation
* result of Function need added to the output BufferArg.
*
* For example:
* ASSIGN_TO: output = Function(inputs)
* ADD_TO: output += Function(inputs)
* If Function has more than one output, each output can have different modes.
*/
class FunctionBase {
public:
virtual ~FunctionBase() {}
virtual void init(const FuncConfig& config) {}
virtual void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) {}
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
static ClassRegistrar<FunctionBase> funcRegistrar_;
};

@ -0,0 +1,59 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Function.h"
#include <gtest/gtest.h>
namespace paddle {
template <DeviceType DType>
void FunctionApi(typename Tensor<real, DType>::Matrix& output,
const typename Tensor<real, DType>::Matrix& input);
template <>
void FunctionApi<DEVICE_TYPE_CPU>(CpuMatrix& output, const CpuMatrix& input) {
EXPECT_EQ(output.getHeight(), 100);
EXPECT_EQ(output.getWidth(), 200);
}
template <>
void FunctionApi<DEVICE_TYPE_GPU>(GpuMatrix& output, const GpuMatrix& input) {
EXPECT_EQ(output.getHeight(), 10);
EXPECT_EQ(output.getWidth(), 20);
}
template <DeviceType DType>
void Function(const BufferArgs& arguments) {
const auto input = arguments[0].matrix<DType>();
auto output = arguments[1].matrix<DType>();
FunctionApi<DType>(output, input);
}
TEST(Function, BufferArgs) {
CpuMatrix cpuInput = CpuMatrix(100, 200);
CpuMatrix cpuOutput = CpuMatrix(100, 200);
BufferArgs cpuArgments;
cpuArgments.addArg(cpuInput);
cpuArgments.addArg(cpuOutput);
Function<DEVICE_TYPE_CPU>(cpuArgments);
GpuMatrix gpuInput = GpuMatrix(10, 20);
GpuMatrix gpuOutput = GpuMatrix(10, 20);
BufferArgs gpuArgments;
gpuArgments.addArg(gpuInput);
gpuArgments.addArg(gpuOutput);
Function<DEVICE_TYPE_GPU>(gpuArgments);
}
} // namespace paddle

@ -0,0 +1,97 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
namespace paddle {
/**
* TensorShape used to represent shape of normal tensor.
*/
class TensorShape {
public:
TensorShape() : ndims_(0), nelements_(0) { initDims(0); }
TensorShape(size_t ndims) : ndims_(ndims), nelements_(1) { initDims(ndims); };
TensorShape(std::initializer_list<size_t> dims) {
ndims_ = dims.size();
initDims(ndims_);
dims_.assign(dims);
numElements();
};
TensorShape(const TensorShape& t)
: ndims_(t.ndims_), nelements_(t.nelements_) {
initDims(ndims_);
dims_.assign(t.dims_.begin(), t.dims_.end());
};
// get the size of specified dimension
size_t operator[](size_t dim) const {
CHECK_GE(dim, (size_t)0);
CHECK_LT(dim, ndims_);
return dims_[dim];
}
// set the size of specified dimension
void setDim(size_t dim, size_t size) {
CHECK_GE(dim, (size_t)0);
CHECK_LT(dim, ndims_);
dims_[dim] = size;
numElements();
}
// number of dimensions of the tensor
size_t ndims() const { return ndims_; }
size_t getElements() const { return nelements_; }
bool operator==(const TensorShape& t) const {
if (ndims() != t.ndims()) return false;
for (size_t i = 0; i < ndims(); i++) {
if (dims_[i] != t.dims_[i]) return false;
}
return true;
}
bool operator!=(const TensorShape& t) const { return !(*this == t); }
private:
// compute number of elements
void numElements() {
nelements_ = 1;
for (size_t n = 0; n < ndims_; n++) {
nelements_ *= dims_[n];
}
}
// init dims_
void initDims(size_t ndims) {
size_t count = ndims < 4 ? 4 : ndims;
dims_.assign(count, 1);
}
// number of dimensions
// ndims_ may be not equeal dims_.size()
size_t ndims_;
// number of elements
size_t nelements_;
std::vector<size_t> dims_;
};
} // namespace paddle

@ -0,0 +1,53 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "TensorShape.h"
#include <gtest/gtest.h>
namespace paddle {
TEST(TensorShape, Constructor) {
TensorShape t1;
EXPECT_EQ(t1.ndims(), 0);
EXPECT_EQ(t1.getElements(), 0);
TensorShape t2(3);
EXPECT_EQ(t2.ndims(), 3);
EXPECT_EQ(t2.getElements(), 1);
TensorShape t3({8, 10});
EXPECT_EQ(t3.ndims(), 2);
EXPECT_EQ(t3.getElements(), 80);
TensorShape t4(t3);
EXPECT_EQ(t4.ndims(), t3.ndims());
EXPECT_EQ(t4.getElements(), t3.getElements());
TensorShape t5({1, 2, 3, 4, 5});
EXPECT_EQ(t5.ndims(), 5);
EXPECT_EQ(t5.getElements(), 120);
}
TEST(TensorShape, GetAndSet) {
TensorShape t({1, 2, 3});
EXPECT_EQ(t.ndims(), 3);
EXPECT_EQ(t.getElements(), 6);
EXPECT_EQ(t[1], 2);
t.setDim(1, 100);
EXPECT_EQ(t.getElements(), 300);
EXPECT_EQ(t[1], 100);
}
} // namespace paddle

@ -0,0 +1,121 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/math/Matrix.h"
namespace paddle {
enum ValueType {
VALUE_TYPE_INT32 = 0,
VALUE_TYPE_FLOAT = 1,
VALUE_TYPE_DOUBLE = 2,
VALUE_TYPE_BYTE = 3
};
enum DeviceType {
DEVICE_TYPE_UNSPECIFIED = 0,
DEVICE_TYPE_CPU = 1,
DEVICE_TYPE_GPU = 2
};
inline int sizeOfValuType(ValueType valueType) {
if (valueType == VALUE_TYPE_INT32) {
return 4;
} else if (valueType == VALUE_TYPE_FLOAT) {
return 4;
} else if (valueType == VALUE_TYPE_DOUBLE) {
return 8;
} else {
LOG(FATAL) << "Unknown type: " << valueType;
return 0;
}
}
template <typename T>
struct DataType;
template <>
struct DataType<float> {
static const ValueType value = VALUE_TYPE_FLOAT;
};
template <>
struct DataType<double> {
static const ValueType value = VALUE_TYPE_DOUBLE;
};
template <>
struct DataType<int> {
static const ValueType value = VALUE_TYPE_INT32;
};
namespace detail {
template <typename VType, DeviceType Device>
struct MatrixT;
template <>
struct MatrixT<real, DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
template <>
struct MatrixT<real, DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
template <>
struct MatrixT<int, DEVICE_TYPE_CPU> {
using type = void; // Not implemented
};
template <>
struct MatrixT<int, DEVICE_TYPE_GPU> {
using type = void; // Not implemented
};
template <typename VType, DeviceType Device>
struct VectorT;
template <>
struct VectorT<real, DEVICE_TYPE_CPU> {
using type = CpuVector;
};
template <>
struct VectorT<real, DEVICE_TYPE_GPU> {
using type = GpuVector;
};
template <>
struct VectorT<int, DEVICE_TYPE_CPU> {
using type = CpuIVector;
};
template <>
struct VectorT<int, DEVICE_TYPE_GPU> {
using type = GpuIVector;
};
} // namespace detail
template <typename VType, DeviceType DType>
struct Tensor {
typedef typename detail::MatrixT<VType, DType>::type Matrix;
typedef typename detail::VectorT<VType, DType>::type Vector;
};
} // namespace paddle

@ -0,0 +1,64 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "TensorType.h"
#include <gtest/gtest.h>
namespace paddle {
TEST(TensorType, Matrix) {
Tensor<real, DEVICE_TYPE_CPU>::Matrix matrix(100, 200);
EXPECT_EQ(matrix.getHeight(), 100);
EXPECT_EQ(matrix.getWidth(), 200);
EXPECT_EQ(matrix.getElementCnt(), 100 * 200);
EXPECT_EQ(matrix.useGpu(), false);
Tensor<real, DEVICE_TYPE_GPU>::Matrix testGpu(100, 200);
EXPECT_EQ(testGpu.useGpu(), true);
}
TEST(TensorType, Vector) {
Tensor<real, DEVICE_TYPE_CPU>::Vector cpuVector(100);
Tensor<real, DEVICE_TYPE_GPU>::Vector gpuVector(100);
EXPECT_EQ(cpuVector.useGpu(), false);
EXPECT_EQ(gpuVector.useGpu(), true);
EXPECT_EQ(cpuVector.getSize(), 100);
EXPECT_EQ(gpuVector.getSize(), 100);
Tensor<int, DEVICE_TYPE_CPU>::Vector cpuIVector(100);
Tensor<int, DEVICE_TYPE_GPU>::Vector gpuIVector(100);
EXPECT_EQ(cpuIVector.useGpu(), false);
EXPECT_EQ(gpuIVector.useGpu(), true);
EXPECT_EQ(cpuIVector.getSize(), 100);
EXPECT_EQ(gpuIVector.getSize(), 100);
}
TEST(TensorType, EmptyMatrix) {
CpuMatrix empty(nullptr, 0, 0);
CpuMatrix nonEmpty(10, 10);
EXPECT_EQ(empty.isEmpty(), true);
EXPECT_EQ(nonEmpty.isEmpty(), false);
CHECK(nonEmpty);
auto function = [](const CpuMatrix& matrix) {
if (matrix) {
EXPECT_NE(matrix.getData(), nullptr);
} else {
EXPECT_EQ(matrix.getData(), nullptr);
}
};
function(empty);
function(nonEmpty);
}
} // namespace paddle

@ -110,9 +110,8 @@ void ContextProjection::forward() {
size_t input_dim = in_->value->getWidth();
size_t dim = out_->value->getWidth();
CHECK_EQ(dim, input_dim * config_.context_length());
size_t batch_size = in_->value->getHeight();
CHECK_EQ(static_cast<int>(forward_.size()), 1)
<< "Only one forward function here";
// size_t batch_size = in_->value->getHeight();
CHECK_EQ(forward_.size(), (size_t)1) << "Only one forward function here";
REGISTER_TIMER_INFO("ContextProjectionForward", getName().c_str());
bool is_padding = config_.trainable_padding();
@ -120,14 +119,16 @@ void ContextProjection::forward() {
auto w_ptr =
state_ ? state_.get() : is_padding ? weight_->getW().get() : nullptr;
auto start_pos = in_->sequenceStartPositions;
forward_[0]->calc({Tensor(in_->value->getData(), Dims{batch_size, input_dim}),
Tensor(w_ptr ? w_ptr->getData() : nullptr,
Dims{w_ptr ? w_ptr->getHeight() : 0, input_dim}),
Tensor(reinterpret_cast<real*>(
const_cast<int*>(start_pos->getData(useGpu_))),
Dims{start_pos->getSize()})},
{Tensor(out_->value->getData(), Dims{batch_size, dim})},
{});
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*in_->value);
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0,
input_dim));
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
outputs.addArg(*out_->value, ADD_TO);
forward_[0]->calc(inputs, outputs);
if (state_ && config_.context_start() < 0) {
CHECK_EQ(1, in_->getNumSequences());
@ -162,15 +163,17 @@ void ContextProjection::backward(const UpdateCallback& callback) {
bool is_padding = config_.trainable_padding();
auto start_pos = in_->sequenceStartPositions;
auto w_ptr = is_padding ? weight_->getWGrad() : nullptr;
backward_[0]->calc({Tensor(in_->grad ? in_->grad->getData() : nullptr,
Dims{batch_size, input_dim}),
Tensor(w_ptr ? w_ptr->getData() : nullptr,
Dims{w_ptr ? w_ptr->getHeight() : 0, input_dim}),
Tensor(reinterpret_cast<real*>(
const_cast<int*>(start_pos->getData(useGpu_))),
Dims{start_pos->getSize()})},
{Tensor(out_->grad->getData(), Dims{batch_size, dim})},
{});
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(CpuMatrix(
in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim));
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0,
input_dim));
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
outputs.addArg(*out_->grad, ADD_TO);
backward_[0]->calc(inputs, outputs);
if (config_.trainable_padding()) {
weight_->getParameterPtr()->incUpdate(callback);

@ -59,7 +59,6 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
void CMRProjectionNormLayer::forward(PassType passType) {
Layer::forward(passType);
/* malloc memory for the output_ if necessary */
/* note: one sample correspond to one row */
MatrixPtr input = inputLayers_[0]->getOutputValue();
@ -67,34 +66,36 @@ void CMRProjectionNormLayer::forward(PassType passType) {
int size = getSize();
resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
forward_[0]->calc(
{Tensor(input->getData(), dims_)},
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
{});
shape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_});
// prepare forward arguments
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), shape_);
outputs.addArg(*getOutputValue(), shape_, ASSIGN_TO);
outputs.addArg(*denoms_, shape_, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
}
void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
(void)callback;
if (NULL == inputLayers_[0]->getOutputGrad()) {
if (NULL == getInputGrad(0)) {
return;
}
/* Do derivation */
MatrixPtr preOutGrad = inputLayers_[0]->getOutputGrad();
MatrixPtr localGrad = getOutputGrad();
MatrixPtr localOutV = getOutputValue();
MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
backward_[0]->calc({Tensor(preOutV->getData(), dims_),
Tensor(localOutV->getData(), dims_),
Tensor(localGrad->getData(), dims_),
Tensor(denoms_->getData(), dims_)},
{Tensor(preOutGrad->getData(), dims_)},
{});
// prepare backward arguments
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), shape_);
inputs.addArg(*getOutputValue(), shape_);
inputs.addArg(*getOutputGrad(), shape_);
inputs.addArg(*denoms_, shape_);
outputs.addArg(*getInputGrad(0), shape_, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
} // namespace paddle

@ -41,6 +41,6 @@ public:
void backward(const UpdateCallback& callback = nullptr);
protected:
Dims dims_;
TensorShape shape_;
};
} // namespace paddle

@ -1091,6 +1091,10 @@ public:
TensorCpuApply<real>(*this, expr);
}
}
bool isEmpty() const { return data_ == nullptr; }
explicit operator bool() const { return !isEmpty(); }
};
inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) {

Loading…
Cancel
Save