|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/fc_op.h"
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -29,11 +30,11 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
auto w_dims = ctx->GetInputDim("W");
|
|
|
|
|
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 2,
|
|
|
|
|
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
|
|
|
|
|
"Fully Connected input should be 2-D or 4-D tensor.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(w_dims.size() == 2,
|
|
|
|
|
"Fully Connected input should be 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4,
|
|
|
|
|
"Fully Connected input should be 2-D or 4-D tensor.");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
|
|
|
|
|
ctx->ShareLoD("Input", "Out");
|
|
|
|
@ -73,19 +74,9 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
|
|
|
|
|
|
|
|
|
|
FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput(
|
|
|
|
|
"Input",
|
|
|
|
|
"(Tensor) The input tensor of fully connected operator. "
|
|
|
|
|
"The format of input tensor is NCHW, where N is batch size, C is the "
|
|
|
|
|
"number of channels, H is the height of the feature, "
|
|
|
|
|
"and W is the width of the feature.");
|
|
|
|
|
AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
|
|
|
|
|
AddInput("W", "(Tensor), The second input tensor of fc op.");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor) The output tensor of fully connected operator. "
|
|
|
|
|
"The format of output tensor is also NCHW, "
|
|
|
|
|
"where N is batch size, C is the number of channels, "
|
|
|
|
|
"H is the height of the feature, "
|
|
|
|
|
"and W is the width of the feature.");
|
|
|
|
|
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
|
|
|
|
|
AddAttr<bool>("use_mkldnn",
|
|
|
|
|
"(bool, default false) Only used in mkldnn kernel")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|