|
|
|
@ -63,12 +63,12 @@ enum ArgType {
|
|
|
|
|
ADD_TO = 2,
|
|
|
|
|
};
|
|
|
|
|
class BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
public:
|
|
|
|
|
void setArgType(ArgType argType) { argType_ = argType; }
|
|
|
|
|
|
|
|
|
|
ArgType getArgType() const { return argType_; }
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
public:
|
|
|
|
|
BufferArg(ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
@ -169,7 +169,7 @@ public:
|
|
|
|
|
const SequenceArg& sequence() const;
|
|
|
|
|
const SparseMatrixArg& sparse() const;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
protected:
|
|
|
|
|
void* buf_;
|
|
|
|
|
ValueType valueType_;
|
|
|
|
|
TensorShape shape_;
|
|
|
|
@ -185,7 +185,7 @@ protected:
|
|
|
|
|
// valueType_ = int32
|
|
|
|
|
// if a < b then value_.buf_[a] < value_.buf_[b]
|
|
|
|
|
class SequenceIdArg : public BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
public:
|
|
|
|
|
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
|
|
|
|
|
bufferType_ = TENSOR_SEQUENCE_ID;
|
|
|
|
@ -212,7 +212,7 @@ public:
|
|
|
|
|
|
|
|
|
|
size_t numSeqs() const { return numSeqs_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
private:
|
|
|
|
|
size_t numSeqs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -222,7 +222,7 @@ private:
|
|
|
|
|
// SequenceArg can be used to represent sequences that contain multiple
|
|
|
|
|
// unequal lengths.
|
|
|
|
|
class SequenceArg : public BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
public:
|
|
|
|
|
SequenceArg(ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
@ -255,7 +255,7 @@ public:
|
|
|
|
|
SequenceIdArg& getSequenceId() { return startPositions_; }
|
|
|
|
|
const SequenceIdArg& getSequenceId() const { return startPositions_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
private:
|
|
|
|
|
SequenceIdArg startPositions_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -263,7 +263,7 @@ private:
|
|
|
|
|
// valueType_ == float or double
|
|
|
|
|
// shape_.ndims() == 2
|
|
|
|
|
class SparseMatrixArg : public BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
public:
|
|
|
|
|
SparseMatrixArg(void* buf,
|
|
|
|
|
ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
@ -353,7 +353,7 @@ public:
|
|
|
|
|
|
|
|
|
|
SparseDataType dataType() const { return type_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
private:
|
|
|
|
|
BufferArg row_;
|
|
|
|
|
BufferArg col_;
|
|
|
|
|
size_t nnz_;
|
|
|
|
|