|
|
|
@ -14,43 +14,15 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/prune.h"
|
|
|
|
|
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
#include "paddle/framework/attribute.h"
|
|
|
|
|
#include "paddle/framework/block_desc.h"
|
|
|
|
|
#include "paddle/framework/op_desc.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "paddle/framework/program_desc.h"
|
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
using DeviceContext = platform::DeviceContext;
|
|
|
|
|
|
|
|
|
|
class OneOneOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
OneOneOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("input", "input");
|
|
|
|
|
AddOutput("output", "output");
|
|
|
|
|
AddComment("Op has one input and one output");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class TwoOneOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
TwoOneOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("input_1", "input_1");
|
|
|
|
|
AddInput("input_2", "input_2");
|
|
|
|
|
AddOutput("output", "output");
|
|
|
|
|
AddComment("Op has two inputs and one output");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
#include "paddle/framework/block_desc.h"
|
|
|
|
|
#include "paddle/framework/op_desc.h"
|
|
|
|
|
#include "paddle/framework/program_desc.h"
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
|
|
|
|
|
namespace f = paddle::framework;
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
@ -61,7 +33,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
|
|
|
|
|
// insert output
|
|
|
|
|
for (auto kv : outputs) {
|
|
|
|
|
for (auto v : kv.second) {
|
|
|
|
|
auto var = block->NewVar(v);
|
|
|
|
|
auto var = block->Var(v);
|
|
|
|
|
var->SetDataType(paddle::framework::DataType::FP32);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|