commit
fd559b3a7e
@ -0,0 +1,218 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/cond_op.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/gather.h"
|
||||
#include "paddle/operators/net_op.h"
|
||||
#include "paddle/operators/scatter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Scope = framework::Scope;
|
||||
using Variable = framework::Variable;
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using DDim = framework::DDim;
|
||||
|
||||
void CondOp::CreateScope(const Scope& scope) const {
|
||||
auto sub_scopes_var = scope.FindVar("SubScopes");
|
||||
PADDLE_ENFORCE(sub_scopes_var != nullptr, "");
|
||||
auto sub_scopes = sub_scopes_var->GetMutable<std::vector<Scope*>>();
|
||||
auto& sub_scope = scope.NewScope();
|
||||
sub_scopes->push_back(&sub_scope);
|
||||
}
|
||||
|
||||
void CondOp::CreateIndexTensor(const Scope& scope) const {
|
||||
auto index_tensors_var = scope.FindVar("IndexTensors");
|
||||
PADDLE_ENFORCE(index_tensors_var != nullptr, "");
|
||||
auto& index_tensors =
|
||||
*index_tensors_var->GetMutable<std::vector<LoDTensor>>();
|
||||
index_tensors.push_back(LoDTensor());
|
||||
}
|
||||
|
||||
void CondOp::InferShape(const Scope& scope) const {
|
||||
auto sub_scopes_var = scope.FindVar("SubScopes");
|
||||
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var);
|
||||
auto& sub_scopes = *sub_scopes_var->GetMutable<std::vector<Scope*>>();
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
// Create two sub scopes for true and false branches
|
||||
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
|
||||
// branch
|
||||
CreateScope(scope);
|
||||
|
||||
// Create two tensors for true and false indices
|
||||
// index_tensors[0] for the true branch and index_tensors[1] for the false
|
||||
// branch
|
||||
CreateIndexTensor(scope);
|
||||
|
||||
PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs can't be empty");
|
||||
for (auto& input : Inputs("Xs")) {
|
||||
// Create a new tensor in sub-scope for input-type tensor
|
||||
Variable* v = sub_scopes[i]->NewVar(input);
|
||||
LoDTensor* sub_input = v->GetMutable<LoDTensor>();
|
||||
sub_input->Resize(scope.FindVar(input)->GetMutable<LoDTensor>()->dims());
|
||||
}
|
||||
|
||||
for (auto& output : (*sub_net_op_[i]).Outputs()) {
|
||||
for (auto& var_name : output.second) {
|
||||
sub_scopes[i]->NewVar(var_name);
|
||||
}
|
||||
}
|
||||
|
||||
// each net calls InferShape
|
||||
sub_net_op_[i]->InferShape(*sub_scopes[i]);
|
||||
}
|
||||
|
||||
for (auto& output : Outputs("Outs")) {
|
||||
LoDTensor* tensor_t_out =
|
||||
sub_scopes[0]->FindVar(output)->GetMutable<LoDTensor>();
|
||||
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL");
|
||||
LoDTensor* tensor_f_out =
|
||||
sub_scopes[1]->FindVar(output)->GetMutable<LoDTensor>();
|
||||
PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL");
|
||||
|
||||
auto* tensor_out_var = scope.FindVar(output);
|
||||
PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found");
|
||||
LoDTensor* tensor_out = tensor_out_var->GetMutable<LoDTensor>();
|
||||
PADDLE_ENFORCE_NOT_NULL(tensor_t_out,
|
||||
"True output tensor should not be NULL");
|
||||
|
||||
// check output size should be same
|
||||
PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(),
|
||||
"Outputs not of the same shape");
|
||||
tensor_out->Resize(tensor_t_out->dims());
|
||||
// tensor_out->mutable_data<float>(tensor_out->dims(),
|
||||
// platform::CPUPlace());
|
||||
tensor_out->mutable_data<float>(platform::CPUPlace());
|
||||
}
|
||||
}
|
||||
|
||||
void CondOp::Run(const Scope& scope,
|
||||
const platform::DeviceContext& dev_ctx) const {
|
||||
auto* sub_scopes_var = scope.FindVar("SubScopes");
|
||||
auto sub_scopes = sub_scopes_var->Get<std::vector<Scope*>>();
|
||||
auto* index_tensors_var = scope.FindVar("IndexTensors");
|
||||
auto index_tensors = index_tensors_var->Get<std::vector<LoDTensor>>();
|
||||
|
||||
std::string cond_name = Input("Cond");
|
||||
Variable* cond_var = scope.FindVar(cond_name);
|
||||
PADDLE_ENFORCE_NOT_NULL(cond_var);
|
||||
const LoDTensor* cond = cond_var->GetMutable<LoDTensor>();
|
||||
|
||||
// Step 1: get the true/false index at runtime
|
||||
// index_[0]: vector<int>, contains all index for cond[i] == true
|
||||
// index_[1]: vector<int>, contains all index for cond[i] == false
|
||||
for (int i = 0; i < 2; ++i) index_[i].clear();
|
||||
|
||||
const int* cond_data = cond->data<int>();
|
||||
for (int i = 0; i < cond->dims()[0]; ++i) {
|
||||
if (cond_data[i])
|
||||
index_[0].push_back(i);
|
||||
else
|
||||
index_[1].push_back(i);
|
||||
}
|
||||
|
||||
// put index_[0] and index_[1] into two tensors:
|
||||
// index_tensor_[0] and index_tensor_[1]
|
||||
DDim dim = paddle::framework::make_ddim({0});
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
dim[0] = index_[i].size();
|
||||
int* tmp_ptr =
|
||||
index_tensors[i].mutable_data<int>(dim, platform::CPUPlace());
|
||||
index_tensors[i].Resize(dim);
|
||||
memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int));
|
||||
}
|
||||
|
||||
// Step 2: collect data by calling gather
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
// i= 0/i for True and False branches respectively
|
||||
for (auto& input : Inputs("Xs")) {
|
||||
// find Tensor
|
||||
Variable* v = scope.FindVar(input);
|
||||
PADDLE_ENFORCE_NOT_NULL(v);
|
||||
LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();
|
||||
|
||||
v = sub_scopes[i]->FindVar(input);
|
||||
PADDLE_ENFORCE_NOT_NULL(v);
|
||||
LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
|
||||
|
||||
// Resize child
|
||||
DDim dim = tensor_child->dims();
|
||||
dim[0] = index_[i].size();
|
||||
tensor_child->Resize(dim);
|
||||
tensor_child->mutable_data<float>(dim, platform::CPUPlace());
|
||||
|
||||
Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i],
|
||||
tensor_child);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: run
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
|
||||
}
|
||||
|
||||
// Step 4: merge output results
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
// i= 0/i for True and False branches respectively
|
||||
for (auto& output : Outputs("Outs")) {
|
||||
// find Tensor
|
||||
Variable* v = scope.FindVar(output);
|
||||
PADDLE_ENFORCE_NOT_NULL(v);
|
||||
LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();
|
||||
|
||||
v = sub_scopes[i]->FindVar(output);
|
||||
PADDLE_ENFORCE_NOT_NULL(v);
|
||||
LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
|
||||
|
||||
ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i],
|
||||
tensor_parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CondOpProtoAndCheckerMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Cond", "The condition, which is a bool vector");
|
||||
AddInput("Xs", "Inputs of Subnets").AsDuplicable();
|
||||
AddOutput("Outs", "Outputs of Cond_Op after merge").AsDuplicable();
|
||||
|
||||
AddOutput("SubScopes", "sub scopes for true and false branches");
|
||||
AddOutput("IndexTensors", "Index Tensors contains indices for true/false");
|
||||
|
||||
AddComment(R"DOC(
|
||||
Sample dependent Cond Operator:
|
||||
Given Cond[i] as a 1/0 vector to indicate true/false
|
||||
The equation is:
|
||||
Out[i] = subnet_t[i], if Cond[i] == true
|
||||
Out[i] = subnet_t[i], if Cond[i] == false
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP_WITHOUT_GRADIENT(cond, paddle::operators::CondOp,
|
||||
paddle::operators::CondOpProtoAndCheckerMaker);
|
@ -0,0 +1,91 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/framework/ddim.h"
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/operators/net_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
/*
|
||||
* @brief CondOp is a dynamic if-else Operator
|
||||
*
|
||||
* It has a input tensor named cond indicating which netop each instance will
|
||||
* run.
|
||||
*
|
||||
* if cond == 1, it will run true_net, which is a NetOp.
|
||||
*
|
||||
* if cond == 0, it will run false_net, which is another NetOp.
|
||||
*/
|
||||
class CondOp : public framework::OperatorBase {
|
||||
public:
|
||||
CondOp(const std::string& type, const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {
|
||||
index_.resize(2);
|
||||
sub_net_op_.resize(2);
|
||||
}
|
||||
|
||||
CondOp(const CondOp& o)
|
||||
: framework::OperatorBase(
|
||||
static_cast<const framework::OperatorBase&>(o)) {
|
||||
// TODO(yuyang18): Implement copy ctor well.
|
||||
PADDLE_THROW("Not implemented");
|
||||
}
|
||||
|
||||
void CreateScope(const framework::Scope& scope) const;
|
||||
|
||||
void CreateIndexTensor(const framework::Scope& scope) const;
|
||||
|
||||
/*
|
||||
* InferShape must be called before Run.
|
||||
*/
|
||||
void InferShape(const framework::Scope& scope) const override;
|
||||
|
||||
/*
|
||||
* Set True Block
|
||||
*/
|
||||
void set_truenet(std::unique_ptr<OperatorBase>&& net) {
|
||||
sub_net_op_[0] = std::move(net);
|
||||
}
|
||||
|
||||
/*
|
||||
* Set False Block
|
||||
*/
|
||||
void set_falsenet(std::unique_ptr<OperatorBase>&& net) {
|
||||
sub_net_op_[1] = std::move(net);
|
||||
}
|
||||
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::DeviceContext& dev_ctx) const override;
|
||||
|
||||
private:
|
||||
// sub_net_op_[0]: subnet_t
|
||||
// sub_net_op_[1]: subnet_f
|
||||
std::vector<std::unique_ptr<framework::OperatorBase>> sub_net_op_;
|
||||
|
||||
// index_[0]: True_index;
|
||||
// index_[1]: False_index;
|
||||
mutable std::vector<std::vector<int>> index_;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,56 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __NVCC__
|
||||
#error device_ptr_cast must be include by .cu file
|
||||
#endif
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
namespace details {
|
||||
template <typename T, bool is_ptr>
|
||||
struct DevicePtrCast;
|
||||
|
||||
template <typename T>
|
||||
struct DevicePtrCast<T, true> {
|
||||
using ELEM = typename std::remove_pointer<T>::type;
|
||||
using RTYPE = thrust::device_ptr<ELEM>;
|
||||
|
||||
inline thrust::device_ptr<ELEM> operator()(ELEM* ele) const {
|
||||
return thrust::device_pointer_cast(ele);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DevicePtrCast<T, false> {
|
||||
using RTYPE = T;
|
||||
inline RTYPE operator()(RTYPE it) const { return it; }
|
||||
};
|
||||
|
||||
// Cast T to thrust::device_ptr if T is a pointer.
|
||||
// Otherwise, e.g., T is a iterator, return T itself.
|
||||
template <typename T>
|
||||
auto DevPtrCast(T t) ->
|
||||
typename DevicePtrCast<T, std::is_pointer<T>::value>::RTYPE {
|
||||
DevicePtrCast<T, std::is_pointer<T>::value> cast;
|
||||
return cast(t);
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,66 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/platform/enforce.h"
|
||||
#include "paddle/platform/hostdevice.h"
|
||||
#include "paddle/platform/place.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#ifdef __NVCC__
|
||||
#include <thrust/transform.h>
|
||||
#include "paddle/platform/details/device_ptr_cast.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
// Transform on host or device. It provides the same API in std library.
|
||||
template <typename Place, typename InputIter, typename OutputIter,
|
||||
typename UnaryOperation>
|
||||
void Transform(Place place, InputIter first, InputIter last, OutputIter result,
|
||||
UnaryOperation op) {
|
||||
if (is_cpu_place(place)) {
|
||||
std::transform(first, last, result, op);
|
||||
} else {
|
||||
#ifdef __NVCC__
|
||||
using namespace details;
|
||||
thrust::transform(DevPtrCast(first), DevPtrCast(last), DevPtrCast(result),
|
||||
op);
|
||||
#else
|
||||
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Place, typename InputIter1, typename InputIter2,
|
||||
typename OutputIter, typename BinaryOperation>
|
||||
void Transform(Place place, InputIter1 first1, InputIter1 last1,
|
||||
InputIter2 first2, OutputIter result, BinaryOperation op) {
|
||||
if (is_cpu_place(place)) {
|
||||
std::transform(first1, last1, first2, result, op);
|
||||
} else {
|
||||
#ifdef __NVCC__
|
||||
using namespace details;
|
||||
thrust::transform(DevPtrCast(first1), DevPtrCast(last1), DevPtrCast(first2),
|
||||
DevPtrCast(result), op);
|
||||
#else
|
||||
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,84 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/memory/memcpy.h"
|
||||
#include "paddle/memory/memory.h"
|
||||
#include "paddle/platform/transform.h"
|
||||
|
||||
template <typename T>
|
||||
class Scale {
|
||||
public:
|
||||
explicit Scale(const T& scale) : scale_(scale) {}
|
||||
|
||||
HOSTDEVICE T operator()(const T& a) const { return a * scale_; }
|
||||
|
||||
private:
|
||||
T scale_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Multiply {
|
||||
public:
|
||||
HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
|
||||
};
|
||||
|
||||
TEST(Transform, CPUUnary) {
|
||||
using namespace paddle::platform;
|
||||
float buf[4] = {0.1, 0.2, 0.3, 0.4};
|
||||
Transform(CPUPlace(), buf, buf + 4, buf, Scale<float>(10));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
ASSERT_NEAR(buf[i], static_cast<float>(i + 1), 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Transform, GPUUnary) {
|
||||
using namespace paddle::platform;
|
||||
using namespace paddle::memory;
|
||||
GPUPlace gpu0(0);
|
||||
float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4};
|
||||
float* gpu_buf = static_cast<float*>(Alloc(gpu0, sizeof(float) * 4));
|
||||
Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf));
|
||||
Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, Scale<float>(10));
|
||||
Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf));
|
||||
Free(gpu0, gpu_buf);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
ASSERT_NEAR(cpu_buf[i], static_cast<float>(i + 1), 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Transform, CPUBinary) {
|
||||
using namespace paddle::platform;
|
||||
using namespace paddle::memory;
|
||||
int buf[4] = {1, 2, 3, 4};
|
||||
Transform(CPUPlace(), buf, buf + 4, buf, buf, Multiply<int>());
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
ASSERT_EQ((i + 1) * (i + 1), buf[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Transform, GPUBinary) {
|
||||
using namespace paddle::platform;
|
||||
using namespace paddle::memory;
|
||||
int buf[4] = {1, 2, 3, 4};
|
||||
GPUPlace gpu0(0);
|
||||
int* gpu_buf = static_cast<int*>(Alloc(gpu0, sizeof(buf)));
|
||||
Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf));
|
||||
Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply<int>());
|
||||
Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf));
|
||||
Free(gpu0, gpu_buf);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
ASSERT_EQ((i + 1) * (i + 1), buf[i]);
|
||||
}
|
||||
}
|
@ -0,0 +1,116 @@
|
||||
import logging
|
||||
import paddle.v2.framework.core as core
|
||||
import unittest
|
||||
import numpy as np
|
||||
from paddle.v2.framework.op import Operator, CondOp
|
||||
|
||||
|
||||
class PySimpleCond(object):
|
||||
'''
|
||||
A simple implementation of dynamic if-else based on numpy
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
array = [1] * 10
|
||||
for i in range(1, 10, 2):
|
||||
array[i] = 0
|
||||
self.cond = np.array(array)
|
||||
self.x = np.ones(shape=(10, 1))
|
||||
|
||||
def forward(self):
|
||||
self.index_t = np.where(self.cond == 1)
|
||||
self.index_f = np.where(self.cond == 0)
|
||||
y_t = self.x[self.index_t]
|
||||
y_f = self.x[self.index_f]
|
||||
y_t = y_t * 2.
|
||||
y_f = y_f * (-2.)
|
||||
output = np.zeros(shape=(10, 1))
|
||||
output[self.index_t] = y_t
|
||||
output[self.index_f] = y_f
|
||||
return output
|
||||
|
||||
|
||||
class PySimpleCondTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.condnn = PySimpleCond()
|
||||
|
||||
def test_forward(self):
|
||||
output = self.condnn.forward()
|
||||
|
||||
|
||||
def create_tensor(scope, name, shape, np_data):
|
||||
tensor = scope.new_var(name).get_tensor()
|
||||
tensor.set_dims(shape)
|
||||
tensor.set(np_data, core.CPUPlace())
|
||||
return tensor
|
||||
|
||||
|
||||
class TestCondOp(unittest.TestCase):
|
||||
'''
|
||||
Test CondOp
|
||||
|
||||
equation:
|
||||
cond = [True, False, True, False, ...]
|
||||
y[index_t] = x[index_t] * 2.
|
||||
y[index_f] = x[index_f] * -2.
|
||||
outputs:
|
||||
y
|
||||
'''
|
||||
|
||||
def setUp(self):
|
||||
self.py_cond = PySimpleCond()
|
||||
|
||||
def forward(self):
|
||||
self.scope = core.Scope()
|
||||
self.create_global_variables()
|
||||
self.create_cond_op()
|
||||
self.create_sub_net()
|
||||
ctx = core.DeviceContext.create(core.CPUPlace())
|
||||
self.condop.infer_shape(self.scope)
|
||||
self.condop.run(self.scope, ctx)
|
||||
return np.array(self.scope.find_var("Out").get_tensor())
|
||||
|
||||
def create_global_variables(self):
|
||||
x_np_data = self.py_cond.x
|
||||
create_tensor(self.scope, "X", [10, 1], x_np_data)
|
||||
cond_np_data = self.py_cond.cond.astype("int32")
|
||||
create_tensor(self.scope, "cond", [10, 1], cond_np_data)
|
||||
self.scope.new_var("SubScopes")
|
||||
self.scope.new_var("IndexTensors")
|
||||
self.scope.new_var("Out")
|
||||
|
||||
def create_cond_op(self):
|
||||
self.condop = CondOp(
|
||||
Cond="cond",
|
||||
Xs=["X"],
|
||||
Outs=["Out"],
|
||||
SubScopes="SubScopes",
|
||||
IndexTensors="IndexTensors")
|
||||
|
||||
def create_sub_net(self):
|
||||
truenet = core.Net.create()
|
||||
scale_op_t = Operator("scale", X='X', Out='Out', scale=2.)
|
||||
truenet.append_op(scale_op_t)
|
||||
truenet.complete_add_op(True)
|
||||
self.condop.set_truenet(truenet)
|
||||
|
||||
falsenet = core.Net.create()
|
||||
scale_op_t = Operator("scale", X='X', Out='Out', scale=-2.)
|
||||
falsenet.append_op(scale_op_t)
|
||||
falsenet.complete_add_op(True)
|
||||
self.condop.set_falsenet(falsenet)
|
||||
|
||||
def test_forward(self):
|
||||
print 'test cond op forward'
|
||||
pd_output = self.forward()
|
||||
py_output = self.py_cond.forward()
|
||||
print 'pd_output', pd_output
|
||||
print
|
||||
print 'py_output', py_output
|
||||
self.assertEqual(pd_output.shape, py_output.shape)
|
||||
print 'test passed'
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue