follow comments

avx_docs
hedaoyuan 9 years ago
parent f8c9c889c3
commit 1c5a7c4316

@ -192,6 +192,7 @@ public:
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED) SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
: BufferArg(VALUE_TYPE_INT32, shape, argType) { : BufferArg(VALUE_TYPE_INT32, shape, argType) {
CHECK_EQ(shape_.ndims(), (size_t)1); CHECK_EQ(shape_.ndims(), (size_t)1);
CHECK_GT(shape_[0], 1);
numSeqs_ = shape_[0] - 1; numSeqs_ = shape_[0] - 1;
} }

@ -85,6 +85,7 @@ void testBufferArgs(const BufferArgs& inputs,
} }
void testBufferArgs(const BufferArgs& inputs, const CheckBufferArg& check) { void testBufferArgs(const BufferArgs& inputs, const CheckBufferArg& check) {
EXPECT_EQ(inputs.size(), 1);
check(inputs[0]); check(inputs[0]);
} }

@ -172,7 +172,7 @@ protected:
void initArg(SequenceIdArg& arg, size_t batchSize) { void initArg(SequenceIdArg& arg, size_t batchSize) {
size_t numSeqs = arg.numSeqs(); size_t numSeqs = arg.numSeqs();
int* buf = (int*)arg.data(); int* buf = reinterpret_cast<int*>(arg.data());
int pos = 0; int pos = 0;
size_t maxLen = 2 * batchSize / numSeqs; size_t maxLen = 2 * batchSize / numSeqs;
for (int i = 0; i < (int)numSeqs; ++i) { for (int i = 0; i < (int)numSeqs; ++i) {

Loading…
Cancel
Save