|
|
|
@ -47,6 +47,12 @@ bool BlockExpandLayer::init(const LayerMap& layerMap,
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("blocks", blocks));
|
|
|
|
|
createFunction(backward_,
|
|
|
|
|
"ImageExpandGrad",
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("blocks", blocks));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
@ -126,12 +132,12 @@ void BlockExpandLayer::forward(PassType passType) {
|
|
|
|
|
}
|
|
|
|
|
start[batchSize] = batchSize * blockNum;
|
|
|
|
|
if (!useGpu_) {
|
|
|
|
|
TensorShape inputShape({batchSize, channels_, imgSizeH_, imgSizeW_});
|
|
|
|
|
TensorShape outputShape({batchSize, blockNum, blockSize});
|
|
|
|
|
inputShape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_});
|
|
|
|
|
outputShape_ = TensorShape({batchSize, blockNum, blockSize});
|
|
|
|
|
BufferArgs inputs;
|
|
|
|
|
BufferArgs outputs;
|
|
|
|
|
inputs.addArg(*getInputValue(0), inputShape);
|
|
|
|
|
outputs.addArg(*getOutputValue(), outputShape, ASSIGN_TO);
|
|
|
|
|
inputs.addArg(*getInputValue(0), inputShape_);
|
|
|
|
|
outputs.addArg(*getOutputValue(), outputShape_, ASSIGN_TO);
|
|
|
|
|
forward_[0]->calc(inputs, outputs);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -144,6 +150,8 @@ void BlockExpandLayer::backward(const UpdateCallback& callback) {
|
|
|
|
|
if (!preGrad) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (useGpu_) {
|
|
|
|
|
MatrixPtr grad = getOutputGrad();
|
|
|
|
|
MatrixPtr gradTrans = Matrix::create(blockSize, blockNum, false, useGpu_);
|
|
|
|
|
size_t batchSize = preGrad->getHeight();
|
|
|
|
@ -180,6 +188,13 @@ void BlockExpandLayer::backward(const UpdateCallback& callback) {
|
|
|
|
|
1.0,
|
|
|
|
|
1.0);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
BufferArgs inputs;
|
|
|
|
|
BufferArgs outputs;
|
|
|
|
|
inputs.addArg(*getOutputGrad(), outputShape_);
|
|
|
|
|
outputs.addArg(*getInputGrad(0), inputShape_, ADD_TO);
|
|
|
|
|
backward_[0]->calc(inputs, outputs);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace paddle
|
|
|
|
|