@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/math/Vector.h"
namespace paddle {
* Context Projection Forward with CPU Matrix Device.
@ -208,10 +207,10 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad_mat,
* Context Projection Backward Function.
* Update the weight gradient and input layer gradient with backprop
* \param inputs[0] input sequence.
* \param inputs[1] output layer grad.
* \param outputs[0] input layer grad.
* \param outputs[1] weight grad.
* \param inputs[0].seq input sequence.
* \param inputs[0].matrix output layer grad.
* \param outputs[0] input layer grad.
* \param outputs[1] weight grad.
template <DeviceType Device>
class ContextProjectionBackwardFunc : public FunctionBase {
@ -225,27 +224,28 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)2, inputs.size());
CHECK_EQ((size_t)1, inputs.size());
CHECK_EQ((size_t)2, outputs.size());
CHECK(inputs[0].data() && inputs[1].data());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)1);
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
const auto seqArg = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(seqArg.data() && inputs[0].data());
CHECK_EQ(seqArg.shape().ndims(), (size_t)2);
CHECK_EQ(seqArg.getSequenceIds().shape().ndims(), (size_t)1);
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(outputs[1].shape().ndims(), (size_t)2);
/// dim of input grad == dim of weight
CHECK_EQ(outputs[0].shape()[1], outputs[1].shape()[1]);
/// input and output grad has the same batch_size
CHECK_EQ(outputs[0].shape()[0], inputs[1].shape()[0]);
CHECK_EQ(outputs[0].shape()[0], seqArg.shape()[0]);
/// dim of output val = dim of input grad * context_length
CHECK_EQ(inputs[1].shape()[1], outputs[0].shape()[1] * context_length_);
CHECK_EQ(seqArg.shape()[1], outputs[0].shape()[1] * context_length_);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
const auto seq_vec = inputs[0].vector<int, Device>();
const auto out_grad_mat = inputs[1].matrix<Device>();
const auto seq_vec = seqArg.getSequenceIds().vector<int, Device>();
const auto out_grad_mat = seqArg.matrix<Device>();
auto in_grad_mat =
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)