|
|
|
@ -136,6 +136,7 @@ public:
|
|
|
|
|
// check
|
|
|
|
|
CHECK_EQ(2UL, inputs.size());
|
|
|
|
|
CHECK_EQ(1UL, outputs.size());
|
|
|
|
|
// TODO(qingqing): support ASSIGN_TO.
|
|
|
|
|
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
|
|
|
|
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
|
|
|
|
|
<< "SequenceArg required here.";
|
|
|
|
@ -144,9 +145,7 @@ public:
|
|
|
|
|
auto w = inputs[1];
|
|
|
|
|
CHECK(in.data() && out.data() && in.getSequenceId().data());
|
|
|
|
|
CHECK_EQ(in.shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(out.shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(in.shape()[1], out.shape()[1]);
|
|
|
|
|
CHECK_EQ(in.shape()[0], out.shape()[0]);
|
|
|
|
|
CHECK(in.shape() == out.shape());
|
|
|
|
|
CHECK_EQ(w.shape()[1], in.shape()[1]);
|
|
|
|
|
|
|
|
|
|
auto outMat = out.matrix<Device>();
|
|
|
|
@ -176,6 +175,7 @@ public:
|
|
|
|
|
|
|
|
|
|
template <DeviceType Device>
|
|
|
|
|
class RowConvGradFunc : public FunctionBase {
|
|
|
|
|
// TODO(qingqing): split into RowConvDataFunc and RowConvWeightFunc
|
|
|
|
|
public:
|
|
|
|
|
void init(const FuncConfig& config) override {}
|
|
|
|
|
|
|
|
|
@ -196,9 +196,8 @@ public:
|
|
|
|
|
auto wGrad = outputs[1];
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(in.shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(outGrad.shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(in.shape()[1], outGrad.shape()[1]);
|
|
|
|
|
CHECK_EQ(in.shape()[0], outGrad.shape()[0]);
|
|
|
|
|
CHECK(in.shape() == inGrad.shape());
|
|
|
|
|
CHECK(in.shape() == outGrad.shape());
|
|
|
|
|
CHECK_EQ(wGrad.shape()[1], in.shape()[1]);
|
|
|
|
|
|
|
|
|
|
const auto outGMat = outGrad.matrix<Device>();
|
|
|
|
|