You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
317 lines
5.7 KiB
317 lines
5.7 KiB
9 years ago
|
/**
|
||
|
* hl_tensor_ops.h
|
||
|
*
|
||
|
* Author: hedaoyuan (hedaoyuan@baidu.com)
|
||
|
* Created on: 2016-06-06
|
||
|
*
|
||
|
* Copyright (c) Baidu.com, Inc. All Rights Reserved
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
#ifndef HL_TENSOR_OPS_H_
|
||
|
#define HL_TENSOR_OPS_H_
|
||
|
|
||
|
#include <cmath>
|
||
|
#include "hl_matrix_type.cuh"
|
||
|
|
||
|
namespace hppl {
|
||
|
namespace unary {
|
||
|
|
||
|
template<class T>
|
||
|
class add_scale{
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE add_scale(const T s) : p(s) {}
|
||
|
INLINE T operator()(const T a) const { return a + p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class sub_scale {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE sub_scale(const T s) : p(s) {}
|
||
|
INLINE T operator()(const T a) const { return a - p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class mul_scale {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE mul_scale(const T s) : p(s) {}
|
||
|
INLINE T operator()(const T a) const { return a * p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class div_scale {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE div_scale(const T s) : p(s) {}
|
||
|
INLINE T operator()(const T a) const { return a / p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class neg {
|
||
|
public:
|
||
|
INLINE T operator()(const T a) const { return -a; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class exp_op {
|
||
|
public:
|
||
|
INLINE T operator()(const T a) const { return std::exp(a); }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class log_op {
|
||
|
public:
|
||
|
INLINE T operator()(const T a) const { return std::log(a); }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class sqrt_op {
|
||
|
public:
|
||
|
INLINE T operator()(const T a) const { return std::sqrt(a); }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class square {
|
||
|
public:
|
||
|
INLINE T operator()(const T a) const { return a * a; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class reciprocal {
|
||
|
public:
|
||
|
INLINE T operator()(const T a) const { return T(1) / a; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class abs {
|
||
|
public:
|
||
|
INLINE T operator()(const T a) const { return a > 0 ? a : -a; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class sign {
|
||
|
public:
|
||
|
INLINE T operator()(const T a) const { return (a > 0) - (a < 0); }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class min {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE min(const T s) : p(s) {}
|
||
|
INLINE T operator()(const T a) const { return a > p ? p : a; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class max {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE max(const T s) : p(s) {}
|
||
|
INLINE T operator()(const T a) const { return a < p ? p : a; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class pow_op {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE pow_op(const T s) : p(s) {}
|
||
|
INLINE T operator()(const T a) const { return std::pow(a, p); }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class constant {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE constant(const T s) : p(s) {}
|
||
|
INLINE T operator()(int i) const { return p; }
|
||
|
INLINE T operator()(int i, int j) const { return p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_eq {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE cmp_eq(const T s) : p(s) {}
|
||
|
INLINE bool operator()(const T a) const { return a == p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_ne {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE cmp_ne(const T s) : p(s) {}
|
||
|
INLINE bool operator()(const T a) const { return a != p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_le {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE cmp_le(const T s) : p(s) {}
|
||
|
INLINE bool operator()(const T a) const { return a <= p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_lt {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE cmp_lt(const T s) : p(s) {}
|
||
|
INLINE bool operator()(const T a) const { return a < p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_ge {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE cmp_ge(const T s) : p(s) {}
|
||
|
INLINE bool operator()(const T a) const { return a >= p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_gt {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE cmp_gt(const T s) : p(s) {}
|
||
|
INLINE bool operator()(const T a) const { return a > p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class and_op {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE and_op(const T s) : p(s) {}
|
||
|
INLINE bool operator()(const T a) const { return a && p; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class or_op {
|
||
|
private:
|
||
|
const T p;
|
||
|
public:
|
||
|
INLINE or_op(const T s) : p(s) {}
|
||
|
INLINE bool operator()(const T a) const { return a || p; }
|
||
|
};
|
||
|
|
||
|
} // namespace unary
|
||
|
|
||
|
namespace binary {
|
||
|
template<class T>
|
||
|
class add {
|
||
|
public:
|
||
|
INLINE T operator()(const T a, const T b) const { return a + b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class add_scale {
|
||
|
private:
|
||
|
const T p1;
|
||
|
const T p2;
|
||
|
public:
|
||
|
INLINE add_scale(const T s1, const T s2) : p1(s1), p2(s2) {}
|
||
|
INLINE T operator()(const T a, const T b) const {
|
||
|
return p1 * a + p2 * b;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class sub {
|
||
|
public:
|
||
|
INLINE T operator()(const T a, const T b) const { return a - b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class mul {
|
||
|
public:
|
||
|
INLINE T operator()(const T a, const T b) const { return a * b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class div {
|
||
|
public:
|
||
|
INLINE T operator()(const T a, const T b) const { return a / b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_eq {
|
||
|
public:
|
||
|
INLINE bool operator()(const T a, const T b) const { return a == b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_ne {
|
||
|
public:
|
||
|
INLINE bool operator()(const T a, const T b) const { return a != b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_le {
|
||
|
public:
|
||
|
INLINE bool operator()(const T a, const T b) const { return a <= b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_lt {
|
||
|
public:
|
||
|
INLINE bool operator()(const T a, const T b) const { return a < b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_ge {
|
||
|
public:
|
||
|
INLINE bool operator()(const T a, const T b) const { return a >= b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class cmp_gt {
|
||
|
public:
|
||
|
INLINE bool operator()(const T a, const T b) const { return a > b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class and_op {
|
||
|
public:
|
||
|
INLINE bool operator()(const T a, const T b) const { return a && b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class or_op {
|
||
|
public:
|
||
|
INLINE bool operator()(const T a, const T b) const { return a || b; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class min {
|
||
|
public:
|
||
|
INLINE T operator()(const T a, const T b) const { return a > b ? b : a; }
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
class max {
|
||
|
public:
|
||
|
INLINE T operator()(const T a, const T b) const { return a < b ? b : a; }
|
||
|
};
|
||
|
|
||
|
} // namespace binary
|
||
|
} // namespace hppl
|
||
|
|
||
|
#endif // HL_TENSOR_OPS_H_
|
||
|
|