parent
3a5d60bc1c
commit
a7855d3ebb
@ -0,0 +1,142 @@
|
||||
/**
|
||||
* TensorAssign.h
|
||||
*
|
||||
* Author: hedaoyuan (hedaoyuan@baidu.com)
|
||||
* Created on: 2016-10-08
|
||||
*
|
||||
* Copyright (c) Baidu.com, Inc. All Rights Reserved
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include "paddle/utils/Logging.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
template<typename LhsType, typename RhsType, class T>
|
||||
class TensorAssignOp {
|
||||
public:
|
||||
explicit TensorAssignOp(const LhsType& lhs, const RhsType& rhs)
|
||||
: lhs_(lhs), rhs_(rhs) {
|
||||
#ifndef __CUDA_ARCH__
|
||||
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
|
||||
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
|
||||
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
|
||||
#endif
|
||||
}
|
||||
|
||||
INLINE void apply(const int i, const int j) {
|
||||
lhs_.applyRef(i, j) = rhs_.apply(i, j);
|
||||
}
|
||||
INLINE void apply(const int index) {
|
||||
lhs_.applyRef(index) = rhs_.apply(index);
|
||||
}
|
||||
|
||||
INLINE size_t getWidth() const { return lhs_.getWidth(); }
|
||||
INLINE size_t getHeight() const { return rhs_.getHeight(); }
|
||||
INLINE bool isContiguous() const {
|
||||
return lhs_.isContiguous() && rhs_.isContiguous();
|
||||
}
|
||||
INLINE bool useGpu() const { return lhs_.useGpu(); }
|
||||
|
||||
private:
|
||||
TensorApply<LhsType, T> lhs_;
|
||||
TensorApply<const RhsType, T> rhs_;
|
||||
};
|
||||
|
||||
template <typename Assign, typename... AssignOp>
|
||||
void AssignCpuEvaluate(int height, int width, bool isContiguous,
|
||||
Assign&& assign, AssignOp&& ... args) {
|
||||
if (isContiguous) {
|
||||
int size = height * width;
|
||||
for (int index = 0; index < size; index++) {
|
||||
assign.apply(index);
|
||||
__attribute__((unused)) int dummy[] = { (((args)).apply(index), 0)... };
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < height; i++) {
|
||||
for (int j = 0; j < width; j++) {
|
||||
assign.apply(i, j);
|
||||
__attribute__((unused)) int dummy[] = { (((args)).apply(i, j), 0)... };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __NVCC__
|
||||
template <typename Assign, typename... AssignOp>
|
||||
__global__
|
||||
void AssignGpuEvaluate1(const int border, Assign assign, AssignOp ... args) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < border) {
|
||||
assign.apply(idx);
|
||||
__attribute__((unused)) int dummy[] = { (((args)).apply(idx), 0)... };
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Assign, typename... AssignOp>
|
||||
__global__
|
||||
void AssignGpuEvaluate2(const int height, const int width,
|
||||
Assign assign, AssignOp ... args) {
|
||||
const int colIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int rowIdx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
for (int i = rowIdx; i < height; i += gridDim.y * blockDim.y) {
|
||||
for (int j = colIdx; j < width; j += gridDim.x * blockDim.x) {
|
||||
assign.apply(i, j);
|
||||
__attribute__((unused)) int dummy[] = { (((args)).apply(i, j), 0)... };
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// At least one assignment expression is required
|
||||
template <typename Assign, typename... AssignOp>
|
||||
void AssignEvaluate(Assign&& assign, AssignOp&& ... args) {
|
||||
const bool useGpu_ = assign.useGpu();
|
||||
bool isContiguous_ = assign.isContiguous();
|
||||
const size_t height = assign.getHeight();
|
||||
const size_t width = assign.getWidth();
|
||||
|
||||
const int packSize = sizeof...(args);
|
||||
const bool packUseGpu[] = { ((args)).useGpu()... };
|
||||
const bool packIsContiguous[] = { ((args)).isContiguous()... };
|
||||
const size_t packHeight[] = { ((args)).getHeight()... };
|
||||
const size_t packWidth[] = { ((args)).getWidth()... };
|
||||
|
||||
for (int i = 0; i < packSize; i++) {
|
||||
CHECK_EQ(useGpu_, packUseGpu[i]);
|
||||
CHECK_EQ(height, packHeight[i]);
|
||||
CHECK_EQ(width, packWidth[i]);
|
||||
isContiguous_ = isContiguous_ && packIsContiguous[i];
|
||||
}
|
||||
|
||||
if (useGpu_) {
|
||||
#ifdef __NVCC__
|
||||
if (isContiguous_) {
|
||||
int size = height * width;
|
||||
int blockSize = size <= 1024 ? size : 1024;
|
||||
int gridSize = (size + 1024 - 1) / 1024;
|
||||
AssignGpuEvaluate1
|
||||
<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(size, assign, args...);
|
||||
} else {
|
||||
int blockSizeY = std::min(32, (int)height);
|
||||
int blockSizeX = (32 / blockSizeY) * 32;
|
||||
int gridSizeX = std::min(32, (int)(width + blockSizeX - 1) / blockSizeX);
|
||||
int gridSizeY = std::min(32, (int)(height + blockSizeY - 1) / blockSizeY);
|
||||
dim3 threads(blockSizeX, blockSizeY);
|
||||
dim3 grid(gridSizeX, gridSizeY);
|
||||
AssignGpuEvaluate2
|
||||
<<<grid, threads, 0, STREAM_DEFAULT>>>(height, width, assign, args...);
|
||||
}
|
||||
|
||||
CHECK_SYNC("AssignEvaluate failed");
|
||||
#endif
|
||||
} else {
|
||||
AssignCpuEvaluate(height, width, isContiguous_, assign, args...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
||||
|
@ -0,0 +1,179 @@
|
||||
/**
|
||||
* test_Tensor.cpp
|
||||
*
|
||||
* Author: hedaoyuan (hedaoyuan@baidu.com)
|
||||
* Created on: 2016-06-06
|
||||
*
|
||||
* Copyright (c) Baidu.com, Inc. All Rights Reserved
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/math/Matrix.h"
|
||||
using namespace paddle; // NOLINT
|
||||
using namespace std; // NOLINT
|
||||
|
||||
template<typename Tensor>
|
||||
extern void TensorCheckEqual(const Tensor& tensor1, const Tensor& tensor2);
|
||||
|
||||
void TensorCheckEqual(const CpuMatrix& matrix1, const CpuMatrix& matrix2) {
|
||||
CHECK(matrix1.getHeight() == matrix2.getHeight());
|
||||
CHECK(matrix1.getWidth() == matrix2.getWidth());
|
||||
|
||||
int height = matrix1.getHeight();
|
||||
int width = matrix1.getWidth();
|
||||
const real* data1 = matrix1.getData();
|
||||
const real* data2 = matrix2.getData();
|
||||
int count = 0;
|
||||
for (int i = 0; i < height; i++) {
|
||||
for (int j = 0; j < width; j++) {
|
||||
if (data1[i * width + j] != data2[i * width + j]) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(count, 0) << "There are " << count << " different element.";
|
||||
}
|
||||
|
||||
void TensorCheckEqual(const GpuMatrix& matrix1, const GpuMatrix& matrix2) {
|
||||
CpuMatrix cpu1(matrix1.getHeight(), matrix1.getWidth());
|
||||
CpuMatrix cpu2(matrix2.getHeight(), matrix2.getWidth());
|
||||
cpu1.copyFrom(matrix1);
|
||||
cpu2.copyFrom(matrix2);
|
||||
TensorCheckEqual(cpu1, cpu2);
|
||||
}
|
||||
|
||||
void TensorCheckErr(const CpuMatrix& matrix1, const CpuMatrix& matrix2) {
|
||||
CHECK(matrix1.getHeight() == matrix2.getHeight());
|
||||
CHECK(matrix1.getWidth() == matrix2.getWidth());
|
||||
#ifndef PADDLE_TYPE_DOUBLE
|
||||
real err = 1e-5;
|
||||
#else
|
||||
real err = 1e-10;
|
||||
#endif
|
||||
|
||||
int height = matrix1.getHeight();
|
||||
int width = matrix1.getWidth();
|
||||
const real* data1 = matrix1.getData();
|
||||
const real* data2 = matrix2.getData();
|
||||
int count = 0;
|
||||
for (int i = 0; i < height; i++) {
|
||||
for (int j = 0; j < width; j++) {
|
||||
real a = data1[i * width + j];
|
||||
real b = data2[i * width + j];
|
||||
if (fabs(a - b) > err) {
|
||||
if ((fabsf(a - b) / fabsf(a)) > (err / 10.0f)) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(count, 0) << "There are " << count << " different element.";
|
||||
}
|
||||
|
||||
void TensorCheckErr(const GpuMatrix& matrix1, const GpuMatrix& matrix2) {
|
||||
CpuMatrix cpu1(matrix1.getHeight(), matrix1.getWidth());
|
||||
CpuMatrix cpu2(matrix2.getHeight(), matrix2.getWidth());
|
||||
cpu1.copyFrom(matrix1);
|
||||
cpu2.copyFrom(matrix2);
|
||||
TensorCheckErr(cpu1, cpu2);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TensorCheckEqual(const CpuVectorT<T>& vector1,
|
||||
const CpuVectorT<T>& vector2) {
|
||||
CHECK(vector1.getSize() == vector2.getSize());
|
||||
|
||||
const T* data1 = vector1.getData();
|
||||
const T* data2 = vector2.getData();
|
||||
size_t size = vector1.getSize();
|
||||
int count = 0;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (data1[i] != data2[i]) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(count, 0) << "There are " << count << " different element.";
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void TensorCheckEqual(const GpuVectorT<T>& vector1,
|
||||
const GpuVectorT<T>& vector2) {
|
||||
CpuVectorT<T> cpu1(vector1.getSize());
|
||||
CpuVectorT<T> cpu2(vector2.getSize());
|
||||
cpu1.copyFrom(vector1);
|
||||
cpu2.copyFrom(vector2);
|
||||
TensorCheckEqual(cpu1, cpu2);
|
||||
}
|
||||
|
||||
int VectorCheckErr(const Vector& vector1, const Vector& vector2) {
|
||||
CHECK(vector1.getSize() == vector2.getSize());
|
||||
|
||||
const real* data1 = vector1.getData();
|
||||
const real* data2 = vector2.getData();
|
||||
size_t size = vector1.getSize();
|
||||
int count = 0;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
real a = data1[i];
|
||||
real b = data2[i];
|
||||
if (fabs(a - b) > FLAGS_max_diff) {
|
||||
if ((fabsf(a - b) / fabsf(a)) > (FLAGS_max_diff / 10.0f)) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
#define INIT_UNARY(A1, A2) \
|
||||
Tensor A1(height, width); \
|
||||
Tensor A2(height, width); \
|
||||
A1.randomizeUniform(); \
|
||||
A2.copyFrom(A1)
|
||||
#define INIT_BINARY(A1, A2, B) \
|
||||
INIT_UNARY(A1, A2); \
|
||||
Tensor B(height, width); \
|
||||
B.randomizeUniform()
|
||||
#define INIT_TERNARY(A1, A2, B, C) \
|
||||
INIT_BINARY(A1, A2, B); \
|
||||
Tensor C(height, width); \
|
||||
C.randomizeUniform()
|
||||
#define INIT_QUATERNARY(A1, A2, B, C, D) \
|
||||
INIT_TERNARY(A1, A2, B, C); \
|
||||
Tensor D(height, width); \
|
||||
D.randomizeUniform()
|
||||
|
||||
// Performance Check
|
||||
#ifdef PADDLE_DISABLE_TIMER
|
||||
|
||||
#define CHECK_VECTORPTR(vector1, vector2) \
|
||||
EXPECT_EQ(VectorCheckErr(vector1, vector2), 0)
|
||||
|
||||
#define EXPRESSION_PERFORMANCE(expression) \
|
||||
expression;
|
||||
|
||||
#else
|
||||
|
||||
#include "paddle/utils/Stat.h"
|
||||
|
||||
#define CHECK_VECTORPTR(vector1, vector2)
|
||||
|
||||
#define EXPRESSION_PERFORMANCE(expression) \
|
||||
do {\
|
||||
char expr[30];\
|
||||
strncpy(expr, #expression, 30);\
|
||||
if (expr[29] != '\0') {\
|
||||
expr[27] = '.'; expr[28] = '.'; expr[29] = '\0';\
|
||||
}\
|
||||
expression;\
|
||||
for (int i = 0; i < 20; i++) {\
|
||||
REGISTER_TIMER(expr);\
|
||||
expression;\
|
||||
}\
|
||||
LOG(INFO) << std::setiosflags(std::ios::left) << std::setfill(' ')\
|
||||
<< *globalStat.getStat(expr);\
|
||||
globalStat.reset();\
|
||||
} while (0)
|
||||
|
||||
#endif
|
||||
|
@ -0,0 +1,131 @@
|
||||
/**
|
||||
* test_lazyAssign.cpp
|
||||
*
|
||||
* Author: hedaoyuan (hedaoyuan@baidu.com)
|
||||
* Created on: 2016-10-15
|
||||
*
|
||||
* Copyright (c) Baidu.com, Inc. All Rights Reserved
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/math/Matrix.h"
|
||||
#include "paddle/math/TensorAssign.h"
|
||||
#include "TensorCheck.h"
|
||||
|
||||
using namespace paddle; // NOLINT
|
||||
using namespace std; // NOLINT
|
||||
|
||||
typedef std::function<void(int height, int width)> testMatrixFunc;
|
||||
void testMatrixCase(testMatrixFunc matrixFunc) {
|
||||
for (auto height : {1}) {
|
||||
for (auto width : {1, 32, 64, 128, 512, 1024, 4096, 32768, 65536, 131072,
|
||||
262144, 524288, 1048576, 2097152, 4194304, 8388608}) {
|
||||
matrixFunc(height, width);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Tensor>
|
||||
void testLazyAssign(int height, int width) {
|
||||
INIT_QUATERNARY(A1, A2, B, C, D);
|
||||
|
||||
EXPRESSION_PERFORMANCE(A1 = B + C; A1 = A1 * D;);
|
||||
|
||||
EXPRESSION_PERFORMANCE(
|
||||
auto expr1 = A2.lazyAssign(B + C);
|
||||
auto expr2 = A2.lazyAssign(A2 * D);
|
||||
AssignEvaluate(expr1, expr2););
|
||||
|
||||
TensorCheckErr(A1, A2);
|
||||
}
|
||||
|
||||
TEST(lazyAssign, CPU) {
|
||||
testMatrixCase(testLazyAssign<CpuMatrix>);
|
||||
}
|
||||
|
||||
#ifndef PADDLE_ONLY_CPU
|
||||
TEST(lazyAssign, GPU) {
|
||||
testMatrixCase(testLazyAssign<GpuMatrix>);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename Tensor>
|
||||
void sgdUpdateTensor(Tensor& A, Tensor& B, Tensor& C, Tensor& D,
|
||||
real p1, real p2, real p3) {
|
||||
C = C * p2 - D * (B + A * p3) * p1;
|
||||
A += C;
|
||||
}
|
||||
|
||||
void sgdUpdateLazyAssign(BaseMatrix& A, BaseMatrix& B,
|
||||
BaseMatrix& C, BaseMatrix& D,
|
||||
real p1, real p2, real p3) {
|
||||
auto expr1 = C.lazyAssign(C * p2 - D * (B + A * p3) * p1);
|
||||
auto expr2 = A.lazyAssign(A + C);
|
||||
AssignEvaluate(expr1, expr2);
|
||||
}
|
||||
|
||||
template<typename Tensor>
|
||||
void testSgdUpdate(int height, int width) {
|
||||
Tensor A1(height, width);
|
||||
Tensor A2(height, width);
|
||||
Tensor A3(height, width);
|
||||
A1.randomizeUniform();
|
||||
A2.copyFrom(A1);
|
||||
A3.copyFrom(A1);
|
||||
|
||||
Tensor B(height, width);
|
||||
B.randomizeUniform();
|
||||
|
||||
Tensor C1(height, width);
|
||||
Tensor C2(height, width);
|
||||
Tensor C3(height, width);
|
||||
C1.randomizeUniform();
|
||||
C2.copyFrom(C1);
|
||||
C3.copyFrom(C1);
|
||||
|
||||
Tensor D(height, width);
|
||||
D.randomizeUniform();
|
||||
|
||||
real p1 = 0.2;
|
||||
real p2 = 0.3;
|
||||
real p3 = 0.5;
|
||||
|
||||
/**
|
||||
* c = p2 * c - p1 * (b + p3 * a);
|
||||
* a = a + c;
|
||||
*/
|
||||
// BaseMatrix API
|
||||
EXPRESSION_PERFORMANCE(
|
||||
A1.sgdUpdate(B, C1, D, p1, p2, p3););
|
||||
|
||||
// Tensor expression
|
||||
EXPRESSION_PERFORMANCE(
|
||||
sgdUpdateTensor(A2, B, C2, D, p1, p2, p3));
|
||||
|
||||
// lazyAssign
|
||||
EXPRESSION_PERFORMANCE(
|
||||
sgdUpdateLazyAssign(A3, B, C3, D, p1, p2, p3));
|
||||
|
||||
TensorCheckErr(A1, A2);
|
||||
TensorCheckErr(A1, A3);
|
||||
TensorCheckErr(C1, C2);
|
||||
TensorCheckErr(C1, C3);
|
||||
}
|
||||
|
||||
TEST(sgdUpdate, CPU) {
|
||||
testMatrixCase(testSgdUpdate<CpuMatrix>);
|
||||
}
|
||||
|
||||
#ifndef PADDLE_ONLY_CPU
|
||||
TEST(sgdUpdate, GPU) {
|
||||
testMatrixCase(testSgdUpdate<GpuMatrix>);
|
||||
}
|
||||
#endif
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
hl_start();
|
||||
hl_init(0);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
Loading…
Reference in new issue