|
|
|
@ -38,16 +38,40 @@ enum SparseDataType {
|
|
|
|
|
|
|
|
|
|
enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* BufferArg used as the argument type for Function.
|
|
|
|
|
*/
|
|
|
|
|
class BufferArg;
|
|
|
|
|
class SequenceArg;
|
|
|
|
|
class SparseMatrixArg;
|
|
|
|
|
typedef std::shared_ptr<BufferArg> BufferArgPtr;
|
|
|
|
|
|
|
|
|
|
// an array of arbitrary dimensions
|
|
|
|
|
/**
|
|
|
|
|
* \brief BufferArg used as the argument type of Function.
|
|
|
|
|
*
|
|
|
|
|
* The arguments of the Paddle Function have four Buffer types.
|
|
|
|
|
* 1. BufferArg for a dense Buffer of any dimension.
|
|
|
|
|
* 2. SequenceIdArg for a Buffer of sequence start positions.
|
|
|
|
|
* 3. SequenceArg for a Buffer of sequence data.
|
|
|
|
|
* 4. SparseMatrixArg for a Buffer of sparse matrix.
|
|
|
|
|
*
|
|
|
|
|
* There is an ArgType property for the BufferArg used as Function Output.
|
|
|
|
|
* Whether the result of the Function calculation is assigned to the
|
|
|
|
|
* output Buffer or added to the output Buffer is determined by the
|
|
|
|
|
* argType_ property of the output BufferArg.
|
|
|
|
|
*/
|
|
|
|
|
class BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
// ArgType is only used by output BufferArg.
|
|
|
|
|
// For input argument, argType_ is ignored.
|
|
|
|
|
// For output argument, need to set the argType_ of the BufferArg.
|
|
|
|
|
enum ArgType {
|
|
|
|
|
UNSPECIFIED = 0,
|
|
|
|
|
ASSIGN_TO = 1,
|
|
|
|
|
ADD_TO = 2,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void setArgType(ArgType argType) { argType_ = argType; }
|
|
|
|
|
|
|
|
|
|
ArgType getArgType() const { return argType_; }
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
BufferArg(void* buf, ValueType valueType, const TensorShape& shape)
|
|
|
|
|
: buf_(buf), valueType_(valueType), shape_(shape) {}
|
|
|
|
@ -56,7 +80,8 @@ public:
|
|
|
|
|
: buf_(buf), valueType_(valueType) {}
|
|
|
|
|
|
|
|
|
|
BufferArg(const Matrix& matrix)
|
|
|
|
|
: buf_(reinterpret_cast<void*>(matrix.getData())),
|
|
|
|
|
: buf_(
|
|
|
|
|
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(2) {
|
|
|
|
|
shape_.setDim(0, matrix.getHeight());
|
|
|
|
@ -64,21 +89,24 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(const Matrix& matrix, const TensorShape& shape)
|
|
|
|
|
: buf_(reinterpret_cast<void*>(matrix.getData())),
|
|
|
|
|
: buf_(
|
|
|
|
|
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(shape) {
|
|
|
|
|
CHECK_EQ(matrix.getElementCnt(), shape.getElements());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(const Vector& vector)
|
|
|
|
|
: buf_(reinterpret_cast<void*>(vector.getData())),
|
|
|
|
|
: buf_(
|
|
|
|
|
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(1) {
|
|
|
|
|
shape_.setDim(0, vector.getSize());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(const IVector& vector)
|
|
|
|
|
: buf_(reinterpret_cast<void*>(vector.getData())),
|
|
|
|
|
: buf_(
|
|
|
|
|
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
|
|
|
|
|
valueType_(VALUE_TYPE_INT32),
|
|
|
|
|
shape_(1) {
|
|
|
|
|
shape_.setDim(0, vector.getSize());
|
|
|
|
@ -124,6 +152,7 @@ protected:
|
|
|
|
|
ValueType valueType_;
|
|
|
|
|
TensorShape shape_;
|
|
|
|
|
BufferType bufferType_;
|
|
|
|
|
ArgType argType_ = UNSPECIFIED;
|
|
|
|
|
// leading dimensions. The size is dims_.size()
|
|
|
|
|
// Dims lds_;
|
|
|
|
|
};
|
|
|
|
|