Add batch_fc op in contrib (#24017)
* add batch fc op, test=develop * add batch_fc_op, test=develop * fix untest, test=develop * rm check_dygraph, test=develop * fix comment, test=develop * fix comment, test=developrevert-22778-infer_var_type
parent
f5c08c3f4d
commit
0fb9b208ab
@ -0,0 +1,155 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/batch_fc_op.h"
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class BatchFCOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Input"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"X(Input) of Batch Fully Connected should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasOutput("Out"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Out(Output) of Batch Fully Connected should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("W"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"W(Input) of Batch Fully Connected should not be null."));
|
||||
|
||||
auto input_dims = ctx->GetInputDim("Input");
|
||||
auto w_dims = ctx->GetInputDim("W");
|
||||
|
||||
PADDLE_ENFORCE_EQ(input_dims.size(), 3,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input of BatchFCOp should have 3D."));
|
||||
PADDLE_ENFORCE_EQ(w_dims.size(), 3, platform::errors::InvalidArgument(
|
||||
"W of BatchFCOp should have 3D."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
input_dims[0], w_dims[0],
|
||||
platform::errors::InvalidArgument(
|
||||
"Input.dim[0] and W.dim[0] of BatchFCOp should be same."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
input_dims[2], w_dims[1],
|
||||
platform::errors::InvalidArgument(
|
||||
"Input.dim[2] and W.dim[1] of BatchFCOp should be same."));
|
||||
|
||||
auto bias_dims = ctx->GetInputDim("Bias");
|
||||
PADDLE_ENFORCE_EQ(bias_dims[0], input_dims[0],
|
||||
platform::errors::InvalidArgument(
|
||||
"Bias.dim[0] should be same as input.dim[0]."));
|
||||
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[2],
|
||||
platform::errors::InvalidArgument(
|
||||
"Bias.dim[1] should be same as input.dim[2]."));
|
||||
|
||||
ctx->SetOutputDim("Out", {input_dims[0], input_dims[1], w_dims[2]});
|
||||
ctx->ShareLoD("Input", /*->*/ "Out");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class BatchFCGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Input"), true,
|
||||
platform::errors::InvalidArgument("Input should not be null"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("W"), true,
|
||||
platform::errors::InvalidArgument("Input(W) should not be null"));
|
||||
|
||||
ctx->SetOutputDim(framework::GradVarName("Input"),
|
||||
ctx->GetInputDim("Input"));
|
||||
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
|
||||
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out")),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class BatchFCOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Input", "(Tensor) Input tensor of batch_fc_op operator.");
|
||||
AddInput("W", "(Tensor) Input tensor of batch_fc_op operator.");
|
||||
AddInput("Bias", "(Tensor) Input tensor of batch_fc_op operator.");
|
||||
AddOutput("Out", "Output tensor of batch_fc_op operator.");
|
||||
AddComment(R"DOC(
|
||||
BatchFC Operator.
|
||||
Notice: It currently supports GPU device.
|
||||
This Op exists in contrib, which means that it is not shown to the public.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BatchFCGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("batch_fc_grad");
|
||||
|
||||
op->SetInput("Input", this->Input("Input"));
|
||||
op->SetInput("W", this->Input("W"));
|
||||
op->SetInput("Bias", this->Input("Bias"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
|
||||
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
|
||||
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
|
||||
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchFCGradOpNoNeedBufferVarsInference,
|
||||
"Bias");
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(batch_fc, ops::BatchFCOp, ops::BatchFCOpMaker,
|
||||
ops::BatchFCGradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::BatchFCGradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp,
|
||||
ops::BatchFCGradOpNoNeedBufferVarsInference);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
batch_fc, ops::BatchFCKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::BatchFCKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,198 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <cublas.h>
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/operators/batch_fc_op.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using framework::Tensor;
|
||||
|
||||
#define CUDA_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
const int CUDA_NUM_THREADS = 1024;
|
||||
static inline int GET_BLOCKS(const int N) {
|
||||
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void add_bias_kernel(T* data, int slot_pairs_num, int ins_num,
|
||||
int out_dim, const T* bias) {
|
||||
CUDA_KERNEL_LOOP(idx, slot_pairs_num * ins_num * out_dim) {
|
||||
int block_len = ins_num * out_dim;
|
||||
int slot_index = idx / block_len;
|
||||
int out_dim_index = (idx % block_len) % out_dim;
|
||||
T temp = data[idx] + bias[slot_index * out_dim + out_dim_index];
|
||||
data[idx] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void add_bias(cudaStream_t stream, T* data, int slot_pairs_num, int ins_num,
|
||||
int out_dim, const T* bias) {
|
||||
add_bias_kernel<<<GET_BLOCKS(slot_pairs_num * ins_num * out_dim),
|
||||
CUDA_NUM_THREADS, 0, stream>>>(data, slot_pairs_num,
|
||||
ins_num, out_dim, bias);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void add_bias_grad_kernel(const T* dout_data, int slot_pairs_num,
|
||||
int ins_num, int out_dim, T* db_data) {
|
||||
CUDA_KERNEL_LOOP(idx, slot_pairs_num * out_dim) {
|
||||
int row = idx / out_dim;
|
||||
int col = idx % out_dim;
|
||||
T temp = static_cast<T>(0);
|
||||
for (int i = 0; i < ins_num; ++i) {
|
||||
int select_indx = ((row + 1) * i + 1) * col;
|
||||
temp += dout_data[select_indx];
|
||||
}
|
||||
db_data[idx] += temp;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void add_bias_grad(cudaStream_t stream, const T* dout_data, int slot_pairs_num,
|
||||
int ins_num, int out_dim, T* db_data) {
|
||||
add_bias_grad_kernel<<<GET_BLOCKS(slot_pairs_num * out_dim), CUDA_NUM_THREADS,
|
||||
0, stream>>>(dout_data, slot_pairs_num, ins_num,
|
||||
out_dim, db_data);
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class BatchFCCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
// X.dim = slot_pairs_num * ins_num * in_dim
|
||||
// W.dim = slot_pairs_num * in_dim * out_dim
|
||||
// b.dim = slot_pairs_num * out_dim
|
||||
// output.dim = slot_pairs_num * ins_num * out_dim
|
||||
auto* input = ctx.Input<framework::LoDTensor>("Input");
|
||||
auto* w = ctx.Input<Tensor>("W");
|
||||
auto* bias = ctx.Input<Tensor>("Bias");
|
||||
auto* output = ctx.Output<framework::LoDTensor>("Out");
|
||||
auto input_dims = input->dims();
|
||||
auto w_dims = w->dims();
|
||||
auto slot_pairs_num = input_dims[0];
|
||||
auto ins_num = input_dims[1];
|
||||
auto in_dim = input_dims[2];
|
||||
auto out_dim = w_dims[2];
|
||||
|
||||
// get data ptr
|
||||
const T* in_data = input->data<T>();
|
||||
const T* w_data = w->data<T>();
|
||||
const T* bias_data = bias->data<T>();
|
||||
|
||||
output->Resize({slot_pairs_num, ins_num, out_dim});
|
||||
T* out_data = output->mutable_data<T>(ctx.GetPlace());
|
||||
// initialize
|
||||
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
|
||||
.eigen_device();
|
||||
out_eigen.device(place) = out_eigen.constant(static_cast<T>(0));
|
||||
|
||||
CBLAS_TRANSPOSE transA = CblasNoTrans;
|
||||
CBLAS_TRANSPOSE transB = CblasNoTrans;
|
||||
|
||||
T alpha = 1;
|
||||
T beta = 0;
|
||||
int64_t strideA = ins_num * in_dim;
|
||||
int64_t strideB = in_dim * out_dim;
|
||||
|
||||
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
|
||||
blas.BatchedGEMM(transA, transB, ins_num, out_dim, in_dim, alpha, in_data,
|
||||
w_data, beta, out_data, slot_pairs_num, strideA, strideB);
|
||||
add_bias<T>(ctx.cuda_device_context().stream(), out_data, slot_pairs_num,
|
||||
ins_num, out_dim, bias_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class BatchFCGradOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<Tensor>("Input");
|
||||
auto* w = ctx.Input<Tensor>("W");
|
||||
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
|
||||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("Input"));
|
||||
auto* dw = ctx.Output<Tensor>(framework::GradVarName("W"));
|
||||
auto* db = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto w_dims = w->dims();
|
||||
auto slot_pairs_num = input_dims[0];
|
||||
auto ins_num = input_dims[1];
|
||||
auto in_dim = input_dims[2];
|
||||
auto out_dim = w_dims[2];
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
|
||||
.eigen_device();
|
||||
// initialize
|
||||
dx->mutable_data<T>(ctx.GetPlace());
|
||||
auto dx_eigen = framework::EigenVector<T>::Flatten(*dx);
|
||||
dx_eigen.device(place) = dx_eigen.constant(static_cast<T>(0));
|
||||
|
||||
dw->mutable_data<T>(ctx.GetPlace());
|
||||
auto dw_eigen = framework::EigenVector<T>::Flatten(*dw);
|
||||
dw_eigen.device(place) = dw_eigen.constant(static_cast<T>(0));
|
||||
|
||||
// get data ptr
|
||||
const T* x_data = input->data<T>();
|
||||
const T* w_data = w->data<T>();
|
||||
const T* dout_data = dout->data<T>();
|
||||
T* dx_data = dx->data<T>();
|
||||
T* dw_data = dw->data<T>();
|
||||
|
||||
db->mutable_data<T>(ctx.GetPlace());
|
||||
auto db_eigen = framework::EigenVector<T>::Flatten(*db);
|
||||
db_eigen.device(place) = db_eigen.constant(static_cast<T>(0));
|
||||
T* db_data = db->data<T>();
|
||||
add_bias_grad<T>(ctx.cuda_device_context().stream(), dout_data,
|
||||
slot_pairs_num, ins_num, out_dim, db_data);
|
||||
|
||||
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
|
||||
T alpha = 1;
|
||||
T beta = 0;
|
||||
|
||||
// dx = dout_data * y^T
|
||||
blas.BatchedGEMM(CblasNoTrans, CblasTrans, ins_num, in_dim, out_dim, alpha,
|
||||
dout_data, w_data, beta, dx_data, slot_pairs_num,
|
||||
ins_num * out_dim, out_dim * in_dim);
|
||||
// dy = x^T * dout_data
|
||||
blas.BatchedGEMM(CblasTrans, CblasNoTrans, in_dim, out_dim, ins_num, alpha,
|
||||
x_data, dout_data, beta, dw_data, slot_pairs_num,
|
||||
in_dim * ins_num, ins_num * out_dim);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using GPUCtx = paddle::platform::CUDADeviceContext;
|
||||
REGISTER_OP_CUDA_KERNEL(batch_fc, ops::BatchFCCUDAKernel<GPUCtx, float>,
|
||||
ops::BatchFCCUDAKernel<GPUCtx, double>);
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(batch_fc_grad,
|
||||
ops::BatchFCGradOpCUDAKernel<GPUCtx, float>,
|
||||
ops::BatchFCGradOpCUDAKernel<GPUCtx, double>);
|
@ -0,0 +1,32 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class BatchFCKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
platform::is_gpu_place(ctx.GetPlace()), true,
|
||||
platform::errors::Unimplemented("BatchFC only supports GPU now."));
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import random
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
def np_cal_batchfc(input, w, bias):
|
||||
slot_pairs_num, batch_size, in_dim = input.shape
|
||||
_, _, out_dim = w.shape
|
||||
res = np.zeros((slot_pairs_num, batch_size, out_dim))
|
||||
for slot in range(slot_pairs_num):
|
||||
res[slot, :] = np.dot(input[slot, :], w[slot, :])
|
||||
for slot in range(slot_pairs_num):
|
||||
for bindx in range(out_dim):
|
||||
res[slot, :, bindx] += bias[slot, bindx]
|
||||
return res
|
||||
|
||||
|
||||
class TestBatchFCOp(OpTest):
|
||||
def config(self):
|
||||
self.slot_pairs_num = 10
|
||||
self.batch_size = 5
|
||||
self.in_dim = 10
|
||||
self.out_dim = 12
|
||||
self.dtype = "float64"
|
||||
|
||||
def setUp(self):
|
||||
self.config()
|
||||
self.input = np.random.random((self.slot_pairs_num, self.batch_size,
|
||||
self.in_dim)).astype(self.dtype)
|
||||
self.w = np.random.random((self.slot_pairs_num, self.in_dim,
|
||||
self.out_dim)).astype(self.dtype)
|
||||
self.bias = np.random.random((self.slot_pairs_num,
|
||||
self.out_dim)).astype(self.dtype)
|
||||
self.op_type = "batch_fc"
|
||||
np_out = np_cal_batchfc(self.input, self.w, self.bias)
|
||||
np_out = np_out.astype(self.dtype)
|
||||
self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias}
|
||||
self.outputs = {"Out": np_out}
|
||||
|
||||
def test_check_output_gpu(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
self.check_output_with_place(core.CUDAPlace(0))
|
||||
|
||||
def test_check_grad_gpu(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
self.check_grad_with_place(
|
||||
core.CUDAPlace(0), ["Bias", "W", "Input"], "Out")
|
||||
|
||||
|
||||
class TestBatchFCOp1(OpTest):
|
||||
def config(self):
|
||||
self.slot_pairs_num = 10
|
||||
self.batch_size = 5
|
||||
self.in_dim = 10
|
||||
self.out_dim = 12
|
||||
self.dtype = "float64"
|
||||
|
||||
def setUp(self):
|
||||
self.config()
|
||||
self.input = np.random.random((self.slot_pairs_num, self.batch_size,
|
||||
self.in_dim)).astype(self.dtype)
|
||||
self.w = np.random.random((self.slot_pairs_num, self.in_dim,
|
||||
self.out_dim)).astype(self.dtype)
|
||||
self.bias = np.random.random((self.slot_pairs_num,
|
||||
self.out_dim)).astype(self.dtype)
|
||||
self.op_type = "batch_fc"
|
||||
np_out = np_cal_batchfc(self.input, self.w, self.bias)
|
||||
np_out = np_out.astype(self.dtype)
|
||||
self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias}
|
||||
self.outputs = {"Out": np_out}
|
||||
|
||||
def test_check_output_cpu(self):
|
||||
try:
|
||||
self.check_output_with_place(place=core.CPUPlace())
|
||||
except:
|
||||
print("do not support cpu test, skip")
|
||||
|
||||
def test_check_grad_cpu(self):
|
||||
try:
|
||||
self.check_grad_with_place(core.CPUPlace(), ["Bias", "W", "Input"],
|
||||
"Out")
|
||||
except:
|
||||
print("do not support cpu test, skip")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue