|
|
|
@ -17,7 +17,6 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/math/Vector.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Context Projection Forward with CPU Matrix Device.
|
|
|
|
|
*
|
|
|
|
@ -208,10 +207,10 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad_mat,
|
|
|
|
|
* Context Projection Backward Function.
|
|
|
|
|
* Update the weight gradient and input layer gradient with backprop
|
|
|
|
|
*
|
|
|
|
|
* \param inputs[0] input sequence.
|
|
|
|
|
* \param inputs[1] output layer grad.
|
|
|
|
|
* \param outputs[0] input layer grad.
|
|
|
|
|
* \param outputs[1] weight grad.
|
|
|
|
|
* \param inputs[0].seq input sequence.
|
|
|
|
|
* \param inputs[0].matrix output layer grad.
|
|
|
|
|
* \param outputs[0] input layer grad.
|
|
|
|
|
* \param outputs[1] weight grad.
|
|
|
|
|
*/
|
|
|
|
|
template <DeviceType Device>
|
|
|
|
|
class ContextProjectionBackwardFunc : public FunctionBase {
|
|
|
|
@ -225,27 +224,28 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ((size_t)2, inputs.size());
|
|
|
|
|
CHECK_EQ((size_t)1, inputs.size());
|
|
|
|
|
CHECK_EQ((size_t)2, outputs.size());
|
|
|
|
|
|
|
|
|
|
CHECK(inputs[0].data() && inputs[1].data());
|
|
|
|
|
CHECK_EQ(inputs[0].shape().ndims(), (size_t)1);
|
|
|
|
|
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
|
|
|
|
|
const auto seqArg = dynamic_cast<const SequenceArg&>(inputs[0]);
|
|
|
|
|
CHECK(seqArg.data() && inputs[0].data());
|
|
|
|
|
CHECK_EQ(seqArg.shape().ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(seqArg.getSequenceIds().shape().ndims(), (size_t)1);
|
|
|
|
|
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(outputs[1].shape().ndims(), (size_t)2);
|
|
|
|
|
|
|
|
|
|
/// dim of input grad == dim of weight
|
|
|
|
|
CHECK_EQ(outputs[0].shape()[1], outputs[1].shape()[1]);
|
|
|
|
|
/// input and output grad has the same batch_size
|
|
|
|
|
CHECK_EQ(outputs[0].shape()[0], inputs[1].shape()[0]);
|
|
|
|
|
CHECK_EQ(outputs[0].shape()[0], seqArg.shape()[0]);
|
|
|
|
|
/// dim of output val = dim of input grad * context_length
|
|
|
|
|
CHECK_EQ(inputs[1].shape()[1], outputs[0].shape()[1] * context_length_);
|
|
|
|
|
CHECK_EQ(seqArg.shape()[1], outputs[0].shape()[1] * context_length_);
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
|
|
|
|
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
|
|
|
|
|
|
|
|
|
|
const auto seq_vec = inputs[0].vector<int, Device>();
|
|
|
|
|
const auto out_grad_mat = inputs[1].matrix<Device>();
|
|
|
|
|
const auto seq_vec = seqArg.getSequenceIds().vector<int, Device>();
|
|
|
|
|
const auto out_grad_mat = seqArg.matrix<Device>();
|
|
|
|
|
auto in_grad_mat =
|
|
|
|
|
!outputs[0].data()
|
|
|
|
|
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
|
|
|
|
|