|
|
|
@ -108,26 +108,23 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK(1 == inputs.size() || 2 == inputs.size());
|
|
|
|
|
CHECK_EQ((size_t)1, outputs.size());
|
|
|
|
|
CHECK(1UL == inputs.size() || 2UL == inputs.size());
|
|
|
|
|
CHECK_EQ(1UL, outputs.size());
|
|
|
|
|
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
|
|
|
|
|
<< "SequenceArg required here";
|
|
|
|
|
const auto val_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
|
|
|
|
|
auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
|
|
|
|
|
|
|
|
|
|
CHECK(out_seq.data() && val_seqs.data() && val_seqs.getSequenceId().data());
|
|
|
|
|
CHECK_EQ(out_seq.shape().ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(val_seqs.shape().ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(val_seqs.getSequenceId().shape().ndims(), (size_t)1);
|
|
|
|
|
if (2 == inputs.size()) {
|
|
|
|
|
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
|
|
|
|
|
}
|
|
|
|
|
CHECK_EQ(out_seq.shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(val_seqs.shape().ndims(), 2UL);
|
|
|
|
|
/// dim of output = dim of input * context_length
|
|
|
|
|
CHECK_EQ(out_seq.shape()[1], val_seqs.shape()[1] * context_length_);
|
|
|
|
|
/// input and output has the same batch_size
|
|
|
|
|
CHECK_EQ(val_seqs.shape()[0], out_seq.shape()[0]);
|
|
|
|
|
/// dim of input == dim of weight
|
|
|
|
|
if (2 == inputs.size()) {
|
|
|
|
|
if (2UL == inputs.size()) {
|
|
|
|
|
CHECK_EQ(inputs[1].shape().ndims(), 2UL);
|
|
|
|
|
/// dim of input == dim of weight
|
|
|
|
|
CHECK_EQ(val_seqs.shape()[1], inputs[1].shape()[1]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -135,10 +132,11 @@ public:
|
|
|
|
|
auto out_mat = out_seq.matrix<Device>();
|
|
|
|
|
const auto in_mat = val_seqs.matrix<Device>();
|
|
|
|
|
const auto w_mat =
|
|
|
|
|
(2 == inputs.size())
|
|
|
|
|
(2UL == inputs.size() && inputs[1].data())
|
|
|
|
|
? inputs[1].matrix<Device>()
|
|
|
|
|
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
|
|
|
|
|
const auto seq_vec = val_seqs.getSequenceId().vector<int, Device>();
|
|
|
|
|
|
|
|
|
|
ContextProjectionForward<Device>(out_mat,
|
|
|
|
|
in_mat,
|
|
|
|
|
w_mat,
|
|
|
|
@ -235,36 +233,40 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ((size_t)1, inputs.size());
|
|
|
|
|
CHECK_EQ((size_t)2, outputs.size());
|
|
|
|
|
CHECK_EQ(1UL, inputs.size());
|
|
|
|
|
CHECK(1UL == outputs.size() || 2UL == outputs.size());
|
|
|
|
|
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
|
|
|
|
|
<< "SequenceArg required here";
|
|
|
|
|
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
|
|
|
|
|
auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
|
|
|
|
|
CHECK(in_seq.data() && in_seq.getSequenceId().data());
|
|
|
|
|
CHECK_EQ(in_seq.shape().ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(in_seq.getSequenceId().shape().ndims(), (size_t)1);
|
|
|
|
|
CHECK_EQ(out_seq.shape().ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(out_seq.getSequenceId().shape().ndims(), (size_t)1);
|
|
|
|
|
CHECK_EQ(outputs[1].shape().ndims(), (size_t)2);
|
|
|
|
|
CHECK_EQ(in_seq.shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(out_seq.shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(out_seq.getSequenceId().shape().ndims(), 1UL);
|
|
|
|
|
|
|
|
|
|
/// dim of input grad == dim of weight
|
|
|
|
|
CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]);
|
|
|
|
|
/// input and output grad has the same batch_size
|
|
|
|
|
CHECK_EQ(out_seq.shape()[0], in_seq.shape()[0]);
|
|
|
|
|
/// dim of output grad = dim of input grad * context_length
|
|
|
|
|
CHECK_EQ(in_seq.shape()[1], out_seq.shape()[1] * context_length_);
|
|
|
|
|
CHECK_EQ(out_seq.getArgType(), ADD_TO);
|
|
|
|
|
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
|
|
|
|
|
|
|
|
|
|
if (2UL == outputs.size()) {
|
|
|
|
|
CHECK_EQ(outputs[1].shape().ndims(), 2UL);
|
|
|
|
|
/// dim of input grad == dim of weight
|
|
|
|
|
CHECK_EQ(out_seq.shape()[1], outputs[1].shape()[1]);
|
|
|
|
|
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto seq_vec = in_seq.getSequenceId().vector<int, Device>();
|
|
|
|
|
const auto out_grad_mat = in_seq.matrix<Device>();
|
|
|
|
|
auto in_grad_mat =
|
|
|
|
|
!out_seq.data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
|
|
|
|
|
: out_seq.matrix<Device>();
|
|
|
|
|
auto w_grad_mat = !outputs[1].data()
|
|
|
|
|
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
|
|
|
|
|
: outputs[1].matrix<Device>();
|
|
|
|
|
auto w_grad_mat =
|
|
|
|
|
(2UL == outputs.size() && outputs[1].data())
|
|
|
|
|
? outputs[1].matrix<Device>()
|
|
|
|
|
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
|
|
|
|
|
|
|
|
|
|
ContextProjectionBackward<Device>(out_grad_mat,
|
|
|
|
|
in_grad_mat,
|
|
|
|
|
w_grad_mat,
|
|
|
|
@ -304,17 +306,17 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ(1, static_cast<int>(inputs.size()));
|
|
|
|
|
CHECK_EQ(1, static_cast<int>(outputs.size()));
|
|
|
|
|
CHECK_EQ(1UL, inputs.size());
|
|
|
|
|
CHECK_EQ(1UL, outputs.size());
|
|
|
|
|
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
|
|
|
|
|
<< "SequenceArg required here";
|
|
|
|
|
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
|
|
|
|
|
const auto out_seq = dynamic_cast<const SequenceArg&>(outputs[0]);
|
|
|
|
|
|
|
|
|
|
CHECK(in_seq.data() && out_seq.data() && in_seq.getSequenceId().data());
|
|
|
|
|
CHECK_EQ(static_cast<int>(out_seq.shape().ndims()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(in_seq.shape().ndims()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(in_seq.getSequenceId().shape().ndims()), 1);
|
|
|
|
|
CHECK_EQ(out_seq.shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(in_seq.shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL);
|
|
|
|
|
/// output layer grad dim == input layer grad dim * context_length_
|
|
|
|
|
CHECK_EQ(in_seq.shape().ndims(), out_seq.shape().ndims() * context_length_);
|
|
|
|
|
/// input and output has the same batch_size
|
|
|
|
@ -355,14 +357,14 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ(1, static_cast<int>(inputs.size()));
|
|
|
|
|
CHECK_EQ(1, static_cast<int>(outputs.size()));
|
|
|
|
|
CHECK_EQ(1UL, inputs.size());
|
|
|
|
|
CHECK_EQ(1UL, outputs.size());
|
|
|
|
|
CHECK(inputs[0].isSequenceArg()) << "SequenceArg required here";
|
|
|
|
|
const auto in_seq = dynamic_cast<const SequenceArg&>(inputs[0]);
|
|
|
|
|
CHECK(in_seq.data() && in_seq.getSequenceId().data() && outputs[0].data());
|
|
|
|
|
CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(in_seq.shape().ndims()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(in_seq.getSequenceId().shape().ndims()), 1);
|
|
|
|
|
CHECK_EQ(outputs[0].shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(in_seq.shape().ndims(), 2UL);
|
|
|
|
|
CHECK_EQ(in_seq.getSequenceId().shape().ndims(), 1UL);
|
|
|
|
|
CHECK_EQ(in_seq.shape()[0], outputs[0].shape()[0]);
|
|
|
|
|
/// output layer grad dim == weight dim * context_length_
|
|
|
|
|
CHECK_EQ(in_seq.shape()[1], outputs[0].shape()[1] * context_length_);
|
|
|
|
|