|
|
|
@ -23,10 +23,11 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
enum BufferType {
|
|
|
|
|
TENSOR_NORMAL = 0,
|
|
|
|
|
TENSOR_SEQUENCE_ID = 1,
|
|
|
|
|
TENSOR_SEQUENCE_DATA = 2,
|
|
|
|
|
TENSOR_SPARSE = 3
|
|
|
|
|
TENSOR_UNKNOWN = 0,
|
|
|
|
|
TENSOR_NORMAL = 1,
|
|
|
|
|
TENSOR_SEQUENCE_ID = 2,
|
|
|
|
|
TENSOR_SEQUENCE_DATA = 3,
|
|
|
|
|
TENSOR_SPARSE = 4
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
enum SparseDataType {
|
|
|
|
@ -86,6 +87,7 @@ public:
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(2),
|
|
|
|
|
argType_(argType) {
|
|
|
|
|
bufferType_ = TENSOR_NORMAL;
|
|
|
|
|
shape_.setDim(0, matrix.getHeight());
|
|
|
|
|
shape_.setDim(1, matrix.getWidth());
|
|
|
|
|
}
|
|
|
|
@ -98,6 +100,7 @@ public:
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(shape),
|
|
|
|
|
argType_(argType) {
|
|
|
|
|
bufferType_ = TENSOR_NORMAL;
|
|
|
|
|
CHECK_EQ(matrix.getElementCnt(), shape.getElements());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -107,6 +110,7 @@ public:
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(1),
|
|
|
|
|
argType_(argType) {
|
|
|
|
|
bufferType_ = TENSOR_NORMAL;
|
|
|
|
|
shape_.setDim(0, vector.getSize());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -116,6 +120,7 @@ public:
|
|
|
|
|
valueType_(VALUE_TYPE_INT32),
|
|
|
|
|
shape_(1),
|
|
|
|
|
argType_(argType) {
|
|
|
|
|
bufferType_ = TENSOR_NORMAL;
|
|
|
|
|
shape_.setDim(0, vector.getSize());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -150,6 +155,8 @@ public:
|
|
|
|
|
ValueType valueType() const { return valueType_; }
|
|
|
|
|
BufferType bufferType() const { return bufferType_; }
|
|
|
|
|
const TensorShape& shape() const { return shape_; }
|
|
|
|
|
bool isSparse() const { return (TENSOR_SPARSE == bufferType_); }
|
|
|
|
|
bool isSequenceArg() const { return TENSOR_SEQUENCE_DATA == bufferType_; }
|
|
|
|
|
|
|
|
|
|
const SequenceArg& sequence() const;
|
|
|
|
|
const SparseMatrixArg& sparse() const;
|
|
|
|
@ -158,8 +165,8 @@ protected:
|
|
|
|
|
void* buf_;
|
|
|
|
|
ValueType valueType_;
|
|
|
|
|
TensorShape shape_;
|
|
|
|
|
BufferType bufferType_;
|
|
|
|
|
ArgType argType_ = UNSPECIFIED;
|
|
|
|
|
BufferType bufferType_{TENSOR_UNKNOWN};
|
|
|
|
|
ArgType argType_{UNSPECIFIED};
|
|
|
|
|
// leading dimensions. The size is dims_.size()
|
|
|
|
|
// Dims lds_;
|
|
|
|
|
};
|
|
|
|
@ -174,11 +181,13 @@ public:
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(buf, VALUE_TYPE_INT32, shape, argType) {
|
|
|
|
|
bufferType_ = TENSOR_SEQUENCE_ID;
|
|
|
|
|
CHECK_EQ(shape_.ndims(), (size_t)1);
|
|
|
|
|
numSeqs_ = shape_[0] - 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SequenceIdArg(const IVector& vector) : BufferArg(vector) {
|
|
|
|
|
bufferType_ = TENSOR_SEQUENCE_ID;
|
|
|
|
|
numSeqs_ = shape_[0] - 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -190,7 +199,7 @@ private:
|
|
|
|
|
size_t numSeqs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// sequence data
|
|
|
|
|
// sequence data {seqId(vec), buf(matrix)}
|
|
|
|
|
class SequenceArg : public BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
SequenceArg(void* buf,
|
|
|
|
@ -199,17 +208,22 @@ public:
|
|
|
|
|
const SequenceIdArg& startPositions,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(buf, valueType, shape, argType),
|
|
|
|
|
startPositions_(startPositions) {}
|
|
|
|
|
startPositions_(startPositions) {
|
|
|
|
|
bufferType_ = TENSOR_SEQUENCE_DATA;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SequenceArg(const Matrix& matrix,
|
|
|
|
|
const IVector& vector,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(matrix, argType), startPositions_(vector) {}
|
|
|
|
|
: BufferArg(matrix, argType), startPositions_(vector) {
|
|
|
|
|
bufferType_ = TENSOR_SEQUENCE_DATA;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~SequenceArg() {}
|
|
|
|
|
|
|
|
|
|
void* getIdBuf() const { return startPositions_.data(); }
|
|
|
|
|
size_t numSeqs() const { return startPositions_.numSeqs(); }
|
|
|
|
|
const SequenceIdArg& getSequenceIds() const { return startPositions_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
SequenceIdArg startPositions_;
|
|
|
|
@ -235,6 +249,7 @@ public:
|
|
|
|
|
nnz_(nnz),
|
|
|
|
|
format_(format),
|
|
|
|
|
type_(type) {
|
|
|
|
|
bufferType_ = TENSOR_SPARSE;
|
|
|
|
|
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
|
|
|
|
|
CHECK_EQ(shape_.ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(row_.shape().ndims(), (size_t)1);
|
|
|
|
|