|
|
|
@ -173,13 +173,9 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ((size_t)numInputs_, inputs.size());
|
|
|
|
|
CHECK_EQ((size_t)numOutputs_, outputs.size());
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
|
|
|
|
|
CHECK(inputs[0].shape() == outputs[0].shape());
|
|
|
|
|
CHECK(inputs[0].shape() == outputs[1].shape());
|
|
|
|
|
|
|
|
|
|
check(inputs, outputs);
|
|
|
|
|
// ArgType check still on here,
|
|
|
|
|
// not sure whether it is better to put inside the check.
|
|
|
|
|
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
|
|
|
|
|
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
|
|
|
|
|
size_t batchSize = inputs[0].shape()[0];
|
|
|
|
@ -199,6 +195,15 @@ public:
|
|
|
|
|
pow_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ((size_t)numInputs_, inputs.size());
|
|
|
|
|
CHECK_EQ((size_t)numOutputs_, outputs.size());
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
|
|
|
|
|
CHECK(inputs[0].shape() == outputs[0].shape());
|
|
|
|
|
CHECK(inputs[0].shape() == outputs[1].shape());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Only need the shape of the input, can calculate the
|
|
|
|
|
// floating-point operation.
|
|
|
|
|
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
@ -211,6 +216,8 @@ public:
|
|
|
|
|
// number of floating-point operations
|
|
|
|
|
// an approximate value
|
|
|
|
|
size_t ops = batchSize * maps * ((rows * columns) * size_);
|
|
|
|
|
|
|
|
|
|
return ops;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|