Add data transform fn (#6953)
* init data_transform * complete DataTransform * fix build error * add data_transform_test * add a register test for data_transform_fn * use function to simulate registration macro * add register macro * update test * clean code * restore unrelated code * update data transform test * generate unique name for REGISTER_DATA_TRANSFORM_FN * add const * follow comment * update KernelTypePair hash functiondel_some_in_makelist
parent
19f2475af1
commit
f97f69feec
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/data_transform.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
DataTransformFnMap& DataTransformFnMap::Instance() {
|
||||
static DataTransformFnMap data_transform_map;
|
||||
return data_transform_map;
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,110 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/framework/op_kernel_type.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/framework/variable.h"
|
||||
#include "paddle/platform/device_context.h"
|
||||
#include "paddle/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
using DataTransformFN =
|
||||
std::function<void(const std::vector<platform::DeviceContext*> ctx,
|
||||
const Variable& in, Variable* out)>;
|
||||
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
|
||||
|
||||
static void hash_combine(std::size_t& seed, const OpKernelType& t) {
|
||||
OpKernelType::Hash kernel_type_hasher;
|
||||
seed ^= kernel_type_hasher(t) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
struct KernelTypePairHash {
|
||||
size_t operator()(const KernelTypePair& kernel_pair) const {
|
||||
std::size_t seed = 0;
|
||||
hash_combine(seed, kernel_pair.first);
|
||||
hash_combine(seed, kernel_pair.second);
|
||||
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
using DataTransformMap =
|
||||
std::unordered_map<KernelTypePair, DataTransformFN, KernelTypePairHash>;
|
||||
|
||||
class DataTransformFnMap {
|
||||
public:
|
||||
static DataTransformFnMap& Instance();
|
||||
|
||||
bool Has(const KernelTypePair& key_pair) const {
|
||||
return map_.find(key_pair) != map_.end();
|
||||
}
|
||||
|
||||
void Insert(const OpKernelType& left, const OpKernelType& right,
|
||||
const DataTransformFN& data_tranform_fn) {
|
||||
Insert(std::make_pair(left, right), data_tranform_fn);
|
||||
}
|
||||
|
||||
void Insert(const KernelTypePair& kernel_type_pair,
|
||||
const DataTransformFN& data_tranform_fn) {
|
||||
PADDLE_ENFORCE(!Has(kernel_type_pair),
|
||||
"KernelTypePair %s has been registered", "");
|
||||
map_.insert({kernel_type_pair, data_tranform_fn});
|
||||
}
|
||||
|
||||
const DataTransformFN& Get(const KernelTypePair& key_pair) const {
|
||||
auto data_transformer = GetNullable(key_pair);
|
||||
PADDLE_ENFORCE_NOT_NULL(data_transformer,
|
||||
"DataTransformFN should not be NULL");
|
||||
return *data_transformer;
|
||||
}
|
||||
|
||||
const DataTransformFN* GetNullable(const KernelTypePair& key_pair) const {
|
||||
auto it = map_.find(key_pair);
|
||||
if (it == map_.end()) {
|
||||
return nullptr;
|
||||
} else {
|
||||
return &(it->second);
|
||||
}
|
||||
}
|
||||
|
||||
const DataTransformMap& Map() const { return map_; }
|
||||
|
||||
private:
|
||||
DataTransformFnMap() = default;
|
||||
DataTransformMap map_;
|
||||
DISABLE_COPY_AND_ASSIGN(DataTransformFnMap);
|
||||
};
|
||||
|
||||
// generate unique name with __LINE__
|
||||
// refs https://stackoverflow.com/questions/1597007
|
||||
#define TOKENPASTE(x, y) x##y
|
||||
#define TOKENPASTE2(x, y) TOKENPASTE(x, y)
|
||||
#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \
|
||||
static int TOKENPASTE2(fn_, __LINE__)() { \
|
||||
::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \
|
||||
return 0; \
|
||||
} \
|
||||
static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \
|
||||
TOKENPASTE2(fn_, __LINE__)()
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,78 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/data_transform.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
using namespace platform;
|
||||
|
||||
int test_value = 0;
|
||||
|
||||
OpKernelType kernel_type_1(proto::DataType::FP32, CPUPlace(), DataLayout::kNCHW,
|
||||
LibraryType::kCUDNN);
|
||||
OpKernelType kernel_type_2(proto::DataType::FP32, CUDAPlace(0),
|
||||
DataLayout::kNCHW, LibraryType::kCUDNN);
|
||||
OpKernelType kernel_type_3(proto::DataType::FP16, CUDAPlace(0),
|
||||
DataLayout::kNCHW, LibraryType::kCUDNN);
|
||||
|
||||
void type1_to_type2(std::vector<platform::DeviceContext*> ctx,
|
||||
const Variable& in, Variable* out) {
|
||||
test_value++;
|
||||
}
|
||||
|
||||
void type2_to_type3(std::vector<platform::DeviceContext*> ctx,
|
||||
const Variable& in, Variable* out) {
|
||||
test_value--;
|
||||
}
|
||||
|
||||
void type1_to_type3(std::vector<platform::DeviceContext*> ctx,
|
||||
const Variable& in, Variable* out) {
|
||||
test_value += 2;
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
namespace frw = paddle::framework;
|
||||
|
||||
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_2,
|
||||
frw::type1_to_type2);
|
||||
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_2, frw::kernel_type_3,
|
||||
frw::type2_to_type3);
|
||||
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_3,
|
||||
frw::type1_to_type3);
|
||||
|
||||
TEST(DataTransform, Register) {
|
||||
using namespace paddle::framework;
|
||||
using namespace paddle::platform;
|
||||
|
||||
auto& instance = DataTransformFnMap::Instance();
|
||||
ASSERT_EQ(instance.Map().size(), 3UL);
|
||||
std::vector<DeviceContext*> ctx;
|
||||
paddle::framework::Variable in;
|
||||
paddle::framework::Variable out;
|
||||
|
||||
instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_2))(ctx, in,
|
||||
&out);
|
||||
ASSERT_EQ(test_value, 1);
|
||||
instance.Get(std::make_pair(frw::kernel_type_2, frw::kernel_type_3))(ctx, in,
|
||||
&out);
|
||||
ASSERT_EQ(test_value, 0);
|
||||
instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_3))(ctx, in,
|
||||
&out);
|
||||
ASSERT_EQ(test_value, 2);
|
||||
}
|
Loading…
Reference in new issue