|
|
|
@ -24,7 +24,6 @@ limitations under the License. */
|
|
|
|
|
#include "TestUtils.h"
|
|
|
|
|
|
|
|
|
|
using namespace paddle; // NOLINT
|
|
|
|
|
using namespace std; // NOLINT
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Test member functions which prototype is
|
|
|
|
@ -32,8 +31,8 @@ using namespace std; // NOLINT
|
|
|
|
|
*/
|
|
|
|
|
TEST(BaseMatrix, void) {
|
|
|
|
|
typedef void (BaseMatrix::*FunctionProto)();
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
BASEMATRIXCOMPARE(neg);
|
|
|
|
|
BASEMATRIXCOMPARE(exp);
|
|
|
|
@ -46,7 +45,7 @@ TEST(BaseMatrix, void) {
|
|
|
|
|
BASEMATRIXCOMPARE(zero);
|
|
|
|
|
BASEMATRIXCOMPARE(one);
|
|
|
|
|
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -55,8 +54,8 @@ TEST(BaseMatrix, void) {
|
|
|
|
|
*/
|
|
|
|
|
TEST(BaseMatrix, real) {
|
|
|
|
|
typedef void (BaseMatrix::*FunctionProto)(real);
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare<0>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare<0>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
BASEMATRIXCOMPARE(pow);
|
|
|
|
|
BASEMATRIXCOMPARE(subScalar);
|
|
|
|
@ -67,7 +66,7 @@ TEST(BaseMatrix, real) {
|
|
|
|
|
BASEMATRIXCOMPARE(biggerThanScalar);
|
|
|
|
|
BASEMATRIXCOMPARE(downClip);
|
|
|
|
|
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -76,13 +75,13 @@ TEST(BaseMatrix, real) {
|
|
|
|
|
*/
|
|
|
|
|
TEST(BaseMatrix, real_real) {
|
|
|
|
|
typedef void (BaseMatrix::*FunctionProto)(real, real);
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare<0, 1>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare<0, 1>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
BASEMATRIXCOMPARE(add);
|
|
|
|
|
BASEMATRIXCOMPARE(clip);
|
|
|
|
|
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -91,8 +90,8 @@ TEST(BaseMatrix, real_real) {
|
|
|
|
|
*/
|
|
|
|
|
TEST(BaseMatrix, BaseMatrix) {
|
|
|
|
|
typedef void (BaseMatrix::*FunctionProto)(BaseMatrix&);
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare<0>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare<0>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
BASEMATRIXCOMPARE(assign);
|
|
|
|
|
BASEMATRIXCOMPARE(add);
|
|
|
|
@ -129,7 +128,7 @@ TEST(BaseMatrix, BaseMatrix) {
|
|
|
|
|
BASEMATRIXCOMPARE(addP2P);
|
|
|
|
|
BASEMATRIXCOMPARE(invSqrt);
|
|
|
|
|
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -138,8 +137,8 @@ TEST(BaseMatrix, BaseMatrix) {
|
|
|
|
|
*/
|
|
|
|
|
TEST(BaseMatrix, BaseMatrix_real) {
|
|
|
|
|
typedef void (BaseMatrix::*FunctionProto)(BaseMatrix&, real);
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare<0, 1>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare<0, 1>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
BASEMATRIXCOMPARE(addBias);
|
|
|
|
|
BASEMATRIXCOMPARE(add);
|
|
|
|
@ -154,7 +153,7 @@ TEST(BaseMatrix, BaseMatrix_real) {
|
|
|
|
|
|
|
|
|
|
BASEMATRIXCOMPARE(isEqualTo);
|
|
|
|
|
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -163,8 +162,8 @@ TEST(BaseMatrix, BaseMatrix_real) {
|
|
|
|
|
*/
|
|
|
|
|
TEST(BaseMatrix, BaseMatrix_BaseMatrix) {
|
|
|
|
|
typedef void (BaseMatrix::*FunctionProto)(BaseMatrix&, BaseMatrix&);
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare<0, 1>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
#define BASEMATRIXCOMPARE(function) \
|
|
|
|
|
BaseMatrixCompare<0, 1>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
BASEMATRIXCOMPARE(softCrossEntropy);
|
|
|
|
|
BASEMATRIXCOMPARE(softCrossEntropyBp);
|
|
|
|
@ -181,69 +180,25 @@ TEST(BaseMatrix, BaseMatrix_BaseMatrix) {
|
|
|
|
|
BASEMATRIXCOMPARE(dotMulSquare);
|
|
|
|
|
BASEMATRIXCOMPARE(dotSquareSquare);
|
|
|
|
|
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
#undef BASEMATRIXCOMPARE
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Test aggregate member functions which prototype is
|
|
|
|
|
* void (BaseMatrix::*)(BaseMatrix&).
|
|
|
|
|
*/
|
|
|
|
|
TEST(Aggregate, BaseMatrix) {
|
|
|
|
|
typedef void (BaseMatrix::*FunctionProto)(BaseMatrix&);
|
|
|
|
|
#define BASEMATRIXAPPLYROW(function) \
|
|
|
|
|
BaseMatrixApplyRow<0>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
#define BASEMATRIXAPPLYCOL(function) \
|
|
|
|
|
BaseMatrixApplyCol<0>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
BASEMATRIXAPPLYROW(maxRows);
|
|
|
|
|
BASEMATRIXAPPLYROW(minRows);
|
|
|
|
|
|
|
|
|
|
BASEMATRIXAPPLYCOL(sumCols);
|
|
|
|
|
BASEMATRIXAPPLYCOL(maxCols);
|
|
|
|
|
BASEMATRIXAPPLYCOL(minCols);
|
|
|
|
|
|
|
|
|
|
#undef BASEMATRIXAPPLYROW
|
|
|
|
|
#undef BASEMATRIXAPPLYCOL
|
|
|
|
|
// member function without overloaded
|
|
|
|
|
TEST(BaseMatrix, Other) {
|
|
|
|
|
BaseMatrixCompare<0, 1, 2>(&BaseMatrix::rowScale);
|
|
|
|
|
BaseMatrixCompare<0, 1, 2>(&BaseMatrix::rowDotMul);
|
|
|
|
|
BaseMatrixCompare<0, 1, 2, 3>(&BaseMatrix::binaryClassificationError);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Test aggregate member functions which prototype is
|
|
|
|
|
* void (BaseMatrix::*)(BaseMatrix&, BaseMatrix&).
|
|
|
|
|
*/
|
|
|
|
|
TEST(Aggregate, BaseMatrix_BaseMatrix) {
|
|
|
|
|
typedef void (BaseMatrix::*FunctionProto)(BaseMatrix&, BaseMatrix&);
|
|
|
|
|
#define BASEMATRIXAPPLYROW(function) \
|
|
|
|
|
BaseMatrixApplyRow<0, 1>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
#define BASEMATRIXAPPLYCOL(function) \
|
|
|
|
|
BaseMatrixApplyCol<0, 1>(static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
BASEMATRIXAPPLYCOL(addDotMulVMM);
|
|
|
|
|
|
|
|
|
|
#undef BASEMATRIXAPPLYROW
|
|
|
|
|
#undef BASEMATRIXAPPLYCOL
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Test aggregate member functions which prototype is
|
|
|
|
|
* void (BaseMatrix::*)(BaseMatrix&, real, real).
|
|
|
|
|
*/
|
|
|
|
|
TEST(Aggregate, BaseMatrix_real_real) {
|
|
|
|
|
typedef void (BaseMatrix::*FunctionProto)(BaseMatrix&, real, real);
|
|
|
|
|
#define BASEMATRIXAPPLYROW(function) \
|
|
|
|
|
BaseMatrixApplyRow<0, 1, 2>(\
|
|
|
|
|
static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
#define BASEMATRIXAPPLYCOL(function) \
|
|
|
|
|
BaseMatrixApplyCol<0, 1, 2>(\
|
|
|
|
|
static_cast<FunctionProto>(&BaseMatrix::function));
|
|
|
|
|
|
|
|
|
|
BASEMATRIXAPPLYROW(sumRows);
|
|
|
|
|
BASEMATRIXAPPLYCOL(sumCols);
|
|
|
|
|
TEST(BaseMatrix, Aggregate) {
|
|
|
|
|
BaseMatrixAsColVector<0>(&BaseMatrix::maxRows);
|
|
|
|
|
BaseMatrixAsColVector<0>(&BaseMatrix::minRows);
|
|
|
|
|
BaseMatrixAsColVector<0, 1, 2>(&BaseMatrix::sumRows);
|
|
|
|
|
|
|
|
|
|
#undef BASEMATRIXAPPLYROW
|
|
|
|
|
#undef BASEMATRIXAPPLYCOL
|
|
|
|
|
BaseMatrixAsRowVector<0>(&BaseMatrix::maxCols);
|
|
|
|
|
BaseMatrixAsRowVector<0>(&BaseMatrix::minCols);
|
|
|
|
|
BaseMatrixAsRowVector<0, 1>(&BaseMatrix::addDotMulVMM);
|
|
|
|
|
BaseMatrixAsRowVector<0, 1, 2>(&BaseMatrix::sumCols);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int main(int argc, char** argv) {
|
|
|
|
|