|
|
|
@ -65,15 +65,24 @@ public:
|
|
|
|
|
// construct a argument
|
|
|
|
|
template <typename T>
|
|
|
|
|
T construct(int height, int width);
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
float construct(int height, int width) {
|
|
|
|
|
return 0.0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
size_t construct(int height, int width) {
|
|
|
|
|
size_t offset = std::rand() % (height < width ? height : width);
|
|
|
|
|
return offset;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
CpuMatrix construct(int height, int width) {
|
|
|
|
|
CpuMatrix a(height, width);
|
|
|
|
|
return a;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
GpuMatrix construct(int height, int width) {
|
|
|
|
|
GpuMatrix a(height, width);
|
|
|
|
@ -83,14 +92,22 @@ GpuMatrix construct(int height, int width) {
|
|
|
|
|
// init a argument
|
|
|
|
|
template <typename T>
|
|
|
|
|
void init(T& v);
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void init(float& v) {
|
|
|
|
|
v = 0.5;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void init(size_t& v) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void init(CpuMatrix& v) {
|
|
|
|
|
v.randomizeUniform();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void init(GpuMatrix& v) {
|
|
|
|
|
v.randomizeUniform();
|
|
|
|
@ -111,10 +128,17 @@ template <std::size_t I = 0, typename... Args>
|
|
|
|
|
// copy a argument, copy src to dest
|
|
|
|
|
template <typename T1, typename T2>
|
|
|
|
|
void copy(T1& dest, T2& src);
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void copy(float& dest, float& src) {
|
|
|
|
|
dest = src;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void copy(size_t& dest, size_t& src) {
|
|
|
|
|
dest = src;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void copy(GpuMatrix& dest, CpuMatrix& src) {
|
|
|
|
|
dest.copyFrom(src);
|
|
|
|
@ -165,8 +189,8 @@ R call(C& obj, R (FC::*f)(FArgs...), Args&&... args) {
|
|
|
|
|
return (obj.*f)(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <bool ApplyRow,
|
|
|
|
|
bool ApplyCol,
|
|
|
|
|
template <bool AsRowVector,
|
|
|
|
|
bool AsColVector,
|
|
|
|
|
std::size_t... I,
|
|
|
|
|
typename C,
|
|
|
|
|
typename R,
|
|
|
|
@ -177,8 +201,8 @@ void BaseMatrixCompare(R (C::*f)(Args...),
|
|
|
|
|
bool checkArgs = false) {
|
|
|
|
|
for (auto height : {1, 11, 73, 128, 200, 330}) {
|
|
|
|
|
for (auto width : {1, 3, 32, 100, 512, 1000}) {
|
|
|
|
|
CpuMatrix obj1(ApplyCol ? 1 : height, ApplyRow ? 1 : width);
|
|
|
|
|
GpuMatrix obj2(ApplyCol ? 1 : height, ApplyRow ? 1 : width);
|
|
|
|
|
CpuMatrix obj1(AsRowVector ? 1 : height, AsColVector ? 1 : width);
|
|
|
|
|
GpuMatrix obj2(AsRowVector ? 1 : height, AsColVector ? 1 : width);
|
|
|
|
|
init(obj1);
|
|
|
|
|
copy(obj2, obj1);
|
|
|
|
|
|
|
|
|
@ -227,7 +251,7 @@ void BaseMatrixCompare(R (C::*f)(Args...), bool checkArgs = false) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <std::size_t... I, typename C, typename R, typename... Args>
|
|
|
|
|
void BaseMatrixApplyRow(R (C::*f)(Args...)) {
|
|
|
|
|
void BaseMatrixAsColVector(R (C::*f)(Args...)) {
|
|
|
|
|
static_assert(sizeof...(I) == sizeof...(Args),
|
|
|
|
|
"size of parameter packs are not equal");
|
|
|
|
|
|
|
|
|
@ -237,11 +261,11 @@ void BaseMatrixApplyRow(R (C::*f)(Args...)) {
|
|
|
|
|
autotest::AssertEqual compare(1e-8);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
autotest::BaseMatrixCompare<true, false, I...>(f, compare);
|
|
|
|
|
autotest::BaseMatrixCompare<false, true, I...>(f, compare);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <std::size_t... I, typename C, typename R, typename... Args>
|
|
|
|
|
void BaseMatrixApplyCol(R (C::*f)(Args...)) {
|
|
|
|
|
void BaseMatrixAsRowVector(R (C::*f)(Args...)) {
|
|
|
|
|
static_assert(sizeof...(I) == sizeof...(Args),
|
|
|
|
|
"size of parameter packs are not equal");
|
|
|
|
|
|
|
|
|
@ -250,5 +274,5 @@ void BaseMatrixApplyCol(R (C::*f)(Args...)) {
|
|
|
|
|
#else
|
|
|
|
|
autotest::AssertEqual compare(1e-8);
|
|
|
|
|
#endif
|
|
|
|
|
autotest::BaseMatrixCompare<false, true, I...>(f, compare);
|
|
|
|
|
autotest::BaseMatrixCompare<true, false, I...>(f, compare);
|
|
|
|
|
}
|
|
|
|
|