|  |  |  | @ -109,6 +109,40 @@ void GatherAgentLayer::forwardValue(PassType passType) { | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | namespace { | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // dest[index[i]] <- src[i] for each i
 | 
			
		
	
		
			
				
					|  |  |  |  | void copyElements(const IVector& srcVec, | 
			
		
	
		
			
				
					|  |  |  |  |                   const IVector& indexVec, | 
			
		
	
		
			
				
					|  |  |  |  |                   IVector& destVec) { | 
			
		
	
		
			
				
					|  |  |  |  |   const int* src = srcVec.getData(); | 
			
		
	
		
			
				
					|  |  |  |  |   const int* index = indexVec.getData(); | 
			
		
	
		
			
				
					|  |  |  |  |   int* dest = destVec.getData(); | 
			
		
	
		
			
				
					|  |  |  |  |   int len = indexVec.getSize(); | 
			
		
	
		
			
				
					|  |  |  |  |   CHECK_EQ(srcVec.getSize(), indexVec.getSize()); | 
			
		
	
		
			
				
					|  |  |  |  |   for (int i = 0; i < len; ++i) { | 
			
		
	
		
			
				
					|  |  |  |  |     dest[index[i]] = src[i]; | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | void GatherAgentLayer::forwardIds(PassType passType) { | 
			
		
	
		
			
				
					|  |  |  |  |   IVectorPtr realId = realLayers_[0]->getOutputLabel(); | 
			
		
	
		
			
				
					|  |  |  |  |   if (!realId) return; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   IVector::resizeOrCreate(output_.ids, allIds_->getSize(), useGpu_); | 
			
		
	
		
			
				
					|  |  |  |  |   IVectorPtr outId = output_.ids; | 
			
		
	
		
			
				
					|  |  |  |  |   idsVec_.resize(idIndex_.size()); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   for (size_t i = 0; i < realLayers_.size(); ++i) { | 
			
		
	
		
			
				
					|  |  |  |  |     const IVectorPtr& realId = realLayers_[i]->getOutputLabel(); | 
			
		
	
		
			
				
					|  |  |  |  |     idsVec_[i] = IVector::create(allIds_->getData() + idIndex_[i], | 
			
		
	
		
			
				
					|  |  |  |  |                                  /* size */ realId->getSize(), | 
			
		
	
		
			
				
					|  |  |  |  |                                  useGpu_); | 
			
		
	
		
			
				
					|  |  |  |  |     execViaCpu(©Elements, *realId, *idsVec_[i], *outId); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | void GatherAgentLayer::backward(const UpdateCallback& callback) { | 
			
		
	
		
			
				
					|  |  |  |  |   (void)callback; | 
			
		
	
		
			
				
					|  |  |  |  |   const MatrixPtr& outputGrad = getOutputGrad(); | 
			
		
	
	
		
			
				
					|  |  |  | @ -136,23 +170,22 @@ void ScatterAgentLayer::forward(PassType passType) { | 
			
		
	
		
			
				
					|  |  |  |  |   CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   int width = this->getSize(); | 
			
		
	
		
			
				
					|  |  |  |  |   if (realOutArg_.hasSeq()) { | 
			
		
	
		
			
				
					|  |  |  |  |     forwardSequence(passType); | 
			
		
	
		
			
				
					|  |  |  |  |   } else if (realOutArg_.value || realOutArg_.ids) { | 
			
		
	
		
			
				
					|  |  |  |  |     output_.subArgFrom( | 
			
		
	
		
			
				
					|  |  |  |  |         realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_); | 
			
		
	
		
			
				
					|  |  |  |  |   } else {  // used in generation
 | 
			
		
	
		
			
				
					|  |  |  |  |     if (realLayer_->getOutput().ids) { | 
			
		
	
		
			
				
					|  |  |  |  |       IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_); | 
			
		
	
		
			
				
					|  |  |  |  |       output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     if (realLayer_->getOutput().value) { | 
			
		
	
		
			
				
					|  |  |  |  |       int height = ids_->getSize(); | 
			
		
	
		
			
				
					|  |  |  |  |       resetOutput(height, width); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |       const MatrixPtr& outV = getOutputValue(); | 
			
		
	
		
			
				
					|  |  |  |  |       const MatrixPtr& realV = realLayer_->getOutputValue(); | 
			
		
	
		
			
				
					|  |  |  |  |       outV->selectRows(*realV, *ids_); | 
			
		
	
		
			
				
					|  |  |  |  |   if (selectionMode_) { | 
			
		
	
		
			
				
					|  |  |  |  |     forwardWithSelection(passType); | 
			
		
	
		
			
				
					|  |  |  |  |   } else { | 
			
		
	
		
			
				
					|  |  |  |  |     if (realOutArg_.hasSeq()) { | 
			
		
	
		
			
				
					|  |  |  |  |       output_.subArgFrom(realOutArg_, | 
			
		
	
		
			
				
					|  |  |  |  |                          /* offset */ idIndex_, | 
			
		
	
		
			
				
					|  |  |  |  |                          idSize_, | 
			
		
	
		
			
				
					|  |  |  |  |                          width, | 
			
		
	
		
			
				
					|  |  |  |  |                          useGpu_, | 
			
		
	
		
			
				
					|  |  |  |  |                          /* trans */ false, | 
			
		
	
		
			
				
					|  |  |  |  |                          /* seqFlag */ true, | 
			
		
	
		
			
				
					|  |  |  |  |                          /* seqStart */ seqStartPosIndex_, | 
			
		
	
		
			
				
					|  |  |  |  |                          /* seqSize */ numSequences_); | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       output_.subArgFrom( | 
			
		
	
		
			
				
					|  |  |  |  |           realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
	
		
			
				
					|  |  |  | @ -160,6 +193,8 @@ void ScatterAgentLayer::forward(PassType passType) { | 
			
		
	
		
			
				
					|  |  |  |  | void ScatterAgentLayer::backward(const UpdateCallback& callback) { | 
			
		
	
		
			
				
					|  |  |  |  |   (void)callback; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   CHECK(!selectionMode_); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   const MatrixPtr& outputGrad = realOutArg_.grad; | 
			
		
	
		
			
				
					|  |  |  |  |   const MatrixPtr& realGrad = realLayer_->getOutputGrad(); | 
			
		
	
		
			
				
					|  |  |  |  |   if (realGrad) { | 
			
		
	
	
		
			
				
					|  |  |  | @ -174,42 +209,7 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) { | 
			
		
	
		
			
				
					|  |  |  |  | REGISTER_LAYER(gather_agent, GatherAgentLayer); | 
			
		
	
		
			
				
					|  |  |  |  | REGISTER_LAYER(scatter_agent, ScatterAgentLayer); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | void GatherAgentLayer::forwardIds(PassType passType) { | 
			
		
	
		
			
				
					|  |  |  |  |   int height = 0; | 
			
		
	
		
			
				
					|  |  |  |  |   IVectorPtr idReal = realLayers_[0]->getOutputLabel(); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   if (!idReal) return; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   if (output_.subSequenceStartPositions) { | 
			
		
	
		
			
				
					|  |  |  |  |     int* starts = output_.subSequenceStartPositions->getMutableData(false); | 
			
		
	
		
			
				
					|  |  |  |  |     // Gather generator.idsVec
 | 
			
		
	
		
			
				
					|  |  |  |  |     // if is beam search generation result. Get first result.
 | 
			
		
	
		
			
				
					|  |  |  |  |     if (idReal->getData()[idReal->getSize() - 1] == -1) { | 
			
		
	
		
			
				
					|  |  |  |  |       for (size_t i = 0; i < realLayers_.size(); ++i) { | 
			
		
	
		
			
				
					|  |  |  |  |         // The first element stores first result size
 | 
			
		
	
		
			
				
					|  |  |  |  |         idReal = realLayers_[i]->getOutputLabel(); | 
			
		
	
		
			
				
					|  |  |  |  |         idReal->subVecFrom(*idReal, 1, idReal->getData()[0]); | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     for (size_t i = 0; i < realLayers_.size(); ++i) { | 
			
		
	
		
			
				
					|  |  |  |  |       CHECK(realLayers_[i]->getOutputLabel()); | 
			
		
	
		
			
				
					|  |  |  |  |       starts[i] = height; | 
			
		
	
		
			
				
					|  |  |  |  |       height += realLayers_[i]->getOutputLabel()->getSize(); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     starts[realLayers_.size()] = height; | 
			
		
	
		
			
				
					|  |  |  |  |     output_.sequenceStartPositions->getMutableData(false)[1] = height; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     IVector::resizeOrCreate(output_.ids, height, false); | 
			
		
	
		
			
				
					|  |  |  |  |     for (size_t i = 0; i < realLayers_.size(); ++i) { | 
			
		
	
		
			
				
					|  |  |  |  |       output_.ids->subVec(starts[i], starts[i + 1] - starts[i]) | 
			
		
	
		
			
				
					|  |  |  |  |           ->copyFrom(*realLayers_[i]->getOutputLabel()); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } else { | 
			
		
	
		
			
				
					|  |  |  |  |     LOG(FATAL) << "Not implemented"; | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | void ScatterAgentLayer::forwardSequence(PassType passType) { | 
			
		
	
		
			
				
					|  |  |  |  | void ScatterAgentLayer::forwardWithSelection(PassType passType) { | 
			
		
	
		
			
				
					|  |  |  |  |   Layer::forward(passType); | 
			
		
	
		
			
				
					|  |  |  |  |   CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -220,17 +220,19 @@ void ScatterAgentLayer::forwardSequence(PassType passType) { | 
			
		
	
		
			
				
					|  |  |  |  |   AsyncGpuBlock asyncGpuBlock; | 
			
		
	
		
			
				
					|  |  |  |  |   REGISTER_TIMER_INFO("SequenceAgentLayerForward", getName().c_str()); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   if (realOutArg_.value || realOutArg_.ids) { | 
			
		
	
		
			
				
					|  |  |  |  |     CHECK(realOutArg_.sequenceStartPositions); | 
			
		
	
		
			
				
					|  |  |  |  |     output_.subArgFrom(realOutArg_, | 
			
		
	
		
			
				
					|  |  |  |  |                        /* offset */ idIndex_, | 
			
		
	
		
			
				
					|  |  |  |  |                        idSize_, | 
			
		
	
		
			
				
					|  |  |  |  |                        width, | 
			
		
	
		
			
				
					|  |  |  |  |                        useGpu_, | 
			
		
	
		
			
				
					|  |  |  |  |                        /* trans */ false, | 
			
		
	
		
			
				
					|  |  |  |  |                        /* seqFlag */ true, | 
			
		
	
		
			
				
					|  |  |  |  |                        /* seqStart */ seqStartPosIndex_, | 
			
		
	
		
			
				
					|  |  |  |  |                        /* seqSize */ numSequences_); | 
			
		
	
		
			
				
					|  |  |  |  |   if (!input.hasSeq()) { | 
			
		
	
		
			
				
					|  |  |  |  |     if (realLayer_->getOutput().ids) { | 
			
		
	
		
			
				
					|  |  |  |  |       IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_); | 
			
		
	
		
			
				
					|  |  |  |  |       output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     if (realLayer_->getOutput().value) { | 
			
		
	
		
			
				
					|  |  |  |  |       int height = ids_->getSize(); | 
			
		
	
		
			
				
					|  |  |  |  |       resetOutput(height, width); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |       const MatrixPtr& outV = getOutputValue(); | 
			
		
	
		
			
				
					|  |  |  |  |       const MatrixPtr& realV = realLayer_->getOutputValue(); | 
			
		
	
		
			
				
					|  |  |  |  |       outV->selectRows(*realV, *ids_); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } else { | 
			
		
	
		
			
				
					|  |  |  |  |     // Putting the generation logic here is really an ugly hack!
 | 
			
		
	
		
			
				
					|  |  |  |  |     // used in generation
 | 
			
		
	
	
		
			
				
					|  |  |  | 
 |