|
|
@ -339,11 +339,15 @@ private:
|
|
|
|
* clear all grad
|
|
|
|
* clear all grad
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
void clearGrads() {
|
|
|
|
void clearGrads() {
|
|
|
|
|
|
|
|
if (output_.grad) {
|
|
|
|
output_.grad->zeroMem();
|
|
|
|
output_.grad->zeroMem();
|
|
|
|
|
|
|
|
}
|
|
|
|
for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
|
|
|
|
for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
|
|
|
|
|
|
|
|
if (outputOtherDevice_[i].grad) {
|
|
|
|
outputOtherDevice_[i].grad->zeroMem();
|
|
|
|
outputOtherDevice_[i].grad->zeroMem();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
/**
|
|
|
|
* Set deviceId of the params used in this layer.
|
|
|
|
* Set deviceId of the params used in this layer.
|
|
|
|