Add select_input_op and select_output_op (#21016)
These ops are useful in control flow.custom_op_abi
parent
fc02c2995e
commit
1957192f05
@ -0,0 +1,72 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
class AssignFunctor {
|
||||
public:
|
||||
AssignFunctor(framework::Variable *out,
|
||||
const platform::DeviceContext &dev_ctx)
|
||||
: out_(out), dev_ctx_(dev_ctx) {}
|
||||
|
||||
void operator()(const framework::LoDTensor &lod_tensor) const {
|
||||
auto &out_tensor = *out_->GetMutable<framework::LoDTensor>();
|
||||
copy_tensor(lod_tensor, &out_tensor);
|
||||
}
|
||||
|
||||
void operator()(const framework::LoDTensorArray &array) const {
|
||||
auto &out_array = *out_->GetMutable<framework::LoDTensorArray>();
|
||||
out_array.resize(array.size());
|
||||
for (size_t i = 0; i < array.size(); ++i) {
|
||||
copy_tensor(array[i], &out_array[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const framework::SelectedRows &rows) const {
|
||||
framework::SelectedRows &out_rows =
|
||||
*out_->GetMutable<framework::SelectedRows>();
|
||||
out_rows.set_rows(rows.rows());
|
||||
out_rows.set_height(rows.height());
|
||||
auto &t = rows.value();
|
||||
auto *m = out_rows.mutable_value();
|
||||
framework::TensorCopy(t, t.place(), dev_ctx_, m);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void operator()(const T &v) const {
|
||||
PADDLE_THROW("Not support type for assign op %s", typeid(T).name());
|
||||
}
|
||||
|
||||
private:
|
||||
void copy_tensor(const framework::LoDTensor &lod_tensor,
|
||||
framework::LoDTensor *out) const {
|
||||
if (lod_tensor.numel() == 0) return;
|
||||
auto &out_tensor = *out;
|
||||
TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor);
|
||||
out_tensor.set_lod(lod_tensor.lod());
|
||||
}
|
||||
|
||||
framework::Variable *out_;
|
||||
const platform::DeviceContext &dev_ctx_;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,118 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#include "paddle/fluid/operators/assign_op.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/variable.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
TEST(AssignOp, AssignLoDTensor) {
|
||||
paddle::platform::CPUPlace cpu_place;
|
||||
paddle::platform::CPUDeviceContext ctx(cpu_place);
|
||||
|
||||
paddle::framework::Variable output;
|
||||
paddle::operators::AssignFunctor assign_functor(&output, ctx);
|
||||
|
||||
paddle::framework::LoDTensor input;
|
||||
paddle::framework::DDim in_dims = paddle::framework::make_ddim({3, 4});
|
||||
int* in_data = input.mutable_data<int>(in_dims, cpu_place);
|
||||
for (int i = 0; i < 12; ++i) {
|
||||
in_data[i] = i;
|
||||
}
|
||||
|
||||
assign_functor(input);
|
||||
|
||||
auto& out_tensor = output.Get<paddle::framework::LoDTensor>();
|
||||
paddle::framework::DDim out_dims = out_tensor.dims();
|
||||
EXPECT_EQ(in_dims, out_dims);
|
||||
auto* out_data = out_tensor.data<int>();
|
||||
for (int i = 0; i < 12; ++i) {
|
||||
EXPECT_EQ(i, out_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(AssignOp, AssignLoDTensorArray) {
|
||||
paddle::platform::CPUPlace cpu_place;
|
||||
paddle::platform::CPUDeviceContext ctx(cpu_place);
|
||||
|
||||
paddle::framework::Variable output;
|
||||
paddle::operators::AssignFunctor assign_functor(&output, ctx);
|
||||
|
||||
paddle::framework::LoDTensorArray input;
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
paddle::framework::DDim in_dims =
|
||||
paddle::framework::make_ddim({i + 1, i + 2});
|
||||
paddle::framework::LoDTensor lod_tensor;
|
||||
float* in_data = lod_tensor.mutable_data<float>(in_dims, cpu_place);
|
||||
for (int j = 0; j < (i + 1) * (i + 2); ++j) {
|
||||
in_data[j] = static_cast<float>(j);
|
||||
}
|
||||
input.push_back(lod_tensor);
|
||||
}
|
||||
|
||||
assign_functor(input);
|
||||
|
||||
auto& out_array = output.Get<paddle::framework::LoDTensorArray>();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
paddle::framework::DDim out_dims = out_array[i].dims();
|
||||
EXPECT_EQ(paddle::framework::make_ddim({i + 1, i + 2}), out_dims);
|
||||
const float* out_data = out_array[i].data<float>();
|
||||
for (int j = 0; j < (i + 1) * (i + 2); ++j) {
|
||||
EXPECT_EQ(static_cast<float>(j), out_data[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(AssignOp, AssignSelectedRows) {
|
||||
paddle::platform::CPUPlace cpu_place;
|
||||
paddle::platform::CPUDeviceContext ctx(cpu_place);
|
||||
|
||||
paddle::framework::Variable output;
|
||||
paddle::operators::AssignFunctor assign_functor(&output, ctx);
|
||||
|
||||
std::vector<int64_t> rows{0, 4, 7};
|
||||
int64_t height = 10;
|
||||
|
||||
paddle::framework::SelectedRows input(rows, height);
|
||||
paddle::framework::Tensor* input_tensor = input.mutable_value();
|
||||
|
||||
paddle::framework::DDim in_dims = paddle::framework::make_ddim({3, 4});
|
||||
int* in_data = input_tensor->mutable_data<int>(in_dims, cpu_place);
|
||||
for (int i = 0; i < 12; ++i) {
|
||||
in_data[i] = i;
|
||||
}
|
||||
|
||||
assign_functor(input);
|
||||
|
||||
auto& out_selected_row = output.Get<paddle::framework::SelectedRows>();
|
||||
const paddle::framework::Vector<int64_t>& out_rows = out_selected_row.rows();
|
||||
EXPECT_EQ(rows.size(), out_rows.size());
|
||||
for (size_t i = 0; i < rows.size(); ++i) {
|
||||
EXPECT_EQ(rows[i], out_rows[i]);
|
||||
}
|
||||
EXPECT_EQ(height, out_selected_row.height());
|
||||
const paddle::framework::Tensor& out_tensor = out_selected_row.value();
|
||||
paddle::framework::DDim out_dims = out_tensor.dims();
|
||||
EXPECT_EQ(in_dims, out_dims);
|
||||
auto* out_data = out_tensor.data<int>();
|
||||
for (int i = 0; i < 12; ++i) {
|
||||
EXPECT_EQ(i, out_data[i]);
|
||||
}
|
||||
}
|
@ -0,0 +1,114 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/memory/memcpy.h"
|
||||
#include "paddle/fluid/operators/assign_op.h"
|
||||
#include "paddle/fluid/operators/select_op_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// SelectInputOp takes multiple inputs and uses an integer mask to select
|
||||
// one input to output. It is used in control flow.
|
||||
class SelectInputOp : public framework::OperatorBase {
|
||||
public:
|
||||
SelectInputOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &dev_place) const override {
|
||||
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
||||
auto &dev_ctx = *pool.Get(dev_place);
|
||||
|
||||
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
|
||||
size_t output_branch = static_cast<size_t>(GetBranchNumber(mask));
|
||||
|
||||
const std::vector<std::string> &x_names = Inputs("X");
|
||||
PADDLE_ENFORCE_LT(output_branch, x_names.size(),
|
||||
"Selected branch number is greater than actual branch "
|
||||
"num in SelectInputOp");
|
||||
|
||||
const framework::Variable *selected_x =
|
||||
scope.FindVar(x_names[output_branch]);
|
||||
framework::Variable *out = scope.FindVar(Output("Out"));
|
||||
framework::VisitVarType(*selected_x, AssignFunctor(out, dev_ctx));
|
||||
}
|
||||
};
|
||||
|
||||
class SelectInputOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"The input LoDTensors or LoDTensorArray or SelectedRows. All "
|
||||
"inputs must have same variable type")
|
||||
.AsDuplicable();
|
||||
AddInput("Mask",
|
||||
"A integer tensor with numel 1 specifying which input to output");
|
||||
AddOutput(
|
||||
"Out",
|
||||
"The merged output. The variable type of output must be same as X");
|
||||
// TODO(huihuangzheng): decide whether to add support for lod level
|
||||
// Because this op is blocking whole control flow. I am implementing MVP
|
||||
// (minimal viable product) here.
|
||||
AddComment(R"DOC(
|
||||
Merge branches of LoDTensor into a single Output with a mask interger
|
||||
specifying the output branchi.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SelectInputInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *context) const override {
|
||||
PADDLE_ENFORCE_EQ(context->HasInputs("X"), true,
|
||||
"SelectInputOp must have input X.");
|
||||
PADDLE_ENFORCE_EQ(context->HasInput("Mask"), true,
|
||||
"SelectInputOp must have input Mask.");
|
||||
PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true,
|
||||
"SelectInputOp must have output Out.");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class SelectInputGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<T> Apply() const override {
|
||||
auto *grad_op = new T();
|
||||
grad_op->SetType("select_output");
|
||||
grad_op->SetInput("X", this->OutputGrad("Out"));
|
||||
grad_op->SetInput("Mask", this->Input("Mask"));
|
||||
grad_op->SetOutput("Out",
|
||||
this->InputGrad("X", /* drop_empty_grad */ false));
|
||||
grad_op->SetAttrMap(this->Attrs());
|
||||
return std::unique_ptr<T>(grad_op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(select_input, ops::SelectInputOp,
|
||||
ops::SelectInputOpProtoMaker, ops::SelectInputInferShape,
|
||||
ops::SelectInputGradMaker<paddle::framework::OpDesc>,
|
||||
ops::SelectInputGradMaker<paddle::imperative::OpBase>);
|
@ -0,0 +1,47 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
// Functions used in SelectInputOp and SelectOutputOp
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// Returns the integer in mask whose numel must be 1. The integer means the
|
||||
// selected branch number.
|
||||
inline int GetBranchNumber(const framework::LoDTensor &mask) {
|
||||
PADDLE_ENFORCE_EQ(mask.numel(), 1,
|
||||
"Mask in SelectOutputOp must have numel 1.");
|
||||
if (platform::is_cpu_place(mask.place())) {
|
||||
return mask.data<int>()[0];
|
||||
}
|
||||
// when platform::is_gpu_place(mask.place()) is ture
|
||||
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get());
|
||||
#else
|
||||
PADDLE_THROW(
|
||||
"This version of PaddlePaddle doen NOT support GPU but got GPU tensor "
|
||||
"Mask in SelectOutputOp. Please compile WITH_GPU option");
|
||||
#endif
|
||||
return cpu_mask->data<int>()[0];
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,110 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/memory/memcpy.h"
|
||||
#include "paddle/fluid/operators/assign_op.h"
|
||||
#include "paddle/fluid/operators/select_op_helper.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// SelectOutputOp has one input, one integer mask and multiple outputs. It
|
||||
// selects one output specified by the mask and copy the input to it.
|
||||
class SelectOutputOp : public framework::OperatorBase {
|
||||
public:
|
||||
SelectOutputOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &dev_place) const override {
|
||||
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
||||
auto &dev_ctx = *pool.Get(dev_place);
|
||||
|
||||
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
|
||||
size_t output_branch = static_cast<size_t>(GetBranchNumber(mask));
|
||||
|
||||
const std::vector<std::string> &out_names = Outputs("Out");
|
||||
PADDLE_ENFORCE_LT(output_branch, out_names.size(),
|
||||
"Selected branch number is greater than actual branch "
|
||||
"num in SelectOutputOp");
|
||||
|
||||
const framework::Variable *x = scope.FindVar(Input("X"));
|
||||
framework::Variable *selected_out = scope.FindVar(out_names[output_branch]);
|
||||
framework::VisitVarType(*x, AssignFunctor(selected_out, dev_ctx));
|
||||
}
|
||||
};
|
||||
|
||||
class SelectOutputOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "The input LoDTensor or LoDTensorArray or SelectedRows.");
|
||||
AddInput("Mask", "Tensor with numel 1 specifying which branch to output");
|
||||
AddOutput("Out",
|
||||
"The output can contains multiple variables. The output of "
|
||||
"selected branch will be same as input. We do nothing for "
|
||||
"variables in other branch")
|
||||
.AsDuplicable();
|
||||
// TODO(huihuangzheng): decide whether to add support for lod level
|
||||
// Because this op is blocking whole control flow. I am implementing MVP
|
||||
// (minimal viable product) here.
|
||||
AddComment(R"DOC(
|
||||
Split input variable into one output branch. The mask is an integer tensor to
|
||||
specify which output branch should copy the input.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SelectOutputInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *context) const override {
|
||||
PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
|
||||
"SelectOutputOp must have input X.");
|
||||
PADDLE_ENFORCE_EQ(context->HasInput("Mask"), true,
|
||||
"SelectOutputOp must have input Mask.");
|
||||
PADDLE_ENFORCE_EQ(context->HasOutputs("Out"), true,
|
||||
"SelectOutputOp must have output Out.");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class SelectOutputGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<T> Apply() const override {
|
||||
auto *grad_op = new T();
|
||||
grad_op->SetType("select_input");
|
||||
grad_op->SetInput("Mask", this->Input("Mask"));
|
||||
grad_op->SetInput("X", this->OutputGrad("Out"));
|
||||
grad_op->SetOutput("Out", this->InputGrad("X"));
|
||||
grad_op->SetAttrMap(this->Attrs());
|
||||
return std::unique_ptr<T>(grad_op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(select_output, ops::SelectOutputOp,
|
||||
ops::SelectOutputOpProtoMaker, ops::SelectOutputInferShape,
|
||||
ops::SelectOutputGradMaker<paddle::framework::OpDesc>,
|
||||
ops::SelectOutputGradMaker<paddle::imperative::OpBase>);
|
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
import paddle.fluid.layers as layers
|
||||
from paddle.fluid.backward import append_backward
|
||||
from paddle.fluid.executor import Executor
|
||||
from paddle.fluid.framework import Program, program_guard
|
||||
from paddle.fluid.layers.control_flow import select_input, select_output
|
||||
|
||||
|
||||
class TestSplitMergeSelectedVarOps(unittest.TestCase):
|
||||
def test_forward_backward(self):
|
||||
branch_num = 9
|
||||
program = Program()
|
||||
with program_guard(program):
|
||||
x = layers.data(name='x', shape=[2], dtype='float32')
|
||||
x.stop_gradient = False # For test gradient
|
||||
mask = layers.data(name='mask', shape=[1], dtype='int32')
|
||||
|
||||
outputs = []
|
||||
for i in range(branch_num):
|
||||
out = program.current_block().create_var(
|
||||
dtype='float32', type=core.VarDesc.VarType.LOD_TENSOR)
|
||||
outputs.append(out)
|
||||
|
||||
select_output(x, outputs, mask)
|
||||
y = select_input(outputs, mask)
|
||||
mean = layers.mean(y)
|
||||
append_backward(mean)
|
||||
|
||||
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
|
||||
) else fluid.CPUPlace()
|
||||
exe = Executor(place)
|
||||
|
||||
feed_x = np.asarray([1.3, -1.4]).astype(np.float32)
|
||||
for i in range(branch_num):
|
||||
feed_mask = np.asarray([i]).astype(np.int32)
|
||||
ret = exe.run(program,
|
||||
feed={'x': feed_x,
|
||||
'mask': feed_mask},
|
||||
fetch_list=[y.name, x.grad_name])
|
||||
x_grad = np.asarray([0.5, 0.5]).astype(np.float32)
|
||||
self.assertTrue(np.allclose(np.asarray(ret[0]), feed_x))
|
||||
self.assertTrue(np.allclose(np.asarray(ret[1]), x_grad))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue