|
|
|
@ -125,27 +125,25 @@ public:
|
|
|
|
|
pow_ = config.get<real>("pow");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const Arguments& inputs,
|
|
|
|
|
const Arguments& outputs,
|
|
|
|
|
const Arguments& inouts) override {
|
|
|
|
|
void calc(const BufferArgs& inputs,
|
|
|
|
|
const BufferArgs& outputs,
|
|
|
|
|
const BufferArgs& inouts) override {
|
|
|
|
|
CHECK_EQ(1, inputs.size());
|
|
|
|
|
CHECK_EQ(2, outputs.size());
|
|
|
|
|
CHECK_EQ(0, inouts.size());
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(inputs[0].dims_.size(), 4);
|
|
|
|
|
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
|
|
|
|
|
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
|
|
|
|
|
CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]);
|
|
|
|
|
}
|
|
|
|
|
CHECK_EQ(inputs[0].shape().ndims(), 4);
|
|
|
|
|
CHECK(inputs[0].shape() == outputs[0].shape());
|
|
|
|
|
CHECK(inputs[0].shape() == outputs[1].shape());
|
|
|
|
|
|
|
|
|
|
size_t samples = inputs[0].dims_[0];
|
|
|
|
|
size_t channels = inputs[0].dims_[1];
|
|
|
|
|
size_t height = inputs[0].dims_[2];
|
|
|
|
|
size_t width = inputs[0].dims_[3];
|
|
|
|
|
size_t samples = inputs[0].shape()[0];
|
|
|
|
|
size_t channels = inputs[0].shape()[1];
|
|
|
|
|
size_t height = inputs[0].shape()[2];
|
|
|
|
|
size_t width = inputs[0].shape()[3];
|
|
|
|
|
|
|
|
|
|
CrossMapNormal<Device>(outputs[0].getData(),
|
|
|
|
|
outputs[1].getData(),
|
|
|
|
|
inputs[0].getData(),
|
|
|
|
|
CrossMapNormal<Device>(outputs[0].data<real>(),
|
|
|
|
|
outputs[1].data<real>(),
|
|
|
|
|
inputs[0].data<real>(),
|
|
|
|
|
samples,
|
|
|
|
|
channels,
|
|
|
|
|
height,
|
|
|
|
@ -177,31 +175,29 @@ public:
|
|
|
|
|
pow_ = config.get<real>("pow");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const Arguments& inputs,
|
|
|
|
|
const Arguments& outputs,
|
|
|
|
|
const Arguments& inouts) override {
|
|
|
|
|
void calc(const BufferArgs& inputs,
|
|
|
|
|
const BufferArgs& outputs,
|
|
|
|
|
const BufferArgs& inouts) override {
|
|
|
|
|
CHECK_EQ(4, inputs.size());
|
|
|
|
|
CHECK_EQ(1, outputs.size());
|
|
|
|
|
CHECK_EQ(0, inouts.size());
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(inputs[0].dims_.size(), 4);
|
|
|
|
|
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
|
|
|
|
|
CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]);
|
|
|
|
|
CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]);
|
|
|
|
|
CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]);
|
|
|
|
|
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t samples = inputs[0].dims_[0];
|
|
|
|
|
size_t channels = inputs[0].dims_[1];
|
|
|
|
|
size_t height = inputs[0].dims_[2];
|
|
|
|
|
size_t width = inputs[0].dims_[3];
|
|
|
|
|
|
|
|
|
|
CrossMapNormalGrad<Device>(outputs[0].getData(),
|
|
|
|
|
inputs[0].getData(),
|
|
|
|
|
inputs[1].getData(),
|
|
|
|
|
inputs[2].getData(),
|
|
|
|
|
inputs[3].getData(),
|
|
|
|
|
CHECK_EQ(inputs[0].shape().ndims(), 4);
|
|
|
|
|
CHECK(inputs[0].shape() == inputs[1].shape());
|
|
|
|
|
CHECK(inputs[0].shape() == inputs[2].shape());
|
|
|
|
|
CHECK(inputs[0].shape() == inputs[3].shape());
|
|
|
|
|
CHECK(inputs[0].shape() == outputs[0].shape());
|
|
|
|
|
|
|
|
|
|
size_t samples = inputs[0].shape()[0];
|
|
|
|
|
size_t channels = inputs[0].shape()[1];
|
|
|
|
|
size_t height = inputs[0].shape()[2];
|
|
|
|
|
size_t width = inputs[0].shape()[3];
|
|
|
|
|
|
|
|
|
|
CrossMapNormalGrad<Device>(outputs[0].data<real>(),
|
|
|
|
|
inputs[0].data<real>(),
|
|
|
|
|
inputs[1].data<real>(),
|
|
|
|
|
inputs[2].data<real>(),
|
|
|
|
|
inputs[3].data<real>(),
|
|
|
|
|
samples,
|
|
|
|
|
channels,
|
|
|
|
|
height,
|
|
|
|
|