|
|
|
@ -110,7 +110,7 @@ public:
|
|
|
|
|
|
|
|
|
|
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
|
|
|
|
|
outputHeight * outputWidth;
|
|
|
|
|
resizeBuffer(size);
|
|
|
|
|
resizeBuffer<Device>(size);
|
|
|
|
|
real* colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
|
|
|
|
|
|
Im2ColFunctor<Device, real> im2col;
|
|
|
|
@ -120,7 +120,7 @@ public:
|
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
|
size_t filterOffset = inputs[1].shape().getElements() / groups_;
|
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
|
for (int g = 0; g < groups_; g++) {
|
|
|
|
|
for (size_t g = 0; g < groups_; g++) {
|
|
|
|
|
im2col(inputData + g * inputOffset,
|
|
|
|
|
inputChannels / groups_,
|
|
|
|
|
inputHeight,
|
|
|
|
@ -138,7 +138,9 @@ public:
|
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
|
int N = outputHeight * outputWidth;
|
|
|
|
|
int K = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
|
gemm(M,
|
|
|
|
|
gemm(CblasNoTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
@ -154,19 +156,6 @@ public:
|
|
|
|
|
outputData += outputChannels * outputHeight * outputWidth;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void resizeBuffer(size_t newSize) {
|
|
|
|
|
if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) {
|
|
|
|
|
if (Device == DEVICE_TYPE_CPU) {
|
|
|
|
|
memory_ = std::make_shared<CpuMemoryHandle>(newSize * sizeof(real));
|
|
|
|
|
} else {
|
|
|
|
|
memory_ = std::make_shared<GpuMemoryHandle>(newSize * sizeof(real));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
MemoryHandlePtr memory_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
@ -202,10 +191,73 @@ public:
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ(numInputs_, inputs.size());
|
|
|
|
|
CHECK_EQ(numOutputs_, outputs.size());
|
|
|
|
|
const TensorShape& outputGrad = inputs[0].shape();
|
|
|
|
|
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
|
|
|
|
|
const TensorShape& output = inputs[0].shape();
|
|
|
|
|
const TensorShape& input = inputs[1].shape();
|
|
|
|
|
const TensorShape& filterGrad = outputs[0].shape();
|
|
|
|
|
check(input, filterGrad, outputGrad);
|
|
|
|
|
const TensorShape& filter = outputs[0].shape();
|
|
|
|
|
check(input, filter, output);
|
|
|
|
|
|
|
|
|
|
size_t batchSize = input[0];
|
|
|
|
|
size_t inputChannels = input[1];
|
|
|
|
|
size_t inputHeight = input[2];
|
|
|
|
|
size_t inputWidth = input[3];
|
|
|
|
|
size_t filterHeight = filter[2];
|
|
|
|
|
size_t filterWidth = filter[3];
|
|
|
|
|
size_t outputChannels = output[1];
|
|
|
|
|
size_t outputHeight = output[2];
|
|
|
|
|
size_t outputWidth = output[3];
|
|
|
|
|
|
|
|
|
|
real* outputGrad = inputs[0].data<real>();
|
|
|
|
|
real* inputData = inputs[1].data<real>();
|
|
|
|
|
real* filterGrad = outputs[0].data<real>();
|
|
|
|
|
|
|
|
|
|
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
|
|
|
|
|
outputHeight * outputWidth;
|
|
|
|
|
resizeBuffer<Device>(size);
|
|
|
|
|
real* colData = reinterpret_cast<real*>(memory_->getBuf());
|
|
|
|
|
|
|
|
|
|
Im2ColFunctor<Device, real> im2col;
|
|
|
|
|
GemmFunctor<Device, real> gemm;
|
|
|
|
|
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
|
|
|
|
|
size_t outputOffset =
|
|
|
|
|
(outputChannels / groups_) * outputHeight * outputWidth;
|
|
|
|
|
size_t filterOffset = filter.getElements() / groups_;
|
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
|
for (size_t g = 0; g < groups_; g++) {
|
|
|
|
|
im2col(inputData + g * inputOffset,
|
|
|
|
|
inputChannels / groups_,
|
|
|
|
|
inputHeight,
|
|
|
|
|
inputWidth,
|
|
|
|
|
filterHeight,
|
|
|
|
|
filterWidth,
|
|
|
|
|
strideH(),
|
|
|
|
|
strideW(),
|
|
|
|
|
paddingH(),
|
|
|
|
|
paddingW(),
|
|
|
|
|
outputHeight,
|
|
|
|
|
outputWidth,
|
|
|
|
|
colData);
|
|
|
|
|
|
|
|
|
|
int M = outputChannels / groups_;
|
|
|
|
|
int K = outputHeight * outputWidth;
|
|
|
|
|
int N = inputChannels / groups_ * filterHeight * filterWidth;
|
|
|
|
|
gemm(CblasNoTrans,
|
|
|
|
|
CblasTrans,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
|
outputGrad + g * outputOffset,
|
|
|
|
|
K,
|
|
|
|
|
colData,
|
|
|
|
|
K,
|
|
|
|
|
1.0f,
|
|
|
|
|
filterGrad + g * filterOffset,
|
|
|
|
|
N);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
inputData += inputChannels * inputHeight * inputWidth;
|
|
|
|
|
outputGrad += outputChannels * outputHeight * outputWidth;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|