|
|
|
@ -134,15 +134,15 @@ public:
|
|
|
|
|
beta = 0.0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t batchSize = inputs[0].shape()[0];
|
|
|
|
|
size_t inputChannels = inputs[0].shape()[1];
|
|
|
|
|
size_t inputHeight = inputs[0].shape()[2];
|
|
|
|
|
size_t inputWidth = inputs[0].shape()[3];
|
|
|
|
|
size_t filterHeight = inputs[1].shape()[2];
|
|
|
|
|
size_t filterWidth = inputs[1].shape()[3];
|
|
|
|
|
size_t outputChannels = outputs[0].shape()[1];
|
|
|
|
|
size_t outputHeight = outputs[0].shape()[2];
|
|
|
|
|
size_t outputWidth = outputs[0].shape()[3];
|
|
|
|
|
size_t batchSize = input[0];
|
|
|
|
|
size_t inputChannels = input[1];
|
|
|
|
|
size_t inputHeight = input[2];
|
|
|
|
|
size_t inputWidth = input[3];
|
|
|
|
|
size_t filterHeight = getFilterHeight(filter);
|
|
|
|
|
size_t filterWidth = getFilterWidth(filter);
|
|
|
|
|
size_t outputChannels = output[1];
|
|
|
|
|
size_t outputHeight = output[2];
|
|
|
|
|
size_t outputWidth = output[3];
|
|
|
|
|
|
|
|
|
|
real* inputData = inputs[0].data<real>();
|
|
|
|
|
real* filterData = inputs[1].data<real>();
|
|
|
|
@ -158,7 +158,8 @@ public:
|
|
|
|
|
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
|
|
|
|
|
size_t outputOffset =
|
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
|
size_t filterOffset = inputs[1].shape().getElements() / groups_;
|
|
|
|
|
size_t filterOffset = filter.getElements() / groups_;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
|
for (size_t g = 0; g < groups_; g++) {
|
|
|
|
|
im2col(inputData + g * inputOffset,
|
|
|
|
@ -211,7 +212,9 @@ public:
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ(numInputs_, inputs.size());
|
|
|
|
|
CHECK_EQ(numOutputs_, outputs.size());
|
|
|
|
|
// CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
|
|
|
|
// Since the implementation of Col2ImFunctor is ADD_TO,
|
|
|
|
|
// this function only supports ADD_TO mode.
|
|
|
|
|
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
|
|
|
|
|
const TensorShape& output = inputs[0].shape();
|
|
|
|
|
const TensorShape& filter = inputs[1].shape();
|
|
|
|
|
const TensorShape& input = outputs[0].shape();
|
|
|
|
@ -221,8 +224,8 @@ public:
|
|
|
|
|
size_t inputChannels = input[1];
|
|
|
|
|
size_t inputHeight = input[2];
|
|
|
|
|
size_t inputWidth = input[3];
|
|
|
|
|
size_t filterHeight = filter[2];
|
|
|
|
|
size_t filterWidth = filter[3];
|
|
|
|
|
size_t filterHeight = getFilterHeight(filter);
|
|
|
|
|
size_t filterWidth = getFilterWidth(filter);
|
|
|
|
|
size_t outputChannels = output[1];
|
|
|
|
|
size_t outputHeight = output[2];
|
|
|
|
|
size_t outputWidth = output[3];
|
|
|
|
@ -311,8 +314,8 @@ public:
|
|
|
|
|
size_t inputChannels = input[1];
|
|
|
|
|
size_t inputHeight = input[2];
|
|
|
|
|
size_t inputWidth = input[3];
|
|
|
|
|
size_t filterHeight = filter[2];
|
|
|
|
|
size_t filterWidth = filter[3];
|
|
|
|
|
size_t filterHeight = getFilterHeight(filter);
|
|
|
|
|
size_t filterWidth = getFilterWidth(filter);
|
|
|
|
|
size_t outputChannels = output[1];
|
|
|
|
|
size_t outputHeight = output[2];
|
|
|
|
|
size_t outputWidth = output[3];
|
|
|
|
|