@ -109,6 +109,40 @@ void GatherAgentLayer::forwardValue(PassType passType) {
}
}
namespace {
// dest[index[i]] <- src[i] for each i
void copyElements ( const IVector & srcVec ,
const IVector & indexVec ,
IVector & destVec ) {
const int * src = srcVec . getData ( ) ;
const int * index = indexVec . getData ( ) ;
int * dest = destVec . getData ( ) ;
int len = indexVec . getSize ( ) ;
CHECK_EQ ( srcVec . getSize ( ) , indexVec . getSize ( ) ) ;
for ( int i = 0 ; i < len ; + + i ) {
dest [ index [ i ] ] = src [ i ] ;
}
}
}
void GatherAgentLayer : : forwardIds ( PassType passType ) {
IVectorPtr realId = realLayers_ [ 0 ] - > getOutputLabel ( ) ;
if ( ! realId ) return ;
IVector : : resizeOrCreate ( output_ . ids , allIds_ - > getSize ( ) , useGpu_ ) ;
IVectorPtr outId = output_ . ids ;
idsVec_ . resize ( idIndex_ . size ( ) ) ;
for ( size_t i = 0 ; i < realLayers_ . size ( ) ; + + i ) {
const IVectorPtr & realId = realLayers_ [ i ] - > getOutputLabel ( ) ;
idsVec_ [ i ] = IVector : : create ( allIds_ - > getData ( ) + idIndex_ [ i ] ,
/* size */ realId - > getSize ( ) ,
useGpu_ ) ;
execViaCpu ( & copyElements , * realId , * idsVec_ [ i ] , * outId ) ;
}
}
void GatherAgentLayer : : backward ( const UpdateCallback & callback ) {
( void ) callback ;
const MatrixPtr & outputGrad = getOutputGrad ( ) ;
@ -174,41 +208,6 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) {
REGISTER_LAYER ( gather_agent , GatherAgentLayer ) ;
REGISTER_LAYER ( scatter_agent , ScatterAgentLayer ) ;
void GatherAgentLayer : : forwardIds ( PassType passType ) {
int height = 0 ;
IVectorPtr idReal = realLayers_ [ 0 ] - > getOutputLabel ( ) ;
if ( ! idReal ) return ;
if ( output_ . subSequenceStartPositions ) {
int * starts = output_ . subSequenceStartPositions - > getMutableData ( false ) ;
// Gather generator.idsVec
// if is beam search generation result. Get first result.
if ( idReal - > getData ( ) [ idReal - > getSize ( ) - 1 ] = = - 1 ) {
for ( size_t i = 0 ; i < realLayers_ . size ( ) ; + + i ) {
// The first element stores first result size
idReal = realLayers_ [ i ] - > getOutputLabel ( ) ;
idReal - > subVecFrom ( * idReal , 1 , idReal - > getData ( ) [ 0 ] ) ;
}
}
for ( size_t i = 0 ; i < realLayers_ . size ( ) ; + + i ) {
CHECK ( realLayers_ [ i ] - > getOutputLabel ( ) ) ;
starts [ i ] = height ;
height + = realLayers_ [ i ] - > getOutputLabel ( ) - > getSize ( ) ;
}
starts [ realLayers_ . size ( ) ] = height ;
output_ . sequenceStartPositions - > getMutableData ( false ) [ 1 ] = height ;
IVector : : resizeOrCreate ( output_ . ids , height , false ) ;
for ( size_t i = 0 ; i < realLayers_ . size ( ) ; + + i ) {
output_ . ids - > subVec ( starts [ i ] , starts [ i + 1 ] - starts [ i ] )
- > copyFrom ( * realLayers_ [ i ] - > getOutputLabel ( ) ) ;
}
} else {
LOG ( FATAL ) < < " Not implemented " ;
}
}
void ScatterAgentLayer : : forwardSequence ( PassType passType ) {
Layer : : forward ( passType ) ;
CHECK_EQ ( realLayer_ - > getDeviceId ( ) , this - > getDeviceId ( ) ) ;