|
|
|
@ -188,8 +188,13 @@ public:
|
|
|
|
|
CHECK(inputs[0].shape() == inputs[3].shape());
|
|
|
|
|
CHECK(inputs[0].shape() == outputs[0].shape());
|
|
|
|
|
|
|
|
|
|
// TODO(hedaoyuan): need support ASSIGN_TO mode.
|
|
|
|
|
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
|
|
|
|
if (outputs[0].getArgType() != ADD_TO) {
|
|
|
|
|
// Currently, some algorithm implementations are ASSIGN_TO mode,
|
|
|
|
|
// if need to support the ADD_TO calculation, need to clear the output.
|
|
|
|
|
typename Tensor<real, Device>::Vector tmp(
|
|
|
|
|
outputs[0].shape().getElements(), outputs[0].data<real>());
|
|
|
|
|
tmp.zero();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t samples = inputs[0].shape()[0];
|
|
|
|
|
size_t channels = inputs[0].shape()[1];
|
|
|
|
|