commit
6f5e64af17
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/data_transform.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
DataTransformFnMap& DataTransformFnMap::Instance() {
|
||||
static DataTransformFnMap data_transform_map;
|
||||
return data_transform_map;
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,109 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/framework/op_kernel_type.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/framework/variable.h"
|
||||
#include "paddle/platform/device_context.h"
|
||||
#include "paddle/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
using DataTransformFN =
|
||||
std::function<void(const std::vector<platform::DeviceContext*> ctx,
|
||||
const Variable& in, Variable* out)>;
|
||||
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
|
||||
|
||||
struct KernelTypePairHash {
|
||||
static void HashCombine(const OpKernelType& t, std::size_t* seed) {
|
||||
OpKernelType::Hash kernel_type_hasher;
|
||||
(*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
|
||||
}
|
||||
|
||||
size_t operator()(const KernelTypePair& kernel_pair) const {
|
||||
std::size_t seed = 0;
|
||||
HashCombine(kernel_pair.first, &seed);
|
||||
HashCombine(kernel_pair.second, &seed);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
using DataTransformMap =
|
||||
std::unordered_map<KernelTypePair, DataTransformFN, KernelTypePairHash>;
|
||||
|
||||
class DataTransformFnMap {
|
||||
public:
|
||||
static DataTransformFnMap& Instance();
|
||||
|
||||
bool Has(const KernelTypePair& key_pair) const {
|
||||
return map_.find(key_pair) != map_.end();
|
||||
}
|
||||
|
||||
void Insert(const OpKernelType& left, const OpKernelType& right,
|
||||
const DataTransformFN& data_tranform_fn) {
|
||||
Insert(std::make_pair(left, right), data_tranform_fn);
|
||||
}
|
||||
|
||||
void Insert(const KernelTypePair& kernel_type_pair,
|
||||
const DataTransformFN& data_tranform_fn) {
|
||||
PADDLE_ENFORCE(!Has(kernel_type_pair),
|
||||
"KernelTypePair %s has been registered", "");
|
||||
map_.insert({kernel_type_pair, data_tranform_fn});
|
||||
}
|
||||
|
||||
const DataTransformFN& Get(const KernelTypePair& key_pair) const {
|
||||
auto data_transformer = GetNullable(key_pair);
|
||||
PADDLE_ENFORCE_NOT_NULL(data_transformer,
|
||||
"DataTransformFN should not be NULL");
|
||||
return *data_transformer;
|
||||
}
|
||||
|
||||
const DataTransformFN* GetNullable(const KernelTypePair& key_pair) const {
|
||||
auto it = map_.find(key_pair);
|
||||
if (it == map_.end()) {
|
||||
return nullptr;
|
||||
} else {
|
||||
return &(it->second);
|
||||
}
|
||||
}
|
||||
|
||||
const DataTransformMap& Map() const { return map_; }
|
||||
|
||||
private:
|
||||
DataTransformFnMap() = default;
|
||||
DataTransformMap map_;
|
||||
DISABLE_COPY_AND_ASSIGN(DataTransformFnMap);
|
||||
};
|
||||
|
||||
// generate unique name with __LINE__
|
||||
// refs https://stackoverflow.com/questions/1597007
|
||||
#define TOKENPASTE(x, y) x##y
|
||||
#define TOKENPASTE2(x, y) TOKENPASTE(x, y)
|
||||
#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \
|
||||
static int TOKENPASTE2(fn_, __LINE__)() { \
|
||||
::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \
|
||||
return 0; \
|
||||
} \
|
||||
static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \
|
||||
TOKENPASTE2(fn_, __LINE__)()
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,78 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/data_transform.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
using namespace platform;
|
||||
|
||||
int test_value = 0;
|
||||
|
||||
OpKernelType kernel_type_1(proto::DataType::FP32, CPUPlace(), DataLayout::kNCHW,
|
||||
LibraryType::kCUDNN);
|
||||
OpKernelType kernel_type_2(proto::DataType::FP32, CUDAPlace(0),
|
||||
DataLayout::kNCHW, LibraryType::kCUDNN);
|
||||
OpKernelType kernel_type_3(proto::DataType::FP16, CUDAPlace(0),
|
||||
DataLayout::kNCHW, LibraryType::kCUDNN);
|
||||
|
||||
void type1_to_type2(std::vector<platform::DeviceContext*> ctx,
|
||||
const Variable& in, Variable* out) {
|
||||
test_value++;
|
||||
}
|
||||
|
||||
void type2_to_type3(std::vector<platform::DeviceContext*> ctx,
|
||||
const Variable& in, Variable* out) {
|
||||
test_value--;
|
||||
}
|
||||
|
||||
void type1_to_type3(std::vector<platform::DeviceContext*> ctx,
|
||||
const Variable& in, Variable* out) {
|
||||
test_value += 2;
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
namespace frw = paddle::framework;
|
||||
|
||||
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_2,
|
||||
frw::type1_to_type2);
|
||||
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_2, frw::kernel_type_3,
|
||||
frw::type2_to_type3);
|
||||
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_3,
|
||||
frw::type1_to_type3);
|
||||
|
||||
TEST(DataTransform, Register) {
|
||||
using namespace paddle::framework;
|
||||
using namespace paddle::platform;
|
||||
|
||||
auto& instance = DataTransformFnMap::Instance();
|
||||
ASSERT_EQ(instance.Map().size(), 3UL);
|
||||
std::vector<DeviceContext*> ctx;
|
||||
paddle::framework::Variable in;
|
||||
paddle::framework::Variable out;
|
||||
|
||||
instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_2))(ctx, in,
|
||||
&out);
|
||||
ASSERT_EQ(test_value, 1);
|
||||
instance.Get(std::make_pair(frw::kernel_type_2, frw::kernel_type_3))(ctx, in,
|
||||
&out);
|
||||
ASSERT_EQ(test_value, 0);
|
||||
instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_3))(ctx, in,
|
||||
&out);
|
||||
ASSERT_EQ(test_value, 2);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue