|
|
|
@ -18,7 +18,7 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) {
|
|
|
|
|
FunctionCompare test("RowConv", FuncConfig());
|
|
|
|
|
CpuGpuFuncCompare test("RowConv", FuncConfig());
|
|
|
|
|
|
|
|
|
|
test.addSequence(SequenceIdArg(TensorShape{batchSize}));
|
|
|
|
|
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
|
|
|
|
@ -31,7 +31,7 @@ void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void testRowConvBw(size_t batchSize, size_t dim, size_t contextLength) {
|
|
|
|
|
FunctionCompare test("RowConvGrad", FuncConfig());
|
|
|
|
|
CpuGpuFuncCompare test("RowConvGrad", FuncConfig());
|
|
|
|
|
|
|
|
|
|
test.addSequence(SequenceIdArg(TensorShape{batchSize}));
|
|
|
|
|
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
|
|
|
|
|