|
|
|
@ -39,7 +39,6 @@ enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
|
|
|
|
|
class BufferArg;
|
|
|
|
|
class SequenceArg;
|
|
|
|
|
class SparseMatrixArg;
|
|
|
|
|
typedef std::shared_ptr<BufferArg> BufferArgPtr;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* \brief BufferArg used as the argument type of Function.
|
|
|
|
@ -50,6 +49,11 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr;
|
|
|
|
|
* 3. SequenceArg for a Buffer of sequence data.
|
|
|
|
|
* 4. SparseMatrixArg for a Buffer of sparse matrix.
|
|
|
|
|
*
|
|
|
|
|
* Buffer shape
|
|
|
|
|
* For most buffers, the first dimension `shape()[0]` represents
|
|
|
|
|
* the size of the mini-batch.
|
|
|
|
|
*
|
|
|
|
|
* Buffer argType
|
|
|
|
|
* There is an ArgType property for the BufferArg used as Function Output.
|
|
|
|
|
* Whether the result of the Function calculation is assigned to the
|
|
|
|
|
* output Buffer or added to the output Buffer is determined by the
|
|
|
|
@ -71,6 +75,14 @@ public:
|
|
|
|
|
ArgType getArgType() const { return argType_; }
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
BufferArg(ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: buf_(nullptr),
|
|
|
|
|
valueType_(valueType),
|
|
|
|
|
shape_(shape),
|
|
|
|
|
argType_(argType) {}
|
|
|
|
|
|
|
|
|
|
BufferArg(void* buf,
|
|
|
|
|
ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
@ -170,6 +182,12 @@ protected:
|
|
|
|
|
// if a < b then value_.buf_[a] < value_.buf_[b]
|
|
|
|
|
class SequenceIdArg : public BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
|
|
|
|
|
CHECK_EQ(shape_.ndims(), (size_t)1);
|
|
|
|
|
numSeqs_ = shape_[0] - 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SequenceIdArg(void* buf,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
@ -190,9 +208,18 @@ private:
|
|
|
|
|
size_t numSeqs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// sequence data
|
|
|
|
|
// sequences data
|
|
|
|
|
// For mini-batch calculate,
|
|
|
|
|
// one batch can contain more than one sequence of data.
|
|
|
|
|
// SequenceArg can be used to represent sequences that contain multiple
|
|
|
|
|
// unequal lengths.
|
|
|
|
|
class SequenceArg : public BufferArg {
|
|
|
|
|
public:
|
|
|
|
|
SequenceArg(ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
|
ArgType argType = UNSPECIFIED)
|
|
|
|
|
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {}
|
|
|
|
|
|
|
|
|
|
SequenceArg(void* buf,
|
|
|
|
|
ValueType valueType,
|
|
|
|
|
const TensorShape& shape,
|
|
|
|
@ -210,6 +237,8 @@ public:
|
|
|
|
|
|
|
|
|
|
void* getIdBuf() const { return startPositions_.data(); }
|
|
|
|
|
size_t numSeqs() const { return startPositions_.numSeqs(); }
|
|
|
|
|
SequenceIdArg& getSequenceId() { return startPositions_; }
|
|
|
|
|
const SequenceIdArg& getSequenceId() const { return startPositions_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
SequenceIdArg startPositions_;
|
|
|
|
|