You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
410 lines
12 KiB
410 lines
12 KiB
/**
|
|
* TensorExpression.h
|
|
*
|
|
* Author: hedaoyuan (hedaoyuan@baidu.com)
|
|
* Created on: 2016-06-06
|
|
*
|
|
* Copyright (c) Baidu.com, Inc. All Rights Reserved
|
|
*
|
|
*/
|
|
|
|
#pragma once
|
|
#include <cstddef>
|
|
#include <stdint.h>
|
|
#include "paddle/utils/TypeDefs.h"
|
|
#include "paddle/utils/Logging.h"
|
|
#include "hl_tensor_ops.h"
|
|
|
|
namespace paddle {
|
|
|
|
template<class OP, typename ExprType, class T> class TensorConstant;
|
|
template<class OP, typename ExprType, class T> class TensorUnaryOp;
|
|
template<
|
|
class OP, typename LhsType, typename RhsType, class T> class TensorBinaryOp;
|
|
template<
|
|
typename ExprType1,
|
|
typename ExprType2,
|
|
typename ExprType3,
|
|
class T> class TensorTernaryOp;
|
|
|
|
/**
|
|
* \brief Tensor base class.
|
|
*
|
|
* This is the base class of all Tensor and Expression class.
|
|
*/
|
|
template<typename Derived, class T>
|
|
class TensorExpression {
|
|
public:
|
|
/**
|
|
* Element wise unary expression.
|
|
*/
|
|
template<typename UnaryOp>
|
|
const TensorUnaryOp<UnaryOp, const Derived, T>
|
|
unaryExpression(const UnaryOp& op) const {
|
|
return TensorUnaryOp<UnaryOp, const Derived, T>(op, derived());
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T>
|
|
operator+(T p) const {
|
|
return unaryExpression(hppl::unary::add_scale<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::sub_scale<T>, const Derived, T>
|
|
operator-(T p) const {
|
|
return unaryExpression(hppl::unary::sub_scale<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T>
|
|
operator*(T p) const {
|
|
return unaryExpression(hppl::unary::mul_scale<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::div_scale<T>, const Derived, T>
|
|
operator/(T p) const {
|
|
return unaryExpression(hppl::unary::div_scale<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::neg<T>, const Derived, T>
|
|
operator-() const {
|
|
return unaryExpression(hppl::unary::neg<T>());
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::exp_op<T>, const Derived, T>
|
|
exp() const {
|
|
return unaryExpression(hppl::unary::exp_op<T>());
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::log_op<T>, const Derived, T>
|
|
log() const {
|
|
return unaryExpression(hppl::unary::log_op<T>());
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::sqrt_op<T>, const Derived, T>
|
|
sqrt() const {
|
|
return unaryExpression(hppl::unary::sqrt_op<T>());
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::square<T>, const Derived, T>
|
|
square() const {
|
|
return unaryExpression(hppl::unary::square<T>());
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::reciprocal<T>, const Derived, T>
|
|
reciprocal() const {
|
|
return unaryExpression(hppl::unary::reciprocal<T>());
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::abs<T>, const Derived, T>
|
|
abs() const {
|
|
return unaryExpression(hppl::unary::abs<T>());
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::sign<T>, const Derived, T>
|
|
sign() const {
|
|
return unaryExpression(hppl::unary::sign<T>());
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::pow_op<T>, const Derived, T>
|
|
pow(T p) const {
|
|
return unaryExpression(hppl::unary::pow_op<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::min<T>, const Derived, T>
|
|
min(T p) const {
|
|
return unaryExpression(hppl::unary::min<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::max<T>, const Derived, T>
|
|
max(T p) const {
|
|
return unaryExpression(hppl::unary::max<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::cmp_eq<T>, const Derived, T>
|
|
operator==(T p) const {
|
|
return unaryExpression(hppl::unary::cmp_eq<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::cmp_ne<T>, const Derived, T>
|
|
operator!=(T p) const {
|
|
return unaryExpression(hppl::unary::cmp_ne<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::cmp_le<T>, const Derived, T>
|
|
operator<=(T p) const {
|
|
return unaryExpression(hppl::unary::cmp_le<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::cmp_lt<T>, const Derived, T>
|
|
operator<(T p) const {
|
|
return unaryExpression(hppl::unary::cmp_lt<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::cmp_ge<T>, const Derived, T>
|
|
operator>=(T p) const {
|
|
return unaryExpression(hppl::unary::cmp_ge<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::cmp_gt<T>, const Derived, T>
|
|
operator>(T p) const {
|
|
return unaryExpression(hppl::unary::cmp_gt<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::and_op<T>, const Derived, T>
|
|
operator&&(T p) const {
|
|
return unaryExpression(hppl::unary::and_op<T>(p));
|
|
}
|
|
|
|
const TensorUnaryOp<hppl::unary::or_op<T>, const Derived, T>
|
|
operator||(T p) const {
|
|
return unaryExpression(hppl::unary::or_op<T>(p));
|
|
}
|
|
|
|
/**
|
|
* Element wise binary expression.
|
|
*/
|
|
template<typename BinaryOp, typename ExpressionType>
|
|
const TensorBinaryOp<BinaryOp, const Derived, const ExpressionType, T>
|
|
binaryExpression(const BinaryOp& op, const ExpressionType& expr) const {
|
|
return TensorBinaryOp<BinaryOp, const Derived, const ExpressionType, T>(
|
|
op, derived(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::cmp_eq<T>, const Derived, const ExpressionType, T>
|
|
operator==(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::cmp_eq<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::cmp_ne<T>, const Derived, const ExpressionType, T>
|
|
operator!=(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::cmp_ne<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::cmp_le<T>, const Derived, const ExpressionType, T>
|
|
operator<=(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::cmp_le<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::cmp_lt<T>, const Derived, const ExpressionType, T>
|
|
operator<(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::cmp_lt<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::cmp_ge<T>, const Derived, const ExpressionType, T>
|
|
operator>=(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::cmp_ge<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::cmp_gt<T>, const Derived, const ExpressionType, T>
|
|
operator>(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::cmp_gt<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::and_op<T>, const Derived, const ExpressionType, T>
|
|
operator&&(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::and_op<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::or_op<T>, const Derived, const ExpressionType, T>
|
|
operator||(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::or_op<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::add<T>, const Derived, const ExpressionType, T>
|
|
operator+(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::add<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::sub<T>, const Derived, const ExpressionType, T>
|
|
operator-(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::sub<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::mul<T>, const Derived, const ExpressionType, T>
|
|
operator*(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::mul<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::div<T>, const Derived, const ExpressionType, T>
|
|
operator/(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::div<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::min<T>, const Derived, const ExpressionType, T>
|
|
min(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::min<T>(), expr);
|
|
}
|
|
|
|
template<typename ExpressionType>
|
|
const TensorBinaryOp<
|
|
hppl::binary::max<T>, const Derived, const ExpressionType, T>
|
|
max(const ExpressionType& expr) const {
|
|
return binaryExpression(hppl::binary::max<T>(), expr);
|
|
}
|
|
|
|
/**
|
|
* Element wise ternary expression.
|
|
*
|
|
* ternary conditional operator(?: operator).
|
|
* The conditional expression returns one of two values depending on
|
|
* the result of derived expression.
|
|
* If derived expression evaluates to true, then expression1 is evaluated.
|
|
* If derived expression evaluates to false, then expression2 is evaluated.
|
|
*/
|
|
template<typename ExprType1, typename ExprType2>
|
|
const TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>
|
|
condition(const ExprType1& expr1, const ExprType2& expr2) const {
|
|
return TensorTernaryOp<const Derived, const ExprType1, const ExprType2, T>
|
|
(derived(), expr1, expr2);
|
|
}
|
|
|
|
template<typename ExprType>
|
|
const TensorTernaryOp<
|
|
const Derived,
|
|
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
|
|
const ExprType,
|
|
T>
|
|
condition(T p, const ExprType& expr) const {
|
|
return condition(constant(p), expr);
|
|
}
|
|
|
|
template<typename ExprType>
|
|
const TensorTernaryOp<
|
|
const Derived,
|
|
const ExprType,
|
|
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
|
|
T>
|
|
condition(const ExprType& expr, T p) const {
|
|
return condition(expr, constant(p));
|
|
}
|
|
|
|
const TensorTernaryOp<
|
|
const Derived,
|
|
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
|
|
const TensorConstant<hppl::unary::constant<T>, const Derived, T>,
|
|
T>
|
|
condition(T p1, T p2) const {
|
|
return condition(constant(p1), constant(p2));
|
|
}
|
|
|
|
const TensorConstant<hppl::unary::constant<T>, const Derived, T>
|
|
constant(T p) const {
|
|
return TensorConstant<hppl::unary::constant<T>, const Derived, T>
|
|
(hppl::unary::constant<T>(p), derived());
|
|
}
|
|
|
|
protected:
|
|
const Derived& derived() const { return *static_cast<const Derived*>(this); }
|
|
};
|
|
|
|
/**
|
|
* \brief Unary Operator Expression
|
|
*/
|
|
template<class OP, typename ExprType, class T>
|
|
class TensorUnaryOp
|
|
: public TensorExpression<TensorUnaryOp<OP, ExprType, T>, T> {
|
|
public:
|
|
explicit TensorUnaryOp(const OP op, const ExprType& expr)
|
|
: op_(op), expr_(expr) {}
|
|
|
|
const OP op_;
|
|
const ExprType expr_;
|
|
};
|
|
|
|
/**
|
|
* \brief Binary Operator Expression
|
|
*/
|
|
template<class OP, typename LhsType, typename RhsType, class T>
|
|
class TensorBinaryOp
|
|
: public TensorExpression<TensorBinaryOp<OP, LhsType, RhsType, T>, T> {
|
|
public:
|
|
explicit TensorBinaryOp(const OP op, const LhsType& lhs, const RhsType& rhs)
|
|
: op_(op), lhs_(lhs), rhs_(rhs) {}
|
|
|
|
const OP op_;
|
|
const LhsType lhs_;
|
|
const RhsType rhs_;
|
|
};
|
|
|
|
/**
|
|
* \brief Ternary Operator Expression
|
|
*/
|
|
template<typename ExprType1, typename ExprType2, typename ExprType3, class T>
|
|
class TensorTernaryOp
|
|
: public TensorExpression<
|
|
TensorTernaryOp<ExprType1, ExprType2, ExprType3, T>, T> {
|
|
public:
|
|
explicit TensorTernaryOp(
|
|
const ExprType1& expr1, const ExprType2& expr2, const ExprType3& expr3)
|
|
: expr1_(expr1), expr2_(expr2), expr3_(expr3) {}
|
|
|
|
const ExprType1 expr1_;
|
|
const ExprType2 expr2_;
|
|
const ExprType3 expr3_;
|
|
};
|
|
|
|
/**
|
|
* \brief Constant Expression
|
|
*/
|
|
template<class OP, typename ExprType, class T>
|
|
class TensorConstant
|
|
: public TensorExpression<TensorConstant<OP, ExprType, T>, T> {
|
|
public:
|
|
explicit TensorConstant(const OP op, const ExprType& expr)
|
|
: op_(op), expr_(expr) {}
|
|
|
|
const OP op_;
|
|
const ExprType expr_;
|
|
};
|
|
|
|
/**
|
|
* \brief operator+ overload
|
|
* \return a unary operator expression
|
|
*/
|
|
template<typename Derived, class T>
|
|
const TensorUnaryOp<hppl::unary::add_scale<T>, const Derived, T>
|
|
operator+(T p, const TensorExpression<Derived, T>& expr) {
|
|
return expr + p;
|
|
}
|
|
|
|
/**
|
|
* \brief operator* overload
|
|
* \return a unary operator expression
|
|
*/
|
|
template<typename Derived, class T>
|
|
const TensorUnaryOp<hppl::unary::mul_scale<T>, const Derived, T>
|
|
operator*(T p, const TensorExpression<Derived, T>& expr) {
|
|
return expr * p;
|
|
}
|
|
|
|
} // namespace paddle
|
|
|
|
#include "TensorApply.h"
|
|
#include "TensorEvaluate.h"
|
|
|