You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/math/TensorExpression.h

410 lines
12 KiB

/**
* TensorExpression.h
*
* Author: hedaoyuan (hedaoyuan@baidu.com)
* Created on: 2016-06-06
*
* Copyright (c) Baidu.com, Inc. All Rights Reserved
*
*/
#pragma once
#include <cstddef>
#include <stdint.h>
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/Logging.h"
#include "hl_tensor_ops.h"
namespace paddle {
template<class OP, typename ExprType, class T> class TensorConstant;
template<class OP, typename ExprType, class T> class TensorUnaryOp;
template<
class OP, typename LhsType, typename RhsType, class T> class TensorBinaryOp;
template<
typename ExprType1,
typename ExprType2,
typename ExprType3,
class T> class TensorTernaryOp;
/**
* \brief Tensor base class.
*
* This is the base class of all Tensor and Expression class.
*/
template<typename Derived, class T>
class TensorExpression {
public:
/**
* Element wise unary expression.
*/
template<typename UnaryOp>
const TensorUnaryOp<UnaryOp, const Derived, T>
unaryExpression(const UnaryOp& op) const {
return TensorUnaryOp<UnaryOp, const Derived, T>(op, derived());
}
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T>
operator+(T p) const {
return unaryExpression(hppl::unary::add_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::sub_scale<T>, const Derived, T>
operator-(T p) const {
return unaryExpression(hppl::unary::sub_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T>
operator*(T p) const {
return unaryExpression(hppl::unary::mul_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::div_scale<T>, const Derived, T>
operator/(T p) const {
return unaryExpression(hppl::unary::div_scale<T>(p));
}
const TensorUnaryOp<hppl::unary::neg<T>, const Derived, T>
operator-() const {
return unaryExpression(hppl::unary::neg<T>());
}
const TensorUnaryOp<hppl::unary::exp_op<T>, const Derived, T>
exp() const {
return unaryExpression(hppl::unary::exp_op<T>());
}
const TensorUnaryOp<hppl::unary::log_op<T>, const Derived, T>
log() const {
return unaryExpression(hppl::unary::log_op<T>());
}
const TensorUnaryOp<hppl::unary::sqrt_op<T>, const Derived, T>
sqrt() const {
return unaryExpression(hppl::unary::sqrt_op<T>());
}
const TensorUnaryOp<hppl::unary::square<T>, const Derived, T>
square() const {
return unaryExpression(hppl::unary::square<T>());
}
const TensorUnaryOp<hppl::unary::reciprocal<T>, const Derived, T>
reciprocal() const {
return unaryExpression(hppl::unary::reciprocal<T>());
}
const TensorUnaryOp<hppl::unary::abs<T>, const Derived, T>
abs() const {
return unaryExpression(hppl::unary::abs<T>());
}
const TensorUnaryOp<hppl::unary::sign<T>, const Derived, T>
sign() const {
return unaryExpression(hppl::unary::sign<T>());
}
const TensorUnaryOp<hppl::unary::pow_op<T>, const Derived, T>
pow(T p) const {
return unaryExpression(hppl::unary::pow_op<T>(p));
}
const TensorUnaryOp<hppl::unary::min<T>, const Derived, T>
min(T p) const {
return unaryExpression(hppl::unary::min<T>(p));
}
const TensorUnaryOp<hppl::unary::max<T>, const Derived, T>
max(T p) const {
return unaryExpression(hppl::unary::max<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_eq<T>, const Derived, T>
operator==(T p) const {
return unaryExpression(hppl::unary::cmp_eq<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_ne<T>, const Derived, T>
operator!=(T p) const {
return unaryExpression(hppl::unary::cmp_ne<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_le<T>, const Derived, T>
operator<=(T p) const {
return unaryExpression(hppl::unary::cmp_le<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_lt<T>, const Derived, T>
operator<(T p) const {
return unaryExpression(hppl::unary::cmp_lt<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_ge<T>, const Derived, T>
operator>=(T p) const {
return unaryExpression(hppl::unary::cmp_ge<T>(p));
}
const TensorUnaryOp<hppl::unary::cmp_gt<T>, const Derived, T>
operator>(T p) const {
return unaryExpression(hppl::unary::cmp_gt<T>(p));
}
const TensorUnaryOp<hppl::unary::and_op<T>, const Derived, T>
operator&&(T p) const {
return unaryExpression(hppl::unary::and_op<T>(p));
}
const TensorUnaryOp<hppl::unary::or_op<T>, const Derived, T>
operator||(T p) const {
return unaryExpression(hppl::unary::or_op<T>(p));
}
/**
* Element wise binary expression.
*/
template<typename BinaryOp, typename ExpressionType>
const TensorBinaryOp<BinaryOp, const Derived, const ExpressionType, T>
binaryExpression(const BinaryOp& op, const ExpressionType& expr) const {
return TensorBinaryOp<BinaryOp, const Derived, const ExpressionType, T>(
op, derived(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_eq<T>, const Derived, const ExpressionType, T>
operator==(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_eq<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_ne<T>, const Derived, const ExpressionType, T>
operator!=(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_ne<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_le<T>, const Derived, const ExpressionType, T>
operator<=(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_le<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_lt<T>, const Derived, const ExpressionType, T>
operator<(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_lt<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_ge<T>, const Derived, const ExpressionType, T>
operator>=(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_ge<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::cmp_gt<T>, const Derived, const ExpressionType, T>
operator>(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::cmp_gt<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::and_op<T>, const Derived, const ExpressionType, T>
operator&&(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::and_op<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::or_op<T>, const Derived, const ExpressionType, T>
operator||(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::or_op<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::add<T>, const Derived, const ExpressionType, T>
operator+(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::add<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::sub<T>, const Derived, const ExpressionType, T>
operator-(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::sub<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::mul<T>, const Derived, const ExpressionType, T>
operator*(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::mul<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::div<T>, const Derived, const ExpressionType, T>
operator/(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::div<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::min<T>, const Derived, const ExpressionType, T>
min(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::min<T>(), expr);
}
template<typename ExpressionType>
const TensorBinaryOp<
hppl::binary::max<T>, const Derived, const ExpressionType, T>
max(const ExpressionType& expr) const {
return binaryExpression(hppl::binary::max<T>(), expr);
}
/**
* Element wise ternary expression.
*
* ternary conditional operator(?: operator).
* The conditional expression returns one of two values depending on
* the result of derived expression.
* If derived expression evaluates to true, then expression1 is evaluated.
* If derived expression evaluates to false, then expression2 is evaluated.
*/
template<typename ExprType1, typename ExprType2>
const TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>
condition(const ExprType1& expr1, const ExprType2& expr2) const {
return TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>
(derived(), expr1, expr2);
}
template<typename ExprType>
const TensorTernaryOp<
const Derived,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
const ExprType,
T>
condition(T p, const ExprType& expr) const {
return condition(constant(p), expr);
}
template<typename ExprType>
const TensorTernaryOp<
const Derived,
const ExprType,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
T>
condition(const ExprType& expr, T p) const {
return condition(expr, constant(p));
}
const TensorTernaryOp<
const Derived,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
T>
condition(T p1, T p2) const {
return condition(constant(p1), constant(p2));
}
const TensorConstant<hppl::unary::constant<T>, const Derived, T>
constant(T p) const {
return TensorConstant<hppl::unary::constant<T>, const Derived, T>
(hppl::unary::constant<T>(p), derived());
}
protected:
const Derived& derived() const { return *static_cast<const Derived*>(this); }
};
/**
* \brief Unary Operator Expression
*/
template<class OP, typename ExprType, class T>
class TensorUnaryOp
: public TensorExpression<TensorUnaryOp<OP, ExprType, T>, T> {
public:
explicit TensorUnaryOp(const OP op, const ExprType& expr)
: op_(op), expr_(expr) {}
const OP op_;
const ExprType expr_;
};
/**
* \brief Binary Operator Expression
*/
template<class OP, typename LhsType, typename RhsType, class T>
class TensorBinaryOp
: public TensorExpression<TensorBinaryOp<OP, LhsType, RhsType, T>, T> {
public:
explicit TensorBinaryOp(const OP op, const LhsType& lhs, const RhsType& rhs)
: op_(op), lhs_(lhs), rhs_(rhs) {}
const OP op_;
const LhsType lhs_;
const RhsType rhs_;
};
/**
* \brief Ternary Operator Expression
*/
template<typename ExprType1, typename ExprType2, typename ExprType3, class T>
class TensorTernaryOp
: public TensorExpression<
TensorTernaryOp<ExprType1, ExprType2, ExprType3, T>, T> {
public:
explicit TensorTernaryOp(
const ExprType1& expr1, const ExprType2& expr2, const ExprType3& expr3)
: expr1_(expr1), expr2_(expr2), expr3_(expr3) {}
const ExprType1 expr1_;
const ExprType2 expr2_;
const ExprType3 expr3_;
};
/**
* \brief Constant Expression
*/
template<class OP, typename ExprType, class T>
class TensorConstant
: public TensorExpression<TensorConstant<OP, ExprType, T>, T> {
public:
explicit TensorConstant(const OP op, const ExprType& expr)
: op_(op), expr_(expr) {}
const OP op_;
const ExprType expr_;
};
/**
* \brief operator+ overload
* \return a unary operator expression
*/
template<typename Derived, class T>
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T>
operator+(T p, const TensorExpression<Derived, T>& expr) {
return expr + p;
}
/**
* \brief operator* overload
* \return a unary operator expression
*/
template<typename Derived, class T>
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T>
operator*(T p, const TensorExpression<Derived, T>& expr) {
return expr * p;
}
} // namespace paddle
#include "TensorApply.h"
#include "TensorEvaluate.h"