|
|
|
@ -85,15 +85,15 @@ public:
|
|
|
|
|
void calc(const Arguments& inputs,
|
|
|
|
|
const Arguments& outputs,
|
|
|
|
|
const Arguments& inouts) override {
|
|
|
|
|
CHECK_EQ(3, inputs.size());
|
|
|
|
|
CHECK_EQ(1, outputs.size());
|
|
|
|
|
CHECK_EQ(0, inouts.size());
|
|
|
|
|
CHECK_EQ(3, static_cast<int>(inputs.size()));
|
|
|
|
|
CHECK_EQ(1, static_cast<int>(outputs.size()));
|
|
|
|
|
CHECK_EQ(0, static_cast<int>(inouts.size()));
|
|
|
|
|
|
|
|
|
|
CHECK(outputs[0].getData() && inputs[0].getData() && inputs[2].getData());
|
|
|
|
|
CHECK_EQ(outputs[0].dims_.size(), 2);
|
|
|
|
|
CHECK_EQ(inputs[0].dims_.size(), 2);
|
|
|
|
|
CHECK_EQ(inputs[1].dims_.size(), 2);
|
|
|
|
|
CHECK_EQ(inputs[2].dims_.size(), 1);
|
|
|
|
|
CHECK_EQ(static_cast<int>(outputs[0].dims_.size()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(inputs[1].dims_.size()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(inputs[2].dims_.size()), 1);
|
|
|
|
|
/// dim of output = dim of input * context_length
|
|
|
|
|
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
|
|
|
|
|
/// dim of input == dim of weight
|
|
|
|
@ -202,15 +202,15 @@ public:
|
|
|
|
|
void calc(const Arguments& inputs,
|
|
|
|
|
const Arguments& outputs,
|
|
|
|
|
const Arguments& inouts) override {
|
|
|
|
|
CHECK_EQ(3, inputs.size());
|
|
|
|
|
CHECK_EQ(1, outputs.size());
|
|
|
|
|
CHECK_EQ(0, inouts.size());
|
|
|
|
|
CHECK_EQ(3, static_cast<int>(inputs.size()));
|
|
|
|
|
CHECK_EQ(1, static_cast<int>(outputs.size()));
|
|
|
|
|
CHECK_EQ(0, static_cast<int>(inouts.size()));
|
|
|
|
|
|
|
|
|
|
CHECK(outputs[0].getData() && inputs[2].getData());
|
|
|
|
|
CHECK_EQ(outputs[0].dims_.size(), 2);
|
|
|
|
|
CHECK_EQ(inputs[0].dims_.size(), 2);
|
|
|
|
|
CHECK_EQ(inputs[1].dims_.size(), 2);
|
|
|
|
|
CHECK_EQ(inputs[2].dims_.size(), 1);
|
|
|
|
|
CHECK_EQ(static_cast<int>(outputs[0].dims_.size()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(inputs[1].dims_.size()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(inputs[2].dims_.size()), 1);
|
|
|
|
|
|
|
|
|
|
/// dim of input == dim of weight
|
|
|
|
|
CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]);
|
|
|
|
@ -269,13 +269,13 @@ public:
|
|
|
|
|
void calc(const Arguments& inputs,
|
|
|
|
|
const Arguments& outputs,
|
|
|
|
|
const Arguments& inouts) override {
|
|
|
|
|
CHECK_EQ(2, inputs.size());
|
|
|
|
|
CHECK_EQ(1, outputs.size());
|
|
|
|
|
CHECK_EQ(0, inouts.size());
|
|
|
|
|
CHECK_EQ(2, static_cast<int>(inputs.size()));
|
|
|
|
|
CHECK_EQ(1, static_cast<int>(outputs.size()));
|
|
|
|
|
CHECK_EQ(0, static_cast<int>(inouts.size()));
|
|
|
|
|
CHECK(inputs[0].getData() && outputs[0].getData() && inputs[1].getData());
|
|
|
|
|
CHECK_EQ(outputs[0].dims_.size(), 2);
|
|
|
|
|
CHECK_EQ(inputs[0].dims_.size(), 2);
|
|
|
|
|
CHECK_EQ(inputs[1].dims_.size(), 1);
|
|
|
|
|
CHECK_EQ(static_cast<int>(outputs[0].dims_.size()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(inputs[1].dims_.size()), 1);
|
|
|
|
|
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
|
|
|
|
|
/// input and output has the same batch_size
|
|
|
|
|
CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]);
|
|
|
|
@ -317,14 +317,14 @@ public:
|
|
|
|
|
void calc(const Arguments& inputs,
|
|
|
|
|
const Arguments& outputs,
|
|
|
|
|
const Arguments& inouts) override {
|
|
|
|
|
CHECK_EQ(2, inputs.size());
|
|
|
|
|
CHECK_EQ(1, outputs.size());
|
|
|
|
|
CHECK_EQ(0, inouts.size());
|
|
|
|
|
CHECK_EQ(2, static_cast<int>(inputs.size()));
|
|
|
|
|
CHECK_EQ(1, static_cast<int>(outputs.size()));
|
|
|
|
|
CHECK_EQ(0, static_cast<int>(inouts.size()));
|
|
|
|
|
|
|
|
|
|
CHECK(inputs[0].getData() && outputs[0].getData() && inputs[1].getData());
|
|
|
|
|
CHECK_EQ(outputs[0].dims_.size(), 2);
|
|
|
|
|
CHECK_EQ(inputs[0].dims_.size(), 2);
|
|
|
|
|
CHECK_EQ(inputs[1].dims_.size(), 1);
|
|
|
|
|
CHECK_EQ(static_cast<int>(outputs[0].dims_.size()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 2);
|
|
|
|
|
CHECK_EQ(static_cast<int>(inputs[1].dims_.size()), 1);
|
|
|
|
|
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
|
|
|
|
|
|
|
|
|
|
auto out_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
|
|
|
|
|