|
|
|
|
@ -119,7 +119,13 @@ REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker,
|
|
|
|
|
REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
multiplex,
|
|
|
|
|
ops::MultiplexCPUKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
ops::MultiplexCPUKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::MultiplexCPUKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::MultiplexCPUKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::MultiplexCPUKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
multiplex_grad,
|
|
|
|
|
ops::MultiplexGradCPUKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
ops::MultiplexGradCPUKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::MultiplexGradCPUKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::MultiplexGradCPUKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::MultiplexGradCPUKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
|
|