accelerate inputbackward(delete 'if' in this func) of depthwise conv

cblas_new
xzl 8 years ago
parent dbb658805e
commit 66520af9ca

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include "DepthwiseConvOp.h" #include "DepthwiseConvOp.h"
#include "GemmFunctor.h" #include "GemmFunctor.h"
#include "paddle/math/BaseMatrix.h" #include "paddle/math/BaseMatrix.h"
@ -93,28 +94,31 @@ void ConvolutionDepthwiseInputBackward(const int nthreads,
const int c_in = (index / inputHeight / inputWidth) % inputChannels; const int c_in = (index / inputHeight / inputWidth) % inputChannels;
const int h_in = (index / inputWidth) % inputHeight; const int h_in = (index / inputWidth) % inputHeight;
const int w_in = index % inputWidth; const int w_in = index % inputWidth;
const int c_out_start = c_in * filterMultiplier; const int c_out_start = c_in * filterMultiplier;
int h_out_start = (h_in - filterHeight + paddingH + strideH)/strideH;
h_out_start = 0 > h_out_start ? 0 : h_out_start;
int h_out_end = (h_in + paddingH)/strideH;
h_out_end = outputHeight - 1 < h_out_end? outputHeight - 1 : h_out_end;
int w_out_start = (w_in - filterWidth + paddingW + strideW)/strideW;
w_out_start = 0 > w_out_start ? 0 : w_out_start;
int w_out_end = (w_in + paddingW)/strideW;
w_out_end = outputWidth - 1 < w_out_end? outputWidth - 1 : w_out_end;
T value = 0; T value = 0;
for (int c_out = c_out_start; for (int c_out = c_out_start;
c_out < c_out_start + filterMultiplier; c_out ++) { c_out < c_out_start + filterMultiplier; c_out ++) {
const T* weight = weight_data + c_out * filterHeight * filterWidth; for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) {
for (int kh = 0; kh < filterHeight; ++kh) { const int filter_h = h_in + paddingH - h_out * strideH;
for (int kw = 0; kw < filterWidth; ++kw) { for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) {
const int h_out_s = h_in + paddingH - kh; const int filter_w = w_in + paddingW - w_out * strideW;
const int w_out_s = w_in + paddingW - kw; const int filter_offset = c_out * filterHeight * filterWidth
if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) { + filter_h * filterWidth + filter_w;
const int h_out = h_out_s / strideH; const int top_diff_offset = ((batch * outputChannels + c_out) *
const int w_out = w_out_s / strideW; outputHeight + h_out)* outputWidth + w_out;
// TODO(zhaolong) : the 'if' affect the effectiveness, value += top_diff[top_diff_offset] * weight_data[filter_offset];
// it needs to optimize
if ((h_out >= 0) && (h_out < outputHeight)
&& (w_out >= 0) && (w_out < outputWidth)) {
const int offset = ((batch * outputChannels + c_out)
* outputHeight + h_out) * outputWidth + w_out;
value += (*weight) * top_diff[offset];
}
}
++weight;
} }
} }
} }

Loading…
Cancel
Save