|
|
|
@ -32,7 +32,7 @@ namespace paddle {
|
|
|
|
|
* \param inputs[0] Sequence data of NST format.
|
|
|
|
|
* \param outputs[0] Image data of NCHW format.
|
|
|
|
|
*/
|
|
|
|
|
class ImageExpandFunction : public FunctionBase {
|
|
|
|
|
class BlockExpandFunction : public FunctionBase {
|
|
|
|
|
public:
|
|
|
|
|
void init(const FuncConfig& config) override {
|
|
|
|
|
// function arguments
|
|
|
|
@ -100,10 +100,10 @@ protected:
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <DeviceType Device>
|
|
|
|
|
class ImageExpandForward : public ImageExpandFunction {
|
|
|
|
|
class BlockExpandForward : public BlockExpandFunction {
|
|
|
|
|
public:
|
|
|
|
|
void init(const FuncConfig& config) override {
|
|
|
|
|
ImageExpandFunction::init(config);
|
|
|
|
|
BlockExpandFunction::init(config);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
@ -148,10 +148,10 @@ public:
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <DeviceType Device>
|
|
|
|
|
class ImageExpandBackward : public ImageExpandFunction {
|
|
|
|
|
class BlockExpandBackward : public BlockExpandFunction {
|
|
|
|
|
public:
|
|
|
|
|
void init(const FuncConfig& config) override {
|
|
|
|
|
ImageExpandFunction::init(config);
|
|
|
|
|
BlockExpandFunction::init(config);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
@ -192,11 +192,11 @@ public:
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
REGISTER_TYPED_FUNC(ImageExpand, CPU, ImageExpandForward);
|
|
|
|
|
REGISTER_TYPED_FUNC(ImageExpandGrad, CPU, ImageExpandBackward);
|
|
|
|
|
REGISTER_TYPED_FUNC(BlockExpand, CPU, BlockExpandForward);
|
|
|
|
|
REGISTER_TYPED_FUNC(BlockExpandGrad, CPU, BlockExpandBackward);
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
REGISTER_TYPED_FUNC(ImageExpand, GPU, ImageExpandForward);
|
|
|
|
|
REGISTER_TYPED_FUNC(ImageExpandGrad, GPU, ImageExpandBackward);
|
|
|
|
|
REGISTER_TYPED_FUNC(BlockExpand, GPU, BlockExpandForward);
|
|
|
|
|
REGISTER_TYPED_FUNC(BlockExpandGrad, GPU, BlockExpandBackward);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
} // namespace paddle
|
|
|
|
|