Пример #1
0
ArrayXXd CMT::BlobNonlinearity::gradient(const ArrayXXd& inputs) const {
	if(inputs.rows() != 1)
		throw Exception("Data has to be stored in one row.");
	
	ArrayXXd diff = ArrayXXd::Zero(mNumComponents, inputs.cols());
	diff.rowwise() += inputs.row(0);
	diff.colwise() -= mMeans;

	ArrayXXd diffSq = diff.square();
	ArrayXd precisions = mLogPrecisions.exp();
	ArrayXd weights = mLogWeights.exp();

	ArrayXXd negEnergy = diffSq.colwise() * (-precisions / 2.);
	ArrayXXd negEnergyExp = negEnergy.exp();

	ArrayXXd gradient(3 * mNumComponents, inputs.cols());

	// gradient of mean
	gradient.topRows(mNumComponents) = (diff * negEnergyExp).colwise() * (weights * precisions);

	// gradient of log-precisions
	gradient.middleRows(mNumComponents, mNumComponents) = (diffSq / 2. * negEnergyExp).colwise() * (-weights * precisions);

	// gradient of log-weights
	gradient.bottomRows(mNumComponents) = negEnergyExp.colwise() * weights;

	return gradient;
}
Пример #2
0
ArrayXXd CMT::BlobNonlinearity::operator()(const ArrayXXd& inputs) const {
	if(inputs.rows() != 1)
		throw Exception("Data has to be stored in one row.");

	ArrayXXd diff = ArrayXXd::Zero(mNumComponents, inputs.cols());
	diff.rowwise() += inputs.row(0);
	diff.colwise() -= mMeans;

	ArrayXXd negEnergy = diff.square().colwise() * (-mLogPrecisions.exp() / 2.);
	return (mLogWeights.exp().transpose().matrix() * negEnergy.exp().matrix()).array() + mEpsilon;
}
Пример #3
0
ArrayXXd CMT::BlobNonlinearity::derivative(const ArrayXXd& inputs) const {
	if(inputs.rows() != 1)
		throw Exception("Data has to be stored in one row.");

	ArrayXXd diff = ArrayXXd::Zero(mNumComponents, inputs.cols());
	diff.rowwise() -= inputs.row(0);
	diff.colwise() += mMeans;

	ArrayXd precisions = mLogPrecisions.exp();

	ArrayXXd negEnergy = diff.square().colwise() * (-precisions / 2.);

	return (mLogWeights.exp() * precisions).transpose().matrix() * (diff * negEnergy.exp()).matrix();
}
Пример #4
0
Файл: mcbm.cpp Проект: cajal/cmt
double CMT::MCBM::parameterGradient(
	const MatrixXd& inputCompl,
	const MatrixXd& outputCompl,
	const lbfgsfloatval_t* x,
	lbfgsfloatval_t* g,
	const Trainable::Parameters& params_) const
{
	const Parameters& params = dynamic_cast<const Parameters&>(params_);

	// average log-likelihood
	double logLik = 0.;

	// interpret memory for parameters and gradients
	lbfgsfloatval_t* y = const_cast<lbfgsfloatval_t*>(x);

	int offset = 0;

	VectorLBFGS priors(params.trainPriors ? y : const_cast<double*>(mPriors.data()), mNumComponents);
	VectorLBFGS priorsGrad(g, mNumComponents);
	if(params.trainPriors)
		offset += priors.size();

	MatrixLBFGS weights(params.trainWeights ? y + offset : const_cast<double*>(mWeights.data()), mNumComponents, mNumFeatures);
	MatrixLBFGS weightsGrad(g + offset, mNumComponents, mNumFeatures);
	if(params.trainWeights)
		offset += weights.size();

	MatrixLBFGS features(params.trainFeatures ? y + offset : const_cast<double*>(mFeatures.data()), mDimIn, mNumFeatures);
	MatrixLBFGS featuresGrad(g + offset, mDimIn, mNumFeatures);
	if(params.trainFeatures)
		offset += features.size();

	MatrixLBFGS predictors(params.trainPredictors ? y + offset : const_cast<double*>(mPredictors.data()), mNumComponents, mDimIn);
	MatrixLBFGS predictorsGrad(g + offset, mNumComponents, mDimIn);
	if(params.trainPredictors)
		offset += predictors.size();

	MatrixLBFGS inputBias(params.trainInputBias ? y + offset : const_cast<double*>(mInputBias.data()), mDimIn, mNumComponents);
	MatrixLBFGS inputBiasGrad(g + offset, mDimIn, mNumComponents);
	if(params.trainInputBias)
		offset += inputBias.size();

	VectorLBFGS outputBias(params.trainOutputBias ? y + offset : const_cast<double*>(mOutputBias.data()), mNumComponents);
	VectorLBFGS outputBiasGrad(g + offset, mNumComponents);
	if(params.trainOutputBias)
		offset += outputBias.size();

	if(g) {
		// initialize gradients
		if(params.trainPriors)
			priorsGrad.setZero();
		if(params.trainWeights)
			weightsGrad.setZero();
		if(params.trainFeatures)
			featuresGrad.setZero();
		if(params.trainPredictors)
			predictorsGrad.setZero();
		if(params.trainInputBias)
			inputBiasGrad.setZero();
		if(params.trainOutputBias)
			outputBiasGrad.setZero();
	}

	// split data into batches for better performance
	int numData = static_cast<int>(inputCompl.cols());
	int batchSize = min(max(params.batchSize, 10), numData);

	#pragma omp parallel for
	for(int b = 0; b < inputCompl.cols(); b += batchSize) {
		const MatrixXd& input = inputCompl.middleCols(b, min(batchSize, numData - b));
		const MatrixXd& output = outputCompl.middleCols(b, min(batchSize, numData - b));

		ArrayXXd featureOutput = features.transpose() * input;
		MatrixXd featureOutputSq = featureOutput.square();
		MatrixXd weightsOutput = weights * featureOutputSq;
		ArrayXXd predictorOutput = predictors * input;

		// unnormalized posteriors over components for both possible outputs
		ArrayXXd logPost0 = (weightsOutput + inputBias.transpose() * input).colwise() + priors;
		ArrayXXd logPost1 = (logPost0 + predictorOutput).colwise() + outputBias.array();

		// sum over components to get unnormalized probabilities of outputs
		Array<double, 1, Dynamic> logProb0 = logSumExp(logPost0);
		Array<double, 1, Dynamic> logProb1 = logSumExp(logPost1);
	
		// normalize posteriors over components
		logPost0.rowwise() -= logProb0;
		logPost1.rowwise() -= logProb1;

		// stack row vectors
		ArrayXXd logProb01(2, input.cols());
		logProb01 << logProb0, logProb1; 

		// normalize log-probabilities
		Array<double, 1, Dynamic> logNorm = logSumExp(logProb01);
		logProb1 -= logNorm;
		logProb0 -= logNorm;

		double logLikBatch = (output.array() * logProb1 + (1. - output.array()) * logProb0).sum();

		#pragma omp critical
		logLik += logLikBatch;

		if(!g)
			// don't compute gradients
			continue;

		Array<double, 1, Dynamic> tmp = output.array() * logProb0.exp() - (1. - output.array()) * logProb1.exp();

		ArrayXXd post0Tmp = logPost0.exp().rowwise() * tmp;
		ArrayXXd post1Tmp = logPost1.exp().rowwise() * tmp;
		ArrayXXd postDiffTmp = post1Tmp - post0Tmp;

		// update gradients
		if(params.trainPriors)
			#pragma omp critical
			priorsGrad -= postDiffTmp.rowwise().sum().matrix();

		if(params.trainWeights)
			#pragma omp critical
			weightsGrad -= postDiffTmp.matrix() * featureOutputSq.transpose();

		if(params.trainFeatures) {
			ArrayXXd tmp2 = weights.transpose() * postDiffTmp.matrix() * 2.;
			MatrixXd tmp3 = featureOutput * tmp2;
			#pragma omp critical
			featuresGrad -= input * tmp3.transpose();
		}

		if(params.trainPredictors)
			#pragma omp critical
			predictorsGrad -= post1Tmp.matrix() * input.transpose();

		if(params.trainInputBias)
			#pragma omp critical
			inputBiasGrad -= input * postDiffTmp.matrix().transpose();

		if(params.trainOutputBias)
			#pragma omp critical
			outputBiasGrad -= post1Tmp.rowwise().sum().matrix();
	}

	double normConst = inputCompl.cols() * log(2.) * dimOut();

	if(g) {
		for(int i = 0; i < offset; ++i)
			g[i] /= normConst;

		if(params.trainFeatures)
			featuresGrad += params.regularizeFeatures.gradient(features);

		if(params.trainPredictors)
			predictorsGrad += params.regularizePredictors.gradient(predictors.transpose()).transpose();

		if(params.trainWeights)
			weightsGrad += params.regularizeWeights.gradient(weights);
	}

	double value = -logLik / normConst;

	if(params.trainFeatures)
		value += params.regularizeFeatures.evaluate(features);

	if(params.trainPredictors)
		value += params.regularizePredictors.evaluate(predictors.transpose());

	if(params.trainWeights)
		value += params.regularizeWeights.evaluate(weights);

	return value;
}
Пример #5
0
double CMT::STM::parameterGradient(
	const MatrixXd& inputCompl,
	const MatrixXd& outputCompl,
	const lbfgsfloatval_t* x,
	lbfgsfloatval_t* g,
	const Trainable::Parameters& params_) const
{
 	// check if nonlinearity is differentiable
 	DifferentiableNonlinearity* nonlinearity = dynamic_cast<DifferentiableNonlinearity*>(mNonlinearity);

	if(!nonlinearity)
		throw Exception("Nonlinearity has to be differentiable for training.");

	const Parameters& params = dynamic_cast<const Parameters&>(params_);

	// average log-likelihood
	double logLik = 0.;

	lbfgsfloatval_t* y = const_cast<lbfgsfloatval_t*>(x);
	int offset = 0;

	VectorLBFGS biases(params.trainBiases ? y : const_cast<double*>(mBiases.data()), mNumComponents);
	VectorLBFGS biasesGrad(g, mNumComponents);
	if(params.trainBiases)
		offset += biases.size();

	MatrixLBFGS weights(params.trainWeights ? y + offset :
		const_cast<double*>(mWeights.data()), mNumComponents, mNumFeatures);
	MatrixLBFGS weightsGrad(g + offset, mNumComponents, mNumFeatures);
	if(params.trainWeights)
		offset += weights.size();

	MatrixLBFGS features(params.trainFeatures ? y + offset :
		const_cast<double*>(mFeatures.data()), dimInNonlinear(), mNumFeatures);
	MatrixLBFGS featuresGrad(g + offset, dimInNonlinear(), mNumFeatures);
	if(params.trainFeatures)
		offset += features.size();

	MatrixLBFGS predictors(params.trainPredictors ? y + offset :
		const_cast<double*>(mPredictors.data()), mNumComponents, dimInNonlinear());
	MatrixLBFGS predictorsGrad(g + offset, mNumComponents, dimInNonlinear());
	if(params.trainPredictors)
		offset += predictors.size();

	VectorLBFGS linearPredictor(params.trainLinearPredictor ? y + offset :
		const_cast<double*>(mLinearPredictor.data()), dimInLinear());
	VectorLBFGS linearPredictorGrad(g + offset, dimInLinear());
	if(params.trainLinearPredictor)
		offset += linearPredictor.size();

	double sharpness = params.trainSharpness ? y[offset++] : mSharpness;
	double sharpnessGrad = 0.;

	if(g) {
		// initialize gradients
		if(params.trainBiases)
			biasesGrad.setZero();
		if(params.trainWeights)
			weightsGrad.setZero();
		if(params.trainFeatures)
			featuresGrad.setZero();
		if(params.trainPredictors)
			predictorsGrad.setZero();
		if(params.trainLinearPredictor)
			linearPredictorGrad.setZero();
	}

	// split data into batches for better performance
	int numData = static_cast<int>(inputCompl.cols());
	int batchSize = min(max(params.batchSize, 10), numData);

	#pragma omp parallel for
	for(int b = 0; b < inputCompl.cols(); b += batchSize) {
		int width = min(batchSize, numData - b);
		const MatrixXd& inputNonlinear = inputCompl.block(0, b, dimInNonlinear(), width);
		const MatrixXd& inputLinear = inputCompl.block(dimInNonlinear(), b, dimInLinear(), width);
		const MatrixXd& output = outputCompl.middleCols(b, width);

		ArrayXXd featureOutput;
		MatrixXd featureOutputSq;
		MatrixXd jointEnergy;

		if(numFeatures() > 0) {
			featureOutput = features.transpose() * inputNonlinear;
			featureOutputSq = featureOutput.square();
			jointEnergy = weights * featureOutputSq + predictors * inputNonlinear;
		} else {
			jointEnergy = predictors * inputNonlinear;
		}

		jointEnergy.colwise() += biases;
		MatrixXd jointEnergyScaled = jointEnergy * sharpness;

		Matrix<double, 1, Dynamic> response = logSumExp(jointEnergyScaled);

		// posterior over components for each data point
		MatrixXd posterior = (jointEnergyScaled.rowwise() - response).array().exp();

		response /= sharpness;

		MatrixXd nonlinearResponse;
		if(params.trainSharpness)
			// make copy of nonlinear response
			nonlinearResponse = response;

		if(dimInLinear())
			response += linearPredictor.transpose() * inputLinear;

		// update log-likelihood
		double logLikBatch = mDistribution->logLikelihood(
			output,
			nonlinearity->operator()(response)).sum();

		#pragma omp critical
		logLik += logLikBatch;

		if(!g)
			// don't compute gradients
			continue;

		Array<double, 1, Dynamic> tmp = -mDistribution->gradient(output, nonlinearity->operator()(response))
			* nonlinearity->derivative(response);

		MatrixXd postTmp = posterior.array().rowwise() * tmp;

		if(params.trainBiases)
			#pragma omp critical
			biasesGrad -= postTmp.rowwise().sum();

		if(numFeatures() > 0) {
			if(params.trainWeights)
				#pragma omp critical
				weightsGrad -= postTmp * featureOutputSq.transpose();

			if(params.trainFeatures) {
				ArrayXXd tmp2 = 2. * weights.transpose() * postTmp;
				MatrixXd tmp3 = featureOutput * tmp2;
				#pragma omp critical
				featuresGrad -= inputNonlinear * tmp3.transpose();
			}
		}

		if(params.trainPredictors)
			#pragma omp critical
			predictorsGrad -= postTmp * inputNonlinear.transpose();

		if(params.trainLinearPredictor && dimInLinear() > 0)
			#pragma omp critical
			linearPredictorGrad -= inputLinear * tmp.transpose().matrix();

		if(params.trainSharpness) {
			double tmp2 = ((jointEnergy.array() * posterior.array()).colwise().sum() * tmp).sum() / sharpness;
			double tmp3 = (nonlinearResponse.array() * tmp).sum() / sharpness;
			#pragma omp critical
			sharpnessGrad -= tmp2 - tmp3;
		}
	}

	double normConst = inputCompl.cols() * log(2.) * dimOut();

	if(g) {
		if(params.trainSharpness)
			g[offset - 1] = sharpnessGrad;

		for(int i = 0; i < offset; ++i)
			g[i] /= normConst;

		if(params.trainBiases)
			biasesGrad += params.regularizeBiases.gradient(biases);
		if(params.trainFeatures)
			featuresGrad += params.regularizeFeatures.gradient(features);
		if(params.trainPredictors)
			predictorsGrad += params.regularizePredictors.gradient(predictors.transpose()).transpose();
		if(params.trainWeights)
			weightsGrad += params.regularizeWeights.gradient(weights.transpose()).transpose();
		if(params.trainLinearPredictor)
			linearPredictorGrad += params.regularizeLinearPredictor.gradient(linearPredictor);
	}

	double value = -logLik / normConst;

	value += params.regularizeBiases.evaluate(biases);
	value += params.regularizeFeatures.evaluate(features);
	value += params.regularizePredictors.evaluate(predictors.transpose());
	value += params.regularizeWeights.evaluate(weights.transpose());
	value += params.regularizeLinearPredictor.evaluate(linearPredictor);

	if(value != value)
		// value is NaN; line search probably went into bad region of parameter space
		value = numeric_limits<double>::max();

	return value;
}