|
|
|
@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/unsqueeze_op.h"
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -327,6 +329,7 @@ REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
@ -334,12 +337,14 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
unsqueeze_grad,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
@ -347,6 +352,7 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
unsqueeze2_grad,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
|