コード例 #1
0
void NeuralNetworkSimpleInternal::write(TextFileWriter const& _file) const throw()
{
	const int intFieldWidth = 4;
	const int size = 128;
	char buf[size];
	
	TextFileWriter file = _file;
	
	file.write("NeuralNetworkSimple:v1\n");
	
	snprintf(buf, size, "learnRate: %f actFuncOffset: %f\n", getLearnRate(), getActFuncOffset());
	file.write(buf);
	
	snprintf(buf, size, "Layers:%*d\n", intFieldWidth, getNumLayersIncludingInput());
	file.write(buf);
	
	snprintf(buf, size, "Layer %*d:%*d\n", intFieldWidth, 0, 
										   intFieldWidth, numInputs);
	file.write(buf);
	
	for(int layer = 0; layer < getNumLayersExcludingInput(); layer++)
	{
		snprintf(buf, size, "Layer %*d:%*d\n", intFieldWidth, layer+1, 
											   intFieldWidth, getNumNodesOnLayer(layer));
		file.write(buf);
	}
		
	for(int layer = 0; layer < getNumLayersExcludingInput(); layer++)
	{
		const int numNodes = getNumNodesOnLayer(layer);
		for(int node = 0; node < numNodes; node++)
		{
			NumericalArray<float> weights;
			float threshold;
			get(layer, node, &weights, threshold);
				
			snprintf(buf, size, "%*d %*d %*d   %.16f\n", intFieldWidth, layer, 
														 intFieldWidth, node, 
														 intFieldWidth, -1, 
														 threshold);
			file.write(buf);
			
			const int numWeights = weights.size();
			for(int weight = 0; weight < numWeights; weight++)
			{
				snprintf(buf, size, "%*d %*d %*d   %.16f\n", intFieldWidth, layer, 
															 intFieldWidth, node, 
															 intFieldWidth, weight, 
															 weights[weight]);
				file.write(buf);
			}
		}
	}
	
}
コード例 #2
0
ファイル: SimpleTrainer.cpp プロジェクト: osushkov/vectornn
void SimpleTrainer::Train(
    Network &network, vector<TrainingSample> &trainingSamples, unsigned iterations) {

  random_shuffle(trainingSamples.begin(), trainingSamples.end());
  curSamplesIndex = 0;

  for (unsigned i = 0; i < iterations; i++) {
    float lr = getLearnRate(i, iterations);

    TrainingProvider samplesProvider = getStochasticSamples(trainingSamples);
    pair<Tensor, float> gradientError = network.ComputeGradient(samplesProvider);
    network.ApplyUpdate(gradientError.first * -lr);
  }
}