|
|
|
@ -72,19 +72,21 @@ public:
|
|
|
|
|
BufferArg(ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: buf_(nullptr),
|
|
|
|
|
valueType_(valueType),
|
|
|
|
|
shape_(shape),
|
|
|
|
|
argType_(argType) {}
|
|
|
|
|
: buf_(nullptr), valueType_(valueType), shape_(shape), argType_(argType) {
|
|
|
|
|
bufferType_ = TENSOR_NORMAL;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(void* buf,
|
|
|
|
|
ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {}
|
|
|
|
|
: buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {
|
|
|
|
|
bufferType_ = TENSOR_NORMAL;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(void* buf, ValueType valueType)
|
|
|
|
|
: buf_(buf), valueType_(valueType) {}
|
|
|
|
|
BufferArg(void* buf, ValueType valueType) : buf_(buf), valueType_(valueType) {
|
|
|
|
|
bufferType_ = TENSOR_NORMAL;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(const Matrix& matrix, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: buf_(
|
|
|
|
@ -173,7 +175,7 @@ protected:
|
|
|
|
|
TensorShape shape_;
|
|
|
|
|
BufferType bufferType_{TENSOR_UNKNOWN};
|
|
|
|
|
ArgType argType_{UNSPECIFIED};
|
|
|
|
|
// todo(tianbing), add deviceType_
|
|
|
|
|
// TODO(tianbing), add deviceType_
|
|
|
|
|
// leading dimensions. The size is dims_.size()
|
|
|
|
|
// Dims lds_;
|
|
|
|
|
};
|
|
|
|
@ -186,6 +188,7 @@ class SequenceIdArg : public BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
|
|
|
|
|
bufferType_ = TENSOR_SEQUENCE_ID;
|
|
|
|
|
CHECK_EQ(shape_.ndims(), (size_t)1);
|
|
|
|
|
CHECK_GT(shape_[0], 1);
|
|
|
|
|
numSeqs_ = shape_[0] - 1;
|
|
|
|
@ -223,7 +226,9 @@ public:
|
|
|
|
|
SequenceArg(ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {}
|
|
|
|
|
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {
|
|
|
|
|
bufferType_ = TENSOR_SEQUENCE_DATA;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SequenceArg(void* buf,
|
|
|
|
|
ValueType valueType,
|
|
|
|
@ -271,16 +276,16 @@ public:
|
|
|
|
|
row_(row),
|
|
|
|
|
col_(col),
|
|
|
|
|
nnz_(nnz),
|
|
|
|
|
format_(format),
|
|
|
|
|
type_(type) {
|
|
|
|
|
format_(static_cast<SparseDataFormat>(format)),
|
|
|
|
|
type_(static_cast<SparseDataType>(type)) {
|
|
|
|
|
bufferType_ = TENSOR_SPARSE;
|
|
|
|
|
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
|
|
|
|
|
CHECK_EQ(shape_.ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(row_.shape().ndims(), (size_t)1);
|
|
|
|
|
CHECK_EQ(col_.shape().ndims(), (size_t)1);
|
|
|
|
|
if (format == SPARSE_CSR) {
|
|
|
|
|
if (format_ == T_SPARSE_CSR) {
|
|
|
|
|
CHECK_EQ(nnz, col.shape()[0]);
|
|
|
|
|
} else if (format == SPARSE_CSC) {
|
|
|
|
|
} else if (format_ == T_SPARSE_CSC) {
|
|
|
|
|
CHECK_EQ(nnz, row.shape()[0]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -292,23 +297,23 @@ public:
|
|
|
|
|
SparseValueType type,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(valueType, shape, argType),
|
|
|
|
|
/// len of row_ : height + 1 (CSR), buf_ == nullptr
|
|
|
|
|
row_(format == SPARSE_CSR
|
|
|
|
|
? BufferArg(VALUE_TYPE_INT32, TensorShape{shape[0] + 1})
|
|
|
|
|
: BufferArg(VALUE_TYPE_INT32, TensorShape{nnz})),
|
|
|
|
|
/// len of col_ : width + 1 (CSC), buf_ == nullptr
|
|
|
|
|
col_(format == SPARSE_CSR
|
|
|
|
|
? BufferArg(VALUE_TYPE_INT32, TensorShape{nnz})
|
|
|
|
|
: BufferArg(VALUE_TYPE_INT32, TensorShape{shape[1] + 1})),
|
|
|
|
|
row_(BufferArg(nullptr, VALUE_TYPE_INT32)),
|
|
|
|
|
col_(BufferArg(nullptr, VALUE_TYPE_INT32)),
|
|
|
|
|
nnz_(nnz),
|
|
|
|
|
format_(format),
|
|
|
|
|
type_(type) {
|
|
|
|
|
format_(static_cast<SparseDataFormat>(format)),
|
|
|
|
|
type_(static_cast<SparseDataType>(type)) {
|
|
|
|
|
bufferType_ = TENSOR_SPARSE;
|
|
|
|
|
/// todo(tianbing)
|
|
|
|
|
/// valueType and shape_.ndims() == 2 need to check before
|
|
|
|
|
/// this constructor to make sure row_ and col_ are right
|
|
|
|
|
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
|
|
|
|
|
CHECK_EQ(shape_.ndims(), (size_t)2);
|
|
|
|
|
|
|
|
|
|
/// len of row_ : height + 1 (CSR) or nnz (CSC), buf_ == nullptr
|
|
|
|
|
row_ = (format_ == T_SPARSE_CSR
|
|
|
|
|
? BufferArg(VALUE_TYPE_INT32, TensorShape{shape_[0] + 1})
|
|
|
|
|
: BufferArg(VALUE_TYPE_INT32, TensorShape{nnz}));
|
|
|
|
|
/// len of col_ : width + 1 (CSC) or nnz (CSR), buf_ == nullptr
|
|
|
|
|
col_ = (format_ == T_SPARSE_CSR
|
|
|
|
|
? BufferArg(VALUE_TYPE_INT32, TensorShape{nnz})
|
|
|
|
|
: BufferArg(VALUE_TYPE_INT32, TensorShape{shape_[1] + 1}));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED);
|
|
|
|
@ -328,8 +333,8 @@ public:
|
|
|
|
|
shape_[0],
|
|
|
|
|
shape_[1],
|
|
|
|
|
nnz_,
|
|
|
|
|
type_,
|
|
|
|
|
format_,
|
|
|
|
|
static_cast<SparseValueType>(type_),
|
|
|
|
|
static_cast<SparseFormat>(format_),
|
|
|
|
|
false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -343,16 +348,16 @@ public:
|
|
|
|
|
|
|
|
|
|
size_t numElements() const override { return nnz_; }
|
|
|
|
|
|
|
|
|
|
SparseFormat dataFormat() const { return format_; }
|
|
|
|
|
SparseDataFormat dataFormat() const { return format_; }
|
|
|
|
|
|
|
|
|
|
SparseValueType dataType() const { return type_; }
|
|
|
|
|
SparseDataType dataType() const { return type_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
BufferArg row_;
|
|
|
|
|
BufferArg col_;
|
|
|
|
|
size_t nnz_;
|
|
|
|
|
SparseFormat format_;
|
|
|
|
|
SparseValueType type_;
|
|
|
|
|
SparseDataFormat format_;
|
|
|
|
|
SparseDataType type_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace paddle
|
|
|
|
|