|
|
|
@ -109,6 +109,11 @@ func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
|
|
|
|
|
return ErrUnintialized
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
count := len(grads)
|
|
|
|
|
if count == 0 {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
|
|
|
@ -118,16 +123,25 @@ func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
|
errCh := make(chan error, count)
|
|
|
|
|
for _, g := range grads {
|
|
|
|
|
wg.Add(1)
|
|
|
|
|
go func(p Parameter, g Gradient) {
|
|
|
|
|
s.opt.UpdateParameter(p, g)
|
|
|
|
|
wg.Done()
|
|
|
|
|
err := s.opt.UpdateParameter(p, g)
|
|
|
|
|
errCh <- err
|
|
|
|
|
}(s.paramMap[g.Name], g)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
wg.Wait()
|
|
|
|
|
recv := 0
|
|
|
|
|
for err := range errCh {
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
recv++
|
|
|
|
|
if recv == count {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|