Remove disable flag in test_fsp_op.py (#22171)

* fix fsp_op, test=develop

* fix fsp grad op maker, test=develop

* update op_use_default_grad_op_maker.spec, test=develop
revert-22710-feature/integrated_ps_api
Bai Yifan 5 years ago committed by GitHub
parent 67e9247f4c
commit faba4b116a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,4 @@
cos_sim
fsp
gru
match_matrix_tensor
maxout

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fsp_op.h"
#include <memory>
namespace paddle {
namespace operators {
@ -114,14 +115,37 @@ class FSPOpGrad : public framework::OperatorWithKernel {
}
};
template <typename T>
class FSPGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("fsp_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fsp, ops::FSPOp, ops::FSPOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(fsp, ops::FSPOp, ops::FSPOpMaker,
ops::FSPGradOpMaker<paddle::framework::OpDesc>,
ops::FSPGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fsp_grad, ops::FSPOpGrad);
REGISTER_OP_CPU_KERNEL(
fsp, ops::FSPOpKernel<paddle::platform::CPUDeviceContext, float>,

@ -46,6 +46,7 @@ class FSPOpKernel : public framework::OpKernel<T> {
x_mat_desc.width_ = height * width;
x_mat_desc.batch_size_ = batch_size;
x_mat_desc.stride_ = x_channel * height * width;
x_mat_desc.trans_ = false;
math::MatDescriptor y_mat_desc;
y_mat_desc.height_ = height * width;
@ -93,12 +94,14 @@ class FSPGradOpKernel : public framework::OpKernel<T> {
d_out_mat_desc.width_ = y_channel;
d_out_mat_desc.batch_size_ = batch_size;
d_out_mat_desc.stride_ = x_channel * y_channel;
d_out_mat_desc.trans_ = false;
math::MatDescriptor y_mat_desc;
y_mat_desc.height_ = y_channel;
y_mat_desc.width_ = h * w;
y_mat_desc.batch_size_ = batch_size;
y_mat_desc.stride_ = y_channel * h * w;
y_mat_desc.trans_ = false;
blas.MatMul(*d_out, d_out_mat_desc, *y, y_mat_desc,
static_cast<T>(1.0 / (h * w)), d_x, static_cast<T>(0.0));
@ -125,6 +128,7 @@ class FSPGradOpKernel : public framework::OpKernel<T> {
x_mat_desc.width_ = h * w;
x_mat_desc.batch_size_ = batch_size;
x_mat_desc.stride_ = x_channel * h * w;
x_mat_desc.trans_ = false;
blas.MatMul(*d_out, d_out_mat_desc, *x, x_mat_desc,
static_cast<T>(1.0 / (h * w)), d_y, static_cast<T>(0.0));

@ -34,7 +34,6 @@ def fsp_matrix(a, b):
return np.mean(a_r * b_r, axis=1)
@unittest.skip("Disable temporarily.")
class TestFSPOp(OpTest):
def setUp(self):
self.op_type = "fsp"

Loading…
Cancel
Save