|
|
|
@ -30,27 +30,26 @@ class InterpOp : public NetOp {
|
|
|
|
|
"Input(Y) of InterpOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName,
|
|
|
|
|
"Input(W) of InterpOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(Output("MinusOut"), framework::kEmptyVarName,
|
|
|
|
|
"Output(MinusOut) of InterpOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(Output("SubOut"), framework::kEmptyVarName,
|
|
|
|
|
"Output(SubOut) of InterpOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName,
|
|
|
|
|
"Output(MulOut) of InterpOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
|
|
|
|
|
"Output(Out) of InterpOp should not be null.");
|
|
|
|
|
|
|
|
|
|
// MinusOut = X - Y
|
|
|
|
|
// SubOut = X - Y
|
|
|
|
|
auto x = Input("X");
|
|
|
|
|
auto y = Input("Y");
|
|
|
|
|
auto minus_out = Output("MinusOut");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp("elementwise_sub",
|
|
|
|
|
{{"X", {x}}, {"Y", {y}}},
|
|
|
|
|
{{"Out", {minus_out}}}, {}));
|
|
|
|
|
auto sub_out = Output("SubOut");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"elementwise_sub", {{"X", {x}}, {"Y", {y}}}, {{"Out", {sub_out}}}, {}));
|
|
|
|
|
|
|
|
|
|
// MulOut = MinusOut * W = (X - Y) * W
|
|
|
|
|
// MulOut = SubOut * W = (X - Y) * W
|
|
|
|
|
auto w = Input("W");
|
|
|
|
|
auto mul_out = Output("MulOut");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"elementwise_mul", {{"X", {minus_out}}, {"Y", {w}}},
|
|
|
|
|
{{"Out", {mul_out}}}, {{"axis", 0}}));
|
|
|
|
|
"elementwise_mul", {{"X", {sub_out}}, {"Y", {w}}}, {{"Out", {mul_out}}},
|
|
|
|
|
{{"axis", 0}}));
|
|
|
|
|
|
|
|
|
|
// Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W)
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp("elementwise_add",
|
|
|
|
@ -65,18 +64,26 @@ class InterpOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "A 2-D Tensor, the first input of interp_op");
|
|
|
|
|
AddInput("Y", "A 2-D Tensor, the second input of interp_op");
|
|
|
|
|
AddInput("W", "A 1-D Tensor, the interpolated values");
|
|
|
|
|
AddOutput("MinusOut",
|
|
|
|
|
"A 2-D Tensor, the intermediate outputs, saving X - Y.")
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(Tensor), 2-D Matrix of shape [batch_size, data_dim]"
|
|
|
|
|
"containing data samples, the first input of interp_op");
|
|
|
|
|
AddInput("Y",
|
|
|
|
|
"(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`"
|
|
|
|
|
"containing data samples, the second input of interp_op");
|
|
|
|
|
AddInput("W",
|
|
|
|
|
"(Tensor), 1-D Vector of shape [batch_size],"
|
|
|
|
|
"the interpolated values in the half-open interval [0.0, 1.0)");
|
|
|
|
|
AddOutput("SubOut",
|
|
|
|
|
"(Tensor), the intermediate subtraction outputs, saving X - Y.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("MulOut",
|
|
|
|
|
"A 2-D Tensor, the intermediate outputs,"
|
|
|
|
|
"saving the mul mul of (X - Y) and W")
|
|
|
|
|
"(Tensor), the intermediate multiplication outputs,"
|
|
|
|
|
"saving the elementwise multiplication of (X - Y) and W.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"A 2-D Tensor, the output of interp_op, same shape with X");
|
|
|
|
|
"(Tensor), the output of interp_op, same shape with X,"
|
|
|
|
|
"returns the first-dimensional piecewise linear interpolant "
|
|
|
|
|
"between X and Y");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Linear Interpolation with two inputs, used in NEURAL TURING MACHINE.
|
|
|
|
|
|
|
|
|
|