|
|
|
@ -648,6 +648,12 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
|
|
|
|
|
const ir::Graph &graph, const std::string &varname,
|
|
|
|
|
const std::unordered_map<std::string, int> &sharded_var_device) const {
|
|
|
|
|
auto got = sharded_var_device.find(varname);
|
|
|
|
|
if (got == sharded_var_device.end()) {
|
|
|
|
|
auto pos = varname.find(framework::kNewGradSuffix);
|
|
|
|
|
if (pos != std::string::npos) {
|
|
|
|
|
got = sharded_var_device.find(varname.substr(0, pos));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return got == sharded_var_device.end() ? -1 : got->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|