|
|
|
@ -57,58 +57,67 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr;
|
|
|
|
|
* output Buffer or added to the output Buffer is determined by the
|
|
|
|
|
* argType_ property of the output BufferArg.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
// ArgType is only used by output BufferArg.
|
|
|
|
|
// For input argument, argType_ is ignored.
|
|
|
|
|
// For output argument, need to set the argType_ of the BufferArg.
|
|
|
|
|
enum ArgType {
|
|
|
|
|
UNSPECIFIED = 0,
|
|
|
|
|
ASSIGN_TO = 1,
|
|
|
|
|
ADD_TO = 2,
|
|
|
|
|
};
|
|
|
|
|
class BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
// ArgType is only used by output BufferArg.
|
|
|
|
|
// For input argument, argType_ is ignored.
|
|
|
|
|
// For output argument, need to set the argType_ of the BufferArg.
|
|
|
|
|
enum ArgType {
|
|
|
|
|
UNSPECIFIED = 0,
|
|
|
|
|
ASSIGN_TO = 1,
|
|
|
|
|
ADD_TO = 2,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void setArgType(ArgType argType) { argType_ = argType; }
|
|
|
|
|
|
|
|
|
|
ArgType getArgType() const { return argType_; }
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
BufferArg(void* buf, ValueType valueType, const TensorShape& shape)
|
|
|
|
|
: buf_(buf), valueType_(valueType), shape_(shape) {}
|
|
|
|
|
BufferArg(void* buf,
|
|
|
|
|
ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {}
|
|
|
|
|
|
|
|
|
|
BufferArg(void* buf, ValueType valueType)
|
|
|
|
|
: buf_(buf), valueType_(valueType) {}
|
|
|
|
|
|
|
|
|
|
BufferArg(const Matrix& matrix)
|
|
|
|
|
BufferArg(const Matrix& matrix, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: buf_(
|
|
|
|
|
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(2) {
|
|
|
|
|
shape_(2),
|
|
|
|
|
argType_(argType) {
|
|
|
|
|
shape_.setDim(0, matrix.getHeight());
|
|
|
|
|
shape_.setDim(1, matrix.getWidth());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(const Matrix& matrix, const TensorShape& shape)
|
|
|
|
|
BufferArg(const Matrix& matrix,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: buf_(
|
|
|
|
|
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(shape) {
|
|
|
|
|
shape_(shape),
|
|
|
|
|
argType_(argType) {
|
|
|
|
|
CHECK_EQ(matrix.getElementCnt(), shape.getElements());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(const Vector& vector)
|
|
|
|
|
BufferArg(const Vector& vector, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: buf_(
|
|
|
|
|
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(1) {
|
|
|
|
|
shape_(1),
|
|
|
|
|
argType_(argType) {
|
|
|
|
|
shape_.setDim(0, vector.getSize());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(const IVector& vector)
|
|
|
|
|
BufferArg(const IVector& vector, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: buf_(
|
|
|
|
|
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
|
|
|
|
|
valueType_(VALUE_TYPE_INT32),
|
|
|
|
|
shape_(1) {
|
|
|
|
|
shape_(1),
|
|
|
|
|
argType_(argType) {
|
|
|
|
|
shape_.setDim(0, vector.getSize());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -163,8 +172,10 @@ protected:
|
|
|
|
|
// if a < b then value_.buf_[a] < value_.buf_[b]
|
|
|
|
|
class SequenceIdArg : public BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
SequenceIdArg(void* buf, const TensorShape& shape)
|
|
|
|
|
: BufferArg(buf, VALUE_TYPE_INT32, shape) {
|
|
|
|
|
SequenceIdArg(void* buf,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(buf, VALUE_TYPE_INT32, shape, argType) {
|
|
|
|
|
CHECK_EQ(shape_.ndims(), 1);
|
|
|
|
|
numSeqs_ = shape_[0] - 1;
|
|
|
|
|
}
|
|
|
|
@ -187,11 +198,15 @@ public:
|
|
|
|
|
SequenceArg(void* buf,
|
|
|
|
|
ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
const SequenceIdArg& startPositions)
|
|
|
|
|
: BufferArg(buf, valueType, shape), startPositions_(startPositions) {}
|
|
|
|
|
const SequenceIdArg& startPositions,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(buf, valueType, shape, argType),
|
|
|
|
|
startPositions_(startPositions) {}
|
|
|
|
|
|
|
|
|
|
SequenceArg(const Matrix& matrix, const IVector& vector)
|
|
|
|
|
: BufferArg(matrix), startPositions_(vector) {}
|
|
|
|
|
SequenceArg(const Matrix& matrix,
|
|
|
|
|
const IVector& vector,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(matrix, argType), startPositions_(vector) {}
|
|
|
|
|
|
|
|
|
|
~SequenceArg() {}
|
|
|
|
|
|
|
|
|
@ -214,8 +229,9 @@ public:
|
|
|
|
|
const BufferArg& col,
|
|
|
|
|
size_t nnz,
|
|
|
|
|
SparseDataFormat format,
|
|
|
|
|
SparseDataType type)
|
|
|
|
|
: BufferArg(buf, valueType, shape),
|
|
|
|
|
SparseDataType type,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(buf, valueType, shape, argType),
|
|
|
|
|
row_(row),
|
|
|
|
|
col_(col),
|
|
|
|
|
nnz_(nnz),
|
|
|
|
@ -232,13 +248,13 @@ public:
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SparseMatrixArg(const CpuSparseMatrix& sparse)
|
|
|
|
|
: BufferArg(sparse),
|
|
|
|
|
SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(sparse, argType),
|
|
|
|
|
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
|
|
|
|
|
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
|
|
|
|
|
|
|
|
|
|
SparseMatrixArg(const GpuSparseMatrix& sparse)
|
|
|
|
|
: BufferArg(sparse),
|
|
|
|
|
SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(sparse, argType),
|
|
|
|
|
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
|
|
|
|
|
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
|
|
|
|
|
|
|
|
|
|