|
|
|
@ -189,8 +189,8 @@ public:
|
|
|
|
|
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
|
|
|
|
|
bufferType_ = TENSOR_SEQUENCE_ID;
|
|
|
|
|
CHECK_EQ(shape_.ndims(), (size_t)1);
|
|
|
|
|
CHECK_GT(shape_[0], 1);
|
|
|
|
|
CHECK_EQ(shape_.ndims(), 1UL);
|
|
|
|
|
CHECK_GT(shape_[0], 1UL);
|
|
|
|
|
numSeqs_ = shape_[0] - 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -199,7 +199,7 @@ public:
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(buf, VALUE_TYPE_INT32, shape, argType) {
|
|
|
|
|
bufferType_ = TENSOR_SEQUENCE_ID;
|
|
|
|
|
CHECK_EQ(shape_.ndims(), (size_t)1);
|
|
|
|
|
CHECK_EQ(shape_.ndims(), 1UL);
|
|
|
|
|
numSeqs_ = shape_[0] - 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -280,9 +280,9 @@ public:
|
|
|
|
|
type_(static_cast<SparseDataType>(type)) {
|
|
|
|
|
bufferType_ = TENSOR_SPARSE;
|
|
|
|
|
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
|
|
|
|
|
CHECK_EQ(shape_.ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(row_.shape().ndims(), (size_t)1);
|
|
|
|
|
CHECK_EQ(col_.shape().ndims(), (size_t)1);
|
|
|
|
|
CHECK_EQ(shape_.ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(row_.shape().ndims(), 1UL);
|
|
|
|
|
CHECK_EQ(col_.shape().ndims(), 1UL);
|
|
|
|
|
if (format_ == T_SPARSE_CSR) {
|
|
|
|
|
CHECK_EQ(nnz, col.shape()[0]);
|
|
|
|
|
} else if (format_ == T_SPARSE_CSC) {
|
|
|
|
@ -304,7 +304,7 @@ public:
|
|
|
|
|
type_(static_cast<SparseDataType>(type)) {
|
|
|
|
|
bufferType_ = TENSOR_SPARSE;
|
|
|
|
|
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
|
|
|
|
|
CHECK_EQ(shape_.ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(shape_.ndims(), 2UL);
|
|
|
|
|
|
|
|
|
|
/// len of row_ : height + 1 (CSR) or nnz (CSC), buf_ == nullptr
|
|
|
|
|
row_ = (format_ == T_SPARSE_CSR
|
|
|
|
@ -325,7 +325,7 @@ public:
|
|
|
|
|
CHECK(buf_);
|
|
|
|
|
CHECK(valueType_ == DataType<real>::value);
|
|
|
|
|
// CHECK(deviceType_ == DType);
|
|
|
|
|
CHECK_EQ(2, shape_.ndims());
|
|
|
|
|
CHECK_EQ(2UL, shape_.ndims());
|
|
|
|
|
return typename Tensor<real, DType>::SparseMatrix(
|
|
|
|
|
reinterpret_cast<real*>(buf_),
|
|
|
|
|
reinterpret_cast<int*>(row_.data()),
|
|
|
|
|