remove ContextMap

wangkuiyi-patch-2
chengduoZH 7 years ago
parent 6db96ec23c
commit 124c93081d

@ -7,16 +7,12 @@ if(WITH_GPU)
dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
endif()
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
if(WITH_GPU)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
device_context broadcast_op_handle)
else()
set(multi_devices_graph_builder_deps)
endif()
@ -25,3 +21,6 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
device_context broadcast_op_handle)

@ -29,13 +29,8 @@ Tensor *GetTensorFromVar(Variable *in_var) {
return nullptr;
}
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::ContextMap &ctxs)
: local_scopes_(local_scopes), places_(places), ctxs_(ctxs) {
for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs_.DevCtx(p);
}
}
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
void BroadcastOpHandle::RunImpl() {
PADDLE_ENFORCE_EQ(this->inputs_.size(), 1);
@ -47,26 +42,18 @@ void BroadcastOpHandle::RunImpl() {
if (inputs_[0]->generated_op_)
inputs_[0]->generated_op_->Wait(dev_ctxes_[in_place]);
auto iter = std::find(places_.begin(), places_.end(), in_place);
if (iter == places_.end()) {
PADDLE_THROW("The input of BCast is not in the places_.");
}
int offset = iter - places_.begin();
auto in_var = local_scopes_[offset]->FindVar(in_var_handle->name_);
auto in_scope_idx = in_var_handle->scope_idx_;
PADDLE_ENFORCE_LT(in_scope_idx, local_scopes_.size(), "");
auto in_var = local_scopes_[in_scope_idx]->FindVar(in_var_handle->name_);
Tensor *in_tensor = GetTensorFromVar(in_var);
for (auto *out : outputs_) {
auto out_handle = static_cast<VarHandle *>(out);
auto &out_p = out_handle->place_;
auto iter = std::find(places_.begin(), places_.end(), out_p);
if (iter == places_.end()) {
PADDLE_THROW("The output of BCast is not in the places_.");
}
int offset = iter - places_.begin();
auto *s = local_scopes_[offset];
auto out_scope_idx = out_handle->scope_idx_;
PADDLE_ENFORCE_LT(out_scope_idx, local_scopes_.size(), "");
auto *s = local_scopes_[out_scope_idx];
auto out_var = s->FindVar(out_handle->name_);
PADDLE_ENFORCE_EQ(out_var->Type(), in_var->Type(), "");

@ -35,11 +35,10 @@ namespace details {
struct BroadcastOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const platform::ContextMap &ctxs_;
// const platform::ContextMap &ctxs_;
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::ContextMap &ctxs);
const std::vector<platform::Place> &places);
std::string Name() const override;

File diff suppressed because it is too large Load Diff

@ -50,6 +50,7 @@ struct VarHandle : public VarHandleBase {
// version field currently is not used, however, just store the version to
// debug easily.
size_t version_;
size_t scope_idx_;
std::string name_;
platform::Place place_;
};

@ -2,21 +2,19 @@
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
@ -140,45 +138,6 @@ template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
using TYPE = CUDAPinnedDeviceContext;
};
class ContextMap {
public:
explicit ContextMap(const std::vector<platform::Place>& places) {
order_.reserve(places.size());
for (auto& p : places) {
auto dev = boost::get<CUDAPlace>(p);
int dev_id = dev.device;
order_.emplace_back(dev_id);
contexts_[dev_id].reset(new CUDADeviceContext(dev));
}
PADDLE_ENFORCE_EQ(
order_.size(), contexts_.size(),
"Context Map does not support contain two or more same device");
}
DeviceContext* DevCtx(int dev_id) const { return at(dev_id); }
DeviceContext* DevCtx(platform::Place p) const {
return DevCtx(boost::get<CUDAPlace>(p).device);
}
DeviceContext* at(platform::Place p) const {
return this->at(boost::get<CUDAPlace>(p).device);
}
DeviceContext* at(int dev_id) const { return contexts_.at(dev_id).get(); }
void WaitAll() {
for (auto& p : contexts_) {
p.second->Wait();
}
}
private:
std::unordered_map<int, std::unique_ptr<DeviceContext>> contexts_;
std::vector<int> order_;
};
#endif
#ifdef PADDLE_WITH_MKLDNN

Loading…
Cancel
Save