Tree conv op (#15217)
* refactor tree2col operator with new memory mechanism test=develop * test=develop * test=develop * Modified API according to panyx0718 test=develop * fix API change according to heavengate test=develop * Modify API comment test=developrecover_files
parent
8f522c15ed
commit
e2ba9668b4
@ -0,0 +1,197 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/math/tree2col.h"
|
||||
#include <deque>
|
||||
#include <stack>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
using Tensor = framework::Tensor;
|
||||
std::vector<TreeNode> Tree2ColUtil::construct_patch(
|
||||
size_t root, int max_depth, const std::vector<std::vector<int>> &tr) {
|
||||
std::stack<TreeNode, std::deque<TreeNode>> stack;
|
||||
std::unordered_map<int, bool> visited;
|
||||
std::vector<TreeNode> patch;
|
||||
|
||||
stack.push(TreeNode(root, 1, 1, 0));
|
||||
patch.emplace_back(TreeNode(root, 1, 1, 0));
|
||||
visited[root] = true;
|
||||
|
||||
while (!stack.empty()) {
|
||||
TreeNode &u = stack.top();
|
||||
bool end = true;
|
||||
size_t node = u.get_node(), sz = tr[node].size();
|
||||
visited[node] = true;
|
||||
for (size_t i = 0; i < sz; i++) {
|
||||
size_t v = tr[node][i];
|
||||
if (!visited[v] && static_cast<int>(u.get_depth()) + 1 < max_depth) {
|
||||
visited[v] = true;
|
||||
stack.push(TreeNode(v, i, sz, u.get_depth() + 1));
|
||||
patch.push_back(TreeNode(v, i + 1, sz, u.get_depth() + 1));
|
||||
end = false;
|
||||
}
|
||||
}
|
||||
if (end) {
|
||||
stack.pop();
|
||||
}
|
||||
}
|
||||
return patch;
|
||||
}
|
||||
|
||||
void Tree2ColUtil::construct_tree(const paddle::Tensor &EdgeSet,
|
||||
std::vector<std::vector<int>> *tr,
|
||||
size_t *node_count) {
|
||||
auto edge_set_dims = EdgeSet.dims();
|
||||
PADDLE_ENFORCE_EQ(edge_set_dims[1], 2);
|
||||
int64_t edge_count = EdgeSet.numel();
|
||||
|
||||
const int *edge_data = EdgeSet.data<int>();
|
||||
|
||||
for (int64_t i = 0; i < edge_count; i += 2) {
|
||||
int u = edge_data[i], v = edge_data[i + 1];
|
||||
if (u != 0 && v != 0) (*node_count)++;
|
||||
}
|
||||
(*node_count)++;
|
||||
|
||||
tr->resize(static_cast<size_t>(*node_count + 1));
|
||||
|
||||
for (int64_t i = 0; i < edge_count; i += 2) {
|
||||
int u = edge_data[i], v = edge_data[i + 1];
|
||||
if (u != 0 && v != 0) {
|
||||
tr->at(u).push_back(v);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Tree2ColFunctor<platform::CPUDeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CPUDeviceContext &context,
|
||||
const framework::Tensor &EdgeSet,
|
||||
const framework::Tensor &node_features,
|
||||
framework::Tensor *patch, int max_depth) {
|
||||
std::vector<std::vector<int>> tr;
|
||||
auto feature_dims = node_features.dims();
|
||||
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
|
||||
math::SetConstant<platform::CPUDeviceContext, T> constant;
|
||||
int64_t feature_size = feature_dims[1];
|
||||
size_t patch_elem_size = 3 * static_cast<size_t>(feature_size);
|
||||
size_t node_count = 0, patch_count = 0, patch_size;
|
||||
Tree2ColUtil::construct_tree(EdgeSet, &tr, &node_count);
|
||||
std::vector<std::vector<TreeNode>> processing_list;
|
||||
for (size_t u = 1; u <= node_count; u++) {
|
||||
std::vector<TreeNode> temp_patch =
|
||||
Tree2ColUtil::construct_patch(u, max_depth, tr);
|
||||
if (!temp_patch.empty()) {
|
||||
processing_list.emplace_back(temp_patch);
|
||||
}
|
||||
}
|
||||
patch_size = processing_list.size();
|
||||
|
||||
T *patch_data =
|
||||
patch->mutable_data<T>({static_cast<int64_t>(patch_size),
|
||||
static_cast<int64_t>(patch_elem_size)},
|
||||
cpu_place);
|
||||
constant(context, patch, 0);
|
||||
const T *features = node_features.data<T>();
|
||||
|
||||
for (auto &patch_item : processing_list) {
|
||||
size_t pointer_base = patch_count * patch_elem_size;
|
||||
for (auto &v : patch_item) {
|
||||
T eta_l = v.eta_l<T>(max_depth), eta_r = v.eta_r<T>(max_depth),
|
||||
eta_t = v.eta_t<T>(max_depth);
|
||||
size_t id = v.get_node() - 1;
|
||||
for (int i = 0; i < feature_size; i++) {
|
||||
patch_data[pointer_base + i * 3] +=
|
||||
eta_l * features[id * feature_size + i];
|
||||
patch_data[pointer_base + i * 3 + 1] +=
|
||||
eta_r * features[id * feature_size + i];
|
||||
patch_data[pointer_base + i * 3 + 2] +=
|
||||
eta_t * features[id * feature_size + i];
|
||||
}
|
||||
}
|
||||
patch_count++;
|
||||
}
|
||||
patch->Resize({static_cast<int64_t>(patch_count),
|
||||
static_cast<int64_t>(patch_elem_size)});
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
class Col2TreeFunctor<platform::CPUDeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CPUDeviceContext &context,
|
||||
const framework::Tensor &EdgeSet,
|
||||
const framework::Tensor &out_grad, framework::Tensor *in_grad,
|
||||
int max_depth) {
|
||||
std::vector<std::vector<int>> tr;
|
||||
auto output_dims = out_grad.dims();
|
||||
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
|
||||
math::SetConstant<platform::CPUDeviceContext, T> constant;
|
||||
int64_t output_size = output_dims[1];
|
||||
size_t grad_elem_size = 3 * static_cast<size_t>(output_size);
|
||||
size_t node_count = 0, grad_count = 0;
|
||||
Tree2ColUtil::construct_tree(EdgeSet, &tr, &node_count);
|
||||
std::vector<std::vector<TreeNode>> processing_list;
|
||||
std::vector<std::vector<TreeNode>> grad_list;
|
||||
grad_list.resize(node_count);
|
||||
for (size_t u = 1; u <= node_count; u++) {
|
||||
std::vector<TreeNode> tmp =
|
||||
Tree2ColUtil::construct_patch(u, max_depth, tr);
|
||||
if (!tmp.empty()) {
|
||||
processing_list.push_back(tmp);
|
||||
}
|
||||
}
|
||||
for (size_t patch_id = 0; patch_id < processing_list.size(); patch_id++) {
|
||||
for (auto v : processing_list[patch_id]) {
|
||||
grad_list[v.get_node() - 1].push_back(v.change_node(patch_id + 1));
|
||||
}
|
||||
}
|
||||
T *grad_data =
|
||||
in_grad->mutable_data<T>({static_cast<int64_t>(node_count),
|
||||
static_cast<int64_t>(grad_elem_size)},
|
||||
cpu_place);
|
||||
|
||||
constant(context, in_grad, 0);
|
||||
const T *out_g = out_grad.data<T>();
|
||||
for (auto &patch_item : grad_list) {
|
||||
size_t pointer_base = grad_count * grad_elem_size;
|
||||
for (auto &v : patch_item) {
|
||||
T eta_l = v.eta_l<T>(max_depth), eta_r = v.eta_r<T>(max_depth),
|
||||
eta_t = v.eta_t<T>(max_depth);
|
||||
size_t id = v.get_node() - 1;
|
||||
for (int i = 0; i < output_size; i++) {
|
||||
grad_data[pointer_base + i * 3] +=
|
||||
eta_l * out_g[id * output_size + i];
|
||||
grad_data[pointer_base + i * 3 + 1] +=
|
||||
eta_r * out_g[id * output_size + i];
|
||||
grad_data[pointer_base + i * 3 + 2] +=
|
||||
eta_t * out_g[id * output_size + i];
|
||||
}
|
||||
}
|
||||
grad_count++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class Tree2ColFunctor<platform::CPUDeviceContext, float>;
|
||||
template class Tree2ColFunctor<platform::CPUDeviceContext, double>;
|
||||
template class Col2TreeFunctor<platform::CPUDeviceContext, float>;
|
||||
template class Col2TreeFunctor<platform::CPUDeviceContext, double>;
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,208 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stack>
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/operators/math/tree2col.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
using Tensor = framework::Tensor;
|
||||
using Node = paddle::operators::math::TreeNode;
|
||||
template <typename T>
|
||||
__global__ void tree2col(const T* eta, const int* node, const int* index,
|
||||
const T* vectors, T* result, int feature_size, int n) {
|
||||
const int thread_id =
|
||||
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
|
||||
const int patch_id = thread_id / feature_size;
|
||||
const int j = thread_id % feature_size;
|
||||
if (patch_id < n) {
|
||||
const int begin_o = patch_id * 3 * feature_size;
|
||||
const int begin = index[patch_id * 2], end = index[patch_id * 2 + 1];
|
||||
T res_l = 0, res_r = 0, res_t = 0;
|
||||
for (int i = begin; i < end; i++) {
|
||||
const int id = node[i];
|
||||
const T vec = vectors[id * feature_size + j];
|
||||
res_l += eta[i * 3] * vec;
|
||||
res_r += eta[i * 3 + 1] * vec;
|
||||
res_t += eta[i * 3 + 2] * vec;
|
||||
}
|
||||
result[begin_o + j * 3] = res_l;
|
||||
result[begin_o + j * 3 + 1] = res_r;
|
||||
result[begin_o + j * 3 + 2] = res_t;
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
class Tree2ColFunctor<platform::CUDADeviceContext, T> {
|
||||
public:
|
||||
void operator()(const paddle::platform::CUDADeviceContext& context,
|
||||
const framework::Tensor& EdgeSet,
|
||||
const framework::Tensor& node_features,
|
||||
framework::Tensor* patch, int max_depth) {
|
||||
std::vector<std::vector<int>> tr;
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
|
||||
auto cpu_place = platform::CPUPlace();
|
||||
auto stream = context.stream();
|
||||
auto feature_dims = node_features.dims();
|
||||
math::SetConstant<platform::CUDADeviceContext, T> constant;
|
||||
|
||||
Tensor EdgeSet_cpu;
|
||||
framework::TensorCopy(EdgeSet, cpu_place, &EdgeSet_cpu);
|
||||
int64_t feature_size = feature_dims[1];
|
||||
size_t patch_elem_size = 3 * static_cast<size_t>(feature_size);
|
||||
size_t node_count = 0, patch_count = 0, total_size = 0;
|
||||
size_t max_size = feature_dims[0];
|
||||
Tree2ColUtil::construct_tree(EdgeSet_cpu, &tr, &node_count);
|
||||
|
||||
std::vector<std::vector<Node>> processing_list;
|
||||
for (size_t u = 1; u <= node_count; u++) {
|
||||
std::vector<Node> tmp = Tree2ColUtil::construct_patch(u, max_depth, tr);
|
||||
if (!tmp.empty()) {
|
||||
processing_list.push_back(tmp);
|
||||
total_size += tmp.size();
|
||||
}
|
||||
}
|
||||
|
||||
size_t patch_size = processing_list.size();
|
||||
Tensor node_cpu, node_gpu, eta_cpu, eta_gpu, index_cpu, index_gpu;
|
||||
int* node = node_cpu.mutable_data<int>({static_cast<int64_t>(total_size)},
|
||||
cpu_place);
|
||||
T* eta = eta_cpu.mutable_data<T>({static_cast<int64_t>(total_size * 3)},
|
||||
cpu_place);
|
||||
int* index = index_cpu.mutable_data<int>(
|
||||
{static_cast<int64_t>(patch_size * 2)}, cpu_place);
|
||||
|
||||
int idx = 0, index_idx = 0;
|
||||
for (auto& tmp : processing_list) {
|
||||
index[index_idx++] = idx;
|
||||
for (auto& v : tmp) {
|
||||
node[idx] = static_cast<int>(v.node - 1);
|
||||
eta[idx * 3] = v.eta_l<T>(max_depth);
|
||||
eta[idx * 3 + 1] = v.eta_r<T>(max_depth);
|
||||
eta[idx * 3 + 2] = v.eta_t<T>(max_depth);
|
||||
idx++;
|
||||
}
|
||||
index[index_idx++] = idx;
|
||||
}
|
||||
framework::TensorCopy(node_cpu, gpu_place, context, &node_gpu);
|
||||
framework::TensorCopy(eta_cpu, gpu_place, context, &eta_gpu);
|
||||
framework::TensorCopy(index_cpu, gpu_place, context, &index_gpu);
|
||||
|
||||
int elem_size = patch_size * feature_size;
|
||||
int blocks = (elem_size + 1024 - 1) / 1024;
|
||||
int block_x = 512;
|
||||
int block_y = (blocks + 512 - 1) / 512;
|
||||
dim3 threads(1024, 1);
|
||||
dim3 grid(block_x, block_y);
|
||||
|
||||
patch->mutable_data<T>(
|
||||
{static_cast<int64_t>(max_size), static_cast<int64_t>(patch_elem_size)},
|
||||
gpu_place);
|
||||
constant(context, patch, 0);
|
||||
tree2col<T><<<grid, threads, 0, stream>>>(
|
||||
eta_gpu.data<T>(), node_gpu.data<int>(), index_gpu.data<int>(),
|
||||
node_features.data<T>(), patch->data<T>(), feature_size, patch_size);
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
class Col2TreeFunctor<platform::CUDADeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CUDADeviceContext& context,
|
||||
const framework::Tensor& EdgeSet,
|
||||
const framework::Tensor& patch_grad,
|
||||
framework::Tensor* embedding_grad, int max_depth) {
|
||||
std::vector<std::vector<int>> tr;
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
|
||||
auto cpu_place = platform::CPUPlace();
|
||||
auto stream = context.stream();
|
||||
auto output_dims = patch_grad.dims();
|
||||
math::SetConstant<platform::CUDADeviceContext, T> constant;
|
||||
|
||||
Tensor EdgeSet_cpu;
|
||||
framework::TensorCopy(EdgeSet, cpu_place, &EdgeSet_cpu);
|
||||
int64_t output_size = output_dims[1];
|
||||
size_t patch_elem_size = 3 * static_cast<size_t>(output_size);
|
||||
size_t node_count = 0, patch_count = 0;
|
||||
size_t max_size = output_dims[0];
|
||||
Tree2ColUtil::construct_tree(EdgeSet_cpu, &tr, &node_count);
|
||||
std::vector<std::vector<Node>> processing_list;
|
||||
std::vector<std::vector<Node>> grad_list;
|
||||
grad_list.resize(node_count);
|
||||
size_t total_size = 0, grad_size = node_count;
|
||||
for (size_t u = 1; u <= node_count; u++) {
|
||||
std::vector<Node> tmp = Tree2ColUtil::construct_patch(u, max_depth, tr);
|
||||
if (!tmp.empty()) {
|
||||
processing_list.push_back(tmp);
|
||||
}
|
||||
}
|
||||
for (size_t patch_id = 0; patch_id < processing_list.size(); patch_id++) {
|
||||
for (auto v : processing_list[patch_id]) {
|
||||
grad_list[v.get_node() - 1].push_back(v.change_node(patch_id + 1));
|
||||
}
|
||||
}
|
||||
for (auto& tmp : grad_list) {
|
||||
total_size += tmp.size();
|
||||
}
|
||||
|
||||
Tensor node_cpu, node_gpu, eta_cpu, eta_gpu, index_cpu, index_gpu;
|
||||
int* node = node_cpu.mutable_data<int>({static_cast<int64_t>(total_size)},
|
||||
cpu_place);
|
||||
T* eta = eta_cpu.mutable_data<T>({static_cast<int64_t>(total_size * 3)},
|
||||
cpu_place);
|
||||
int* index = index_cpu.mutable_data<int>(
|
||||
{static_cast<int64_t>(grad_size * 2)}, cpu_place);
|
||||
|
||||
size_t idx = 0, index_idx = 0;
|
||||
for (auto& tmp : grad_list) {
|
||||
index[index_idx++] = idx;
|
||||
for (auto& v : tmp) {
|
||||
node[idx] = static_cast<int>(v.node - 1);
|
||||
eta[idx * 3] = v.eta_l<T>(max_depth);
|
||||
eta[idx * 3 + 1] = v.eta_r<T>(max_depth);
|
||||
eta[idx * 3 + 2] = v.eta_t<T>(max_depth);
|
||||
idx++;
|
||||
}
|
||||
index[index_idx++] = idx;
|
||||
}
|
||||
framework::TensorCopy(node_cpu, gpu_place, &node_gpu);
|
||||
framework::TensorCopy(eta_cpu, gpu_place, &eta_gpu);
|
||||
framework::TensorCopy(index_cpu, gpu_place, &index_gpu);
|
||||
|
||||
int elem_size = output_size * grad_size;
|
||||
int blocks = (elem_size + 1024 - 1) / 1024;
|
||||
int block_x = 512;
|
||||
int block_y = (blocks + 512 - 1) / 512;
|
||||
dim3 threads(1024, 1);
|
||||
dim3 grid(block_x, block_y);
|
||||
|
||||
embedding_grad->mutable_data<T>(
|
||||
{static_cast<int64_t>(max_size), static_cast<int64_t>(patch_elem_size)},
|
||||
gpu_place);
|
||||
|
||||
constant(context, embedding_grad, 0);
|
||||
tree2col<T><<<grid, threads, 0, stream>>>(
|
||||
eta_gpu.data<T>(), node_gpu.data<int>(), index_gpu.data<int>(),
|
||||
patch_grad.data<T>(), embedding_grad->data<T>(), output_size,
|
||||
grad_size);
|
||||
}
|
||||
};
|
||||
|
||||
template class Tree2ColFunctor<platform::CUDADeviceContext, float>;
|
||||
template class Tree2ColFunctor<platform::CUDADeviceContext, double>;
|
||||
template class Col2TreeFunctor<platform::CUDADeviceContext, float>;
|
||||
template class Col2TreeFunctor<platform::CUDADeviceContext, double>;
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,90 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
using Tensor = framework::Tensor;
|
||||
using DDim = framework::DDim;
|
||||
namespace operators {
|
||||
namespace math {
|
||||
class TreeNode {
|
||||
public:
|
||||
size_t node;
|
||||
explicit TreeNode(size_t node = 0, size_t index = 0, size_t pclen = 0,
|
||||
size_t depth = 0)
|
||||
: node(node), index(index), pclen(pclen), depth(depth) {}
|
||||
template <typename T>
|
||||
T eta_t(T filter_depth) {
|
||||
return ((filter_depth - this->depth) / filter_depth);
|
||||
}
|
||||
template <typename T>
|
||||
T eta_l(T filter_depth) {
|
||||
T temp;
|
||||
if (this->pclen == 1) {
|
||||
temp = 0.5;
|
||||
} else {
|
||||
temp = (this->index - 1.0) / (this->pclen - 1.0);
|
||||
}
|
||||
return (1.0 - this->eta_t<T>(filter_depth)) * temp;
|
||||
}
|
||||
template <typename T>
|
||||
T eta_r(T filter_depth) {
|
||||
return (1.0 - this->eta_t<T>(filter_depth)) *
|
||||
(1.0 - this->eta_l<T>(filter_depth));
|
||||
}
|
||||
TreeNode change_node(size_t v) {
|
||||
return TreeNode(v, this->index, this->pclen, this->depth);
|
||||
}
|
||||
size_t get_node() { return this->node; }
|
||||
size_t get_depth() { return this->depth; }
|
||||
|
||||
private:
|
||||
size_t index, pclen, depth;
|
||||
};
|
||||
class Tree2ColUtil {
|
||||
public:
|
||||
static std::vector<TreeNode> construct_patch(
|
||||
size_t root, int max_depth, const std::vector<std::vector<int>> &tr);
|
||||
|
||||
static void construct_tree(const Tensor &EdgeSet,
|
||||
std::vector<std::vector<int>> *tr,
|
||||
size_t *node_count);
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class Tree2ColFunctor {
|
||||
public:
|
||||
void operator()(const DeviceContext &context,
|
||||
const framework::Tensor &EdgeSet,
|
||||
const framework::Tensor &node_features,
|
||||
framework::Tensor *patch, int max_depth);
|
||||
};
|
||||
template <typename DeviceContext, typename T>
|
||||
class Col2TreeFunctor {
|
||||
public:
|
||||
void operator()(const DeviceContext &context,
|
||||
const framework::Tensor &EdgeSet,
|
||||
const framework::Tensor &out_grad, framework::Tensor *in_grad,
|
||||
int max_depth);
|
||||
};
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,129 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/tree_conv_op.h"
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
class TreeConvOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("NodesVector",
|
||||
"(Tensor) The feature vector of every node on the tree. "
|
||||
"The shape of the feature vector must be "
|
||||
"[max_tree_node_size, feature_size].");
|
||||
AddInput("EdgeSet",
|
||||
"(Tensor) The Edges of Tree. The edge must be directional. "
|
||||
"The shape of the edge set must be [max_tree_node_size, 2].");
|
||||
AddInput("Filter",
|
||||
"(Tensor) The feature detector. "
|
||||
"The shape of the filter is "
|
||||
"[feature_size, 3, output_size, num_filters].");
|
||||
AddOutput("Out",
|
||||
"(Tensor) The feature vector of subtrees. "
|
||||
"The shape of the output tensor is [max_tree_node_size, "
|
||||
"output_size, num_filters]. "
|
||||
"The output tensor could be a new feature "
|
||||
"vector for next tree convolution layers.");
|
||||
AddAttr<int>("max_depth",
|
||||
"(int, default: 2) The depth of feature detector.")
|
||||
.SetDefault(2)
|
||||
.GreaterThan(1);
|
||||
AddComment(R"DOC(
|
||||
**Tree-Based Convolution Operator**
|
||||
|
||||
Tree-Based Convolution is a kind of convolution based on tree structure.
|
||||
Tree-Based Convolution is a part of Tree-Based Convolution Neural Network(TBCNN),
|
||||
which is used to classify tree structures, such as Abstract Syntax Tree.
|
||||
Tree-Based Convolution proposed a kind of data structure called continuous binary tree,
|
||||
which regards multiway tree as binary tree.
|
||||
The paper of Tree-Based Convolution Operator is here:
|
||||
https://arxiv.org/abs/1409.5718v1
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
class TreeConvOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"));
|
||||
auto edge_dims = ctx->GetInputDim("EdgeSet");
|
||||
auto vector_dims = ctx->GetInputDim("NodesVector");
|
||||
auto filter_dims = ctx->GetInputDim("Filter");
|
||||
PADDLE_ENFORCE_EQ(edge_dims[2], 2, "Input(EdgeSet) dim[2] should be 2");
|
||||
PADDLE_ENFORCE_EQ(edge_dims.size(), 3,
|
||||
"The dimension of EdgeSet Tensor should be 3");
|
||||
PADDLE_ENFORCE_EQ(vector_dims.size(), 3,
|
||||
"The dimension of NodesVector Tensor should be 3");
|
||||
PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
|
||||
"The dimension of Filter Tensor should be 4");
|
||||
PADDLE_ENFORCE_EQ(filter_dims[1], 3, "Input(Filter) dim[1] should be 3");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
filter_dims[0], vector_dims[2],
|
||||
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]");
|
||||
auto output_dims = framework::make_ddim(
|
||||
{vector_dims[0], vector_dims[1], filter_dims[2], filter_dims[3]});
|
||||
ctx->SetOutputDim("Out", output_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("NodesVector")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class TreeConvGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
auto vectors_dims = ctx->GetInputDim("NodesVector");
|
||||
auto filter_dims = ctx->GetInputDim("Filter");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"the gradient of output(Out) must not be null");
|
||||
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
|
||||
}
|
||||
if (ctx->HasOutput(framework::GradVarName("NodesVector"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("NodesVector"), vectors_dims);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("NodesVector")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(tree_conv, ops::TreeConvOp, ops::TreeConvOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
|
||||
REGISTER_OPERATOR(tree_conv_grad, ops::TreeConvGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
tree_conv, ops::TreeConvKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::TreeConvKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
tree_conv_grad,
|
||||
ops::TreeConvGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::TreeConvGradKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,24 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/tree_conv_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
tree_conv, ops::TreeConvKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::TreeConvKernel<paddle::platform::CUDADeviceContext, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
tree_conv_grad,
|
||||
ops::TreeConvGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::TreeConvGradKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/tree2col.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using Tensor = framework::Tensor;
|
||||
using DDim = framework::DDim;
|
||||
template <typename DeviceContext, typename T>
|
||||
class TreeConvKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
math::Tree2ColFunctor<DeviceContext, T> tree2col;
|
||||
math::SetConstant<DeviceContext, T> constant;
|
||||
|
||||
auto *Edges = ctx.Input<Tensor>("EdgeSet");
|
||||
auto *Embeddings = ctx.Input<Tensor>("NodesVector");
|
||||
auto *Filter = ctx.Input<Tensor>("Filter");
|
||||
auto *output_emb = ctx.Output<Tensor>("Out");
|
||||
int max_depth = ctx.Attr<int>("max_depth");
|
||||
|
||||
auto &dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
|
||||
|
||||
Tensor W;
|
||||
W.ShareDataWith(*Filter);
|
||||
W.Resize(framework::flatten_to_2d(Filter->dims(), 2));
|
||||
|
||||
int batch_size = static_cast<int>(Edges->dims()[0]);
|
||||
int n = static_cast<int>(Embeddings->dims()[1]);
|
||||
int out_size = static_cast<int>(Filter->dims()[2]);
|
||||
int num_filters = static_cast<int>(Filter->dims()[3]);
|
||||
output_emb->mutable_data<T>({batch_size, n, out_size, num_filters},
|
||||
ctx.GetPlace());
|
||||
|
||||
auto edge_set_slicedim = framework::slice_ddim(
|
||||
Edges->dims(), 1, static_cast<int>(Edges->dims().size()));
|
||||
|
||||
auto embedding_slicedim = framework::slice_ddim(
|
||||
Embeddings->dims(), 1, static_cast<int>(Embeddings->dims().size()));
|
||||
|
||||
auto output_slicedim = framework::slice_ddim(
|
||||
output_emb->dims(), 1, static_cast<int>(output_emb->dims().size()));
|
||||
|
||||
output_slicedim = framework::flatten_to_2d(output_slicedim, 1);
|
||||
|
||||
for (int idx = 0; idx < batch_size; idx++) {
|
||||
auto edge_set = Edges->Slice(idx, idx + 1).Resize(edge_set_slicedim);
|
||||
auto embeddings =
|
||||
Embeddings->Slice(idx, idx + 1).Resize(embedding_slicedim);
|
||||
auto out_vec = output_emb->Slice(idx, idx + 1).Resize(output_slicedim);
|
||||
Tensor patch;
|
||||
tree2col(dev_ctx, edge_set, embeddings, &patch, max_depth);
|
||||
constant(dev_ctx, &out_vec, 0);
|
||||
blas.MatMul(patch, W, &out_vec);
|
||||
}
|
||||
}
|
||||
};
|
||||
template <typename DeviceContext, typename T>
|
||||
class TreeConvGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto *out_g = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto *in_g = ctx.Output<Tensor>(framework::GradVarName("NodesVector"));
|
||||
auto *filter_g = ctx.Output<Tensor>(framework::GradVarName("Filter"));
|
||||
int max_depth = ctx.Attr<int>("max_depth");
|
||||
auto *Embeddings = ctx.Input<Tensor>("NodesVector");
|
||||
auto *edges = ctx.Input<Tensor>("EdgeSet");
|
||||
auto *Filter = ctx.Input<Tensor>("Filter");
|
||||
math::Tree2ColFunctor<DeviceContext, T> tree2col;
|
||||
math::Col2TreeFunctor<DeviceContext, T> col2tree;
|
||||
math::SetConstant<DeviceContext, T> constant;
|
||||
auto &dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
|
||||
|
||||
Tensor W;
|
||||
W.ShareDataWith(*Filter);
|
||||
W.Resize(framework::flatten_to_2d(Filter->dims(), 1));
|
||||
|
||||
int batch_size = static_cast<int>(Embeddings->dims()[0]);
|
||||
|
||||
auto edge_set_slicedim = framework::slice_ddim(
|
||||
edges->dims(), 1, static_cast<int>(edges->dims().size()));
|
||||
|
||||
auto embedding_slicedim = framework::slice_ddim(
|
||||
Embeddings->dims(), 1, static_cast<int>(Embeddings->dims().size()));
|
||||
|
||||
auto out_grad_dims = framework::slice_ddim(
|
||||
out_g->dims(), 1, static_cast<int>(out_g->dims().size()));
|
||||
out_grad_dims = framework::flatten_to_2d(out_grad_dims, 1);
|
||||
if (filter_g) {
|
||||
filter_g->mutable_data<T>(Filter->dims(), ctx.GetPlace());
|
||||
Tensor f_g;
|
||||
f_g.ShareDataWith(*filter_g);
|
||||
f_g.Resize(framework::flatten_to_2d(Filter->dims(), 2));
|
||||
constant(dev_ctx, filter_g, 0);
|
||||
for (int batch_id = 0; batch_id < batch_size; batch_id++) {
|
||||
auto edge_set =
|
||||
edges->Slice(batch_id, batch_id + 1).Resize(edge_set_slicedim);
|
||||
auto embeddings = Embeddings->Slice(batch_id, batch_id + 1)
|
||||
.Resize(embedding_slicedim);
|
||||
auto out_grad =
|
||||
out_g->Slice(batch_id, batch_id + 1).Resize(out_grad_dims);
|
||||
Tensor patch;
|
||||
tree2col(dev_ctx, edge_set, embeddings, &patch, max_depth);
|
||||
blas.MatMul(patch, true, out_grad, false, T(1.0), &f_g, T(1.0));
|
||||
}
|
||||
}
|
||||
if (in_g) {
|
||||
auto input_grad_dims = framework::slice_ddim(
|
||||
in_g->dims(), 1, static_cast<int>(in_g->dims().size()));
|
||||
in_g->mutable_data<T>(Embeddings->dims(), ctx.GetPlace());
|
||||
constant(dev_ctx, in_g, 0);
|
||||
for (int batch_id = 0; batch_id < batch_size; batch_id++) {
|
||||
auto edge_set =
|
||||
edges->Slice(batch_id, batch_id + 1).Resize(edge_set_slicedim);
|
||||
auto out_grad =
|
||||
out_g->Slice(batch_id, batch_id + 1).Resize(out_grad_dims);
|
||||
auto in_grad =
|
||||
in_g->Slice(batch_id, batch_id + 1).Resize(input_grad_dims);
|
||||
Tensor in_grad_temp;
|
||||
col2tree(dev_ctx, edge_set, out_grad, &in_grad_temp, max_depth);
|
||||
blas.MatMul(in_grad_temp, false, W, true, &in_grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,120 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def collect_node_patch(og, max_depth):
|
||||
"""
|
||||
The naive method to construct patches
|
||||
:param og: original graph
|
||||
:param max_depth: the depth of convolution filters
|
||||
:return: convolution patches
|
||||
"""
|
||||
|
||||
def gen(node, max_depth):
|
||||
collected = [(node, 1, 1, 0, max_depth)]
|
||||
|
||||
def recurse_helper(node, depth):
|
||||
if depth > max_depth:
|
||||
return
|
||||
l = len(og[node])
|
||||
for idx, c in enumerate(og[node], 1):
|
||||
if depth + 1 < max_depth:
|
||||
collected.append((c, idx, l, depth + 1, max_depth))
|
||||
recurse_helper(c, depth + 1)
|
||||
|
||||
recurse_helper(node, 0)
|
||||
return collected
|
||||
|
||||
res = []
|
||||
for u in range(1, len(og)):
|
||||
lis = gen(u, max_depth)
|
||||
if len(lis) > 0:
|
||||
res.append(lis)
|
||||
return res
|
||||
|
||||
|
||||
class TestTreeConvOp(OpTest):
|
||||
def setUp(self):
|
||||
self.n = 17
|
||||
self.fea_size = 3
|
||||
self.output_size = 1
|
||||
self.max_depth = 2
|
||||
self.batch_size = 1
|
||||
self.num_filters = 1
|
||||
adj_array = [
|
||||
1, 2, 1, 3, 1, 4, 1, 5, 2, 6, 2, 7, 2, 8, 4, 9, 4, 10, 5, 11, 6, 12,
|
||||
6, 13, 9, 14, 9, 15, 9, 16, 9, 17
|
||||
]
|
||||
adj = np.array(adj_array).reshape((1, self.n - 1, 2)).astype('int32')
|
||||
adj = np.tile(adj, (self.batch_size, 1, 1))
|
||||
self.op_type = 'tree_conv'
|
||||
vectors = np.random.random(
|
||||
(self.batch_size, self.n, self.fea_size)).astype('float32')
|
||||
self.inputs = {
|
||||
'EdgeSet': adj,
|
||||
'NodesVector': vectors,
|
||||
'Filter': np.random.random((self.fea_size, 3, self.output_size,
|
||||
self.num_filters)).astype('float32')
|
||||
}
|
||||
self.attrs = {'max_depth': self.max_depth}
|
||||
vectors = []
|
||||
for i in range(self.batch_size):
|
||||
vector = self.get_output_naive(i)
|
||||
vectors.append(vector)
|
||||
self.outputs = {'Out': np.array(vectors).astype('float32'), }
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(
|
||||
['NodesVector', 'Filter'], 'Out', max_relative_error=0.5)
|
||||
|
||||
def get_output_naive(self, batch_id):
|
||||
og = [[] for i in range(1, self.n + 2)]
|
||||
st = np.array(self.inputs['EdgeSet'][batch_id]).tolist()
|
||||
for e in st:
|
||||
og[e[0]].append(e[1])
|
||||
patches = collect_node_patch(og, self.max_depth)
|
||||
W = np.array(self.inputs['Filter']).astype('float32')
|
||||
W = np.transpose(W, axes=[1, 0, 2, 3])
|
||||
vec = []
|
||||
for i, patch in enumerate(patches, 1):
|
||||
result = np.zeros((1, W.shape[2], W.shape[3]))
|
||||
for v in patch:
|
||||
eta_t = float(v[4] - v[3]) / float(v[4])
|
||||
eta_l = (1.0 - eta_t) * (0.5 if v[2] == 1 else
|
||||
float(v[1] - 1.0) / float(v[2] - 1.0))
|
||||
eta_r = (1.0 - eta_t) * (1.0 - eta_l)
|
||||
x = self.inputs['NodesVector'][batch_id][v[0] - 1]
|
||||
eta = np.array([eta_l, eta_r, eta_t]).reshape(
|
||||
(3, 1)).astype('float32')
|
||||
Wconvi = np.tensordot(eta, W, axes=([0], [0]))
|
||||
x = np.array(x).reshape((1, 1, self.fea_size))
|
||||
res = np.tensordot(x, Wconvi, axes=2)
|
||||
result = result + res
|
||||
vec.append(result)
|
||||
vec = np.concatenate(vec, axis=0)
|
||||
vec = np.concatenate(
|
||||
[
|
||||
vec, np.zeros(
|
||||
(self.n - vec.shape[0], W.shape[2], W.shape[3]),
|
||||
dtype='float32')
|
||||
],
|
||||
axis=0)
|
||||
return vec
|
Loading…
Reference in new issue