Add data transform fn (#6953)
* init data_transform * complete DataTransform * fix build error * add data_transform_test * add a register test for data_transform_fn * use function to simulate registration macro * add register macro * update test * clean code * restore unrelated code * update data transform test * generate unique name for REGISTER_DATA_TRANSFORM_FN * add const * follow comment * update KernelTypePair hash functiondel_some_in_makelist
parent
19f2475af1
commit
f97f69feec
@ -0,0 +1,26 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/framework/data_transform.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
DataTransformFnMap& DataTransformFnMap::Instance() {
|
||||||
|
static DataTransformFnMap data_transform_map;
|
||||||
|
return data_transform_map;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,110 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/framework/op_kernel_type.h"
|
||||||
|
#include "paddle/framework/tensor.h"
|
||||||
|
#include "paddle/framework/variable.h"
|
||||||
|
#include "paddle/platform/device_context.h"
|
||||||
|
#include "paddle/platform/macros.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
using DataTransformFN =
|
||||||
|
std::function<void(const std::vector<platform::DeviceContext*> ctx,
|
||||||
|
const Variable& in, Variable* out)>;
|
||||||
|
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
|
||||||
|
|
||||||
|
static void hash_combine(std::size_t& seed, const OpKernelType& t) {
|
||||||
|
OpKernelType::Hash kernel_type_hasher;
|
||||||
|
seed ^= kernel_type_hasher(t) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct KernelTypePairHash {
|
||||||
|
size_t operator()(const KernelTypePair& kernel_pair) const {
|
||||||
|
std::size_t seed = 0;
|
||||||
|
hash_combine(seed, kernel_pair.first);
|
||||||
|
hash_combine(seed, kernel_pair.second);
|
||||||
|
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using DataTransformMap =
|
||||||
|
std::unordered_map<KernelTypePair, DataTransformFN, KernelTypePairHash>;
|
||||||
|
|
||||||
|
class DataTransformFnMap {
|
||||||
|
public:
|
||||||
|
static DataTransformFnMap& Instance();
|
||||||
|
|
||||||
|
bool Has(const KernelTypePair& key_pair) const {
|
||||||
|
return map_.find(key_pair) != map_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Insert(const OpKernelType& left, const OpKernelType& right,
|
||||||
|
const DataTransformFN& data_tranform_fn) {
|
||||||
|
Insert(std::make_pair(left, right), data_tranform_fn);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Insert(const KernelTypePair& kernel_type_pair,
|
||||||
|
const DataTransformFN& data_tranform_fn) {
|
||||||
|
PADDLE_ENFORCE(!Has(kernel_type_pair),
|
||||||
|
"KernelTypePair %s has been registered", "");
|
||||||
|
map_.insert({kernel_type_pair, data_tranform_fn});
|
||||||
|
}
|
||||||
|
|
||||||
|
const DataTransformFN& Get(const KernelTypePair& key_pair) const {
|
||||||
|
auto data_transformer = GetNullable(key_pair);
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(data_transformer,
|
||||||
|
"DataTransformFN should not be NULL");
|
||||||
|
return *data_transformer;
|
||||||
|
}
|
||||||
|
|
||||||
|
const DataTransformFN* GetNullable(const KernelTypePair& key_pair) const {
|
||||||
|
auto it = map_.find(key_pair);
|
||||||
|
if (it == map_.end()) {
|
||||||
|
return nullptr;
|
||||||
|
} else {
|
||||||
|
return &(it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const DataTransformMap& Map() const { return map_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataTransformFnMap() = default;
|
||||||
|
DataTransformMap map_;
|
||||||
|
DISABLE_COPY_AND_ASSIGN(DataTransformFnMap);
|
||||||
|
};
|
||||||
|
|
||||||
|
// generate unique name with __LINE__
|
||||||
|
// refs https://stackoverflow.com/questions/1597007
|
||||||
|
#define TOKENPASTE(x, y) x##y
|
||||||
|
#define TOKENPASTE2(x, y) TOKENPASTE(x, y)
|
||||||
|
#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \
|
||||||
|
static int TOKENPASTE2(fn_, __LINE__)() { \
|
||||||
|
::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \
|
||||||
|
return 0; \
|
||||||
|
} \
|
||||||
|
static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \
|
||||||
|
TOKENPASTE2(fn_, __LINE__)()
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/framework/data_transform.h"
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
using namespace platform;
|
||||||
|
|
||||||
|
int test_value = 0;
|
||||||
|
|
||||||
|
OpKernelType kernel_type_1(proto::DataType::FP32, CPUPlace(), DataLayout::kNCHW,
|
||||||
|
LibraryType::kCUDNN);
|
||||||
|
OpKernelType kernel_type_2(proto::DataType::FP32, CUDAPlace(0),
|
||||||
|
DataLayout::kNCHW, LibraryType::kCUDNN);
|
||||||
|
OpKernelType kernel_type_3(proto::DataType::FP16, CUDAPlace(0),
|
||||||
|
DataLayout::kNCHW, LibraryType::kCUDNN);
|
||||||
|
|
||||||
|
void type1_to_type2(std::vector<platform::DeviceContext*> ctx,
|
||||||
|
const Variable& in, Variable* out) {
|
||||||
|
test_value++;
|
||||||
|
}
|
||||||
|
|
||||||
|
void type2_to_type3(std::vector<platform::DeviceContext*> ctx,
|
||||||
|
const Variable& in, Variable* out) {
|
||||||
|
test_value--;
|
||||||
|
}
|
||||||
|
|
||||||
|
void type1_to_type3(std::vector<platform::DeviceContext*> ctx,
|
||||||
|
const Variable& in, Variable* out) {
|
||||||
|
test_value += 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace frw = paddle::framework;
|
||||||
|
|
||||||
|
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_2,
|
||||||
|
frw::type1_to_type2);
|
||||||
|
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_2, frw::kernel_type_3,
|
||||||
|
frw::type2_to_type3);
|
||||||
|
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_3,
|
||||||
|
frw::type1_to_type3);
|
||||||
|
|
||||||
|
TEST(DataTransform, Register) {
|
||||||
|
using namespace paddle::framework;
|
||||||
|
using namespace paddle::platform;
|
||||||
|
|
||||||
|
auto& instance = DataTransformFnMap::Instance();
|
||||||
|
ASSERT_EQ(instance.Map().size(), 3UL);
|
||||||
|
std::vector<DeviceContext*> ctx;
|
||||||
|
paddle::framework::Variable in;
|
||||||
|
paddle::framework::Variable out;
|
||||||
|
|
||||||
|
instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_2))(ctx, in,
|
||||||
|
&out);
|
||||||
|
ASSERT_EQ(test_value, 1);
|
||||||
|
instance.Get(std::make_pair(frw::kernel_type_2, frw::kernel_type_3))(ctx, in,
|
||||||
|
&out);
|
||||||
|
ASSERT_EQ(test_value, 0);
|
||||||
|
instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_3))(ctx, in,
|
||||||
|
&out);
|
||||||
|
ASSERT_EQ(test_value, 2);
|
||||||
|
}
|
Loading…
Reference in new issue