|
|
|
@ -56,7 +56,7 @@ public:
|
|
|
|
|
: buf_(buf), valueType_(valueType) {}
|
|
|
|
|
|
|
|
|
|
BufferArg(const Matrix& matrix)
|
|
|
|
|
: buf_((void*)matrix.getData()),
|
|
|
|
|
: buf_(reinterpret_cast<void*>(matrix.getData())),
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(2) {
|
|
|
|
|
shape_.setDim(0, matrix.getHeight());
|
|
|
|
@ -64,21 +64,23 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(const Matrix& matrix, const TensorShape& shape)
|
|
|
|
|
: buf_((void*)matrix.getData()),
|
|
|
|
|
: buf_(reinterpret_cast<void*>(matrix.getData())),
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(shape) {
|
|
|
|
|
CHECK_EQ(matrix.getElementCnt(), shape.getElements());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(const Vector& vector)
|
|
|
|
|
: buf_((void*)vector.getData()),
|
|
|
|
|
: buf_(reinterpret_cast<void*>(vector.getData())),
|
|
|
|
|
valueType_(DataType<real>::value),
|
|
|
|
|
shape_(1) {
|
|
|
|
|
shape_.setDim(0, vector.getSize());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferArg(const IVector& vector)
|
|
|
|
|
: buf_((void*)vector.getData()), valueType_(VALUE_TYPE_INT32), shape_(1) {
|
|
|
|
|
: buf_(reinterpret_cast<void*>(vector.getData())),
|
|
|
|
|
valueType_(VALUE_TYPE_INT32),
|
|
|
|
|
shape_(1) {
|
|
|
|
|
shape_.setDim(0, vector.getSize());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -129,7 +131,7 @@ protected:
|
|
|
|
|
// sequence start positions in a mini-batch of sequences
|
|
|
|
|
// shape_.ndims() == 1
|
|
|
|
|
// valueType_ = int32
|
|
|
|
|
// if a < b than value_.buf_[a] < value_.buf_[b]
|
|
|
|
|
// if a < b then value_.buf_[a] < value_.buf_[b]
|
|
|
|
|
class SequenceIdArg : public BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
SequenceIdArg(void* buf, const TensorShape& shape)
|
|
|
|
@ -203,13 +205,13 @@ public:
|
|
|
|
|
|
|
|
|
|
SparseMatrixArg(const CpuSparseMatrix& sparse)
|
|
|
|
|
: BufferArg(sparse),
|
|
|
|
|
row_((void*)sparse.getRows(), VALUE_TYPE_INT32),
|
|
|
|
|
col_((void*)sparse.getCols(), VALUE_TYPE_INT32) {}
|
|
|
|
|
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
|
|
|
|
|
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
|
|
|
|
|
|
|
|
|
|
SparseMatrixArg(const GpuSparseMatrix& sparse)
|
|
|
|
|
: BufferArg(sparse),
|
|
|
|
|
row_((void*)sparse.getRows(), VALUE_TYPE_INT32),
|
|
|
|
|
col_((void*)sparse.getCols(), VALUE_TYPE_INT32) {}
|
|
|
|
|
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
|
|
|
|
|
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
|
|
|
|
|
|
|
|
|
|
~SparseMatrixArg() {}
|
|
|
|
|
|
|
|
|
|