|
|
|
@ -18,9 +18,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "TensorShape.h"
|
|
|
|
|
#include "TensorType.h"
|
|
|
|
|
#include "paddle/math/CpuSparseMatrix.h"
|
|
|
|
|
#include "paddle/math/Matrix.h"
|
|
|
|
|
#include "paddle/math/SparseMatrix.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
@ -248,15 +246,9 @@ public:
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(sparse, argType),
|
|
|
|
|
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
|
|
|
|
|
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
|
|
|
|
|
SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED);
|
|
|
|
|
|
|
|
|
|
SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(sparse, argType),
|
|
|
|
|
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
|
|
|
|
|
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
|
|
|
|
|
SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED);
|
|
|
|
|
|
|
|
|
|
~SparseMatrixArg() {}
|
|
|
|
|
|
|
|
|
|