pair<double, double> drwnADLPInference::inference(drwnFullAssignment& mapAssignment) { int iteration = 0; drwnTableFactorStorage storage, storage_2; drwnVarUniversePtr universe(_graph.getUniverse()); double factorDiff = 100 ; double incrFactor = 1.25 ; mapAssignment.clear(); mapAssignment.resize(_numNodes); double bestEnergy = _graph.getEnergy(mapAssignment); double bestDualEnergy = -DRWN_DBL_MAX; double sumResidual, sumPrimalUpdate; bool notConverged = true; // for t = 1 to T do while((iteration < MAX_ITERATIONS) && (notConverged)) { //============================================================================================= // Update delta: for all i = 1, ..., n for (int i = 0; i < _numNodes; i++) { // Set theta_bar_i = theta_i + sum_{c,i in C} (delta_bar_{ci} - 1 / p * gamma_{ci}) int tempSize = _message_unary[i].size(); int entries = _unary_bar[i]->entries(); for (int k = 0; k < entries; k++) { (*_unary_bar[i])[k] = -(*_unary[i])[k]; } for (int j = 0; j < tempSize; j++) { _gamma[i][j]->scale(1.0 / PENALTY_PARAMETER); // Only need to scale once for (int k = 0; k < entries; k++) { (*_unary_bar[i])[k] += (*_message_unary_bar[i][j])[k] - (*_gamma[i][j])[k]; } } double theta = TRIM(_unary_bar[i], (double)tempSize / PENALTY_PARAMETER); for (int k = 0; k < entries; k++) { (*_unary_bar[i])[k] = ((*_unary_bar[i])[k] > theta) ? ((*_unary_bar[i])[k] - theta) / (double)tempSize : 0.0; } // Update delta_{ci} = delta_bar_{ci} - 1 / p * gamma_{ci} - q, forall_c : i in c for (int j = 0; j < tempSize; j++) { for (int k = 0; k < entries; k++) { // Note : This is possible because there is only one variable (*_message_unary[i][j])[k] = (*_message_unary_bar[i][j])[k] - (*_gamma[i][j])[k] - (*_unary_bar[i])[k]; } } } //============================================================================================= // Update lambda: for all c in C for (int i = 0; i < _cliqueSize; i++) { // Set theta_bar_c = theta_c - sum_{i:i in c} delta_bar_{ci} + 1 / p * mu_c int entries = _lambda[i]->entries(); _mu[i]->scale(1.0 / PENALTY_PARAMETER); // Only need to scale once for (int j = 0; j < entries; j++) { // Note : this is possible due to the way mu and clique_bar are constructed, taking the same variable order as the clique (*_clique_bar[i])[j] = (*_mu[i])[j] - (*_clique[i])[j]; } int tempSize = _updateLambdaOp[i].size(); for (int j = 0; j < tempSize; j++) { _updateLambdaOp[i][j]->execute(); // theta_bar_c - delta_bar } double theta = TRIM(_clique_bar[i], 1.0 / PENALTY_PARAMETER); for (int j = 0; j < entries; j++) { // Note : this is possible due to the way lambda and clique_bar are constructed, taking the same variable order as the clique (*_lambda[i])[j] = ((*_clique_bar[i])[j] > theta) ? -(*_clique[i])[j] - theta : -(*_clique[i])[j] - (*_clique_bar[i])[j]; } } //============================================================================================= // Update delta_bar: for all c in C, i : i in c, x_i sumPrimalUpdate = 0.0; for (int i = 0; i < _cliqueSize; i++) { // Set v_{ci} = delta_{ci} + 1 / p * gamma_{ci} + sum_{x_c\i} lambda_c + 1 / p * sum_{x_c\i} mu_c int tempSize = _message_clique[i].size(); vector<vector<double> > v(tempSize); vector<double> sum; double totalSum = 0.0; int sumCard = 0; sum.resize(tempSize); for (int j = 0; j < tempSize; j++) { _updateDeltaBarOp[i][j][0]->execute(); // Marginalize Lambda _updateDeltaBarOp[i][j][1]->execute(); // Marginalize Mu sum[j] = 0.0; int entries = _message_clique[i][j]->entries(); v[j].resize(entries, 0.0); for(int k = 0; k < entries; k++) { v[j][k] = (*_message_clique[i][j])[k] + (*_gamma_clique[i][j])[k] + (*_margin_result_lambda[i][j])[k] + (*_margin_result_mu[i][j])[k]; sum[j] += v[j][k]; } sumCard += _clique[i]->entries() / universe->varCardinality(_clique[i]->varId(j)); totalSum += sum[j] * (double)(_clique[i]->entries() / universe->varCardinality(_clique[i]->varId(j))); } // v_bar_{c} = 1 / {1 + sum_{k : k in c} |X_{c\k}|} * sum_{k : k in c} |X_{c\k}| * sum_{x_k} v_{ck}(x_k) double v_bar = (1.0 / (double)(1 + sumCard)) * totalSum; // delta_bar_{ci} = 1 / {1 + |X_{c\i}|} * [v_{ci} - sum_{j : j in c, j != i} |X_{c\ji}| (sum_{x_j} v_{cj}(x_j) - v_bar_{c})] for (int j = 0; j < tempSize; j++) { totalSum = 0.0; for (int k = 0; k < tempSize; k++) { if (j != k) { totalSum += (double)(_clique[i]->entries() / universe->varCardinality(_clique[i]->varId(j)) / universe->varCardinality(_clique[i]->varId(k))) * (sum[k] - v_bar); } } int entries = _message_clique_bar[i][j]->entries(); double denominator = (double)(1 + _clique[i]->entries() / universe->varCardinality(_clique[i]->varId(j))); for(int k = 0; k < entries; k++) { double result = (v[j][k] - totalSum) / denominator; sumPrimalUpdate += pow(result - (*_message_clique_bar[i][j])[k], 2); (*_message_clique_bar[i][j])[k] = result; } } v.clear(); } //============================================================================================= // Update the multipliers: // gamma_{ci} = gamma_{ci} + p * (delta_{ci} - delta_bar_{ci}) for all c in C, i : i in c, x_i sumResidual = 0.0; for (int i = 0; i < _numNodes; i++) { int tempSize = _message_unary[i].size(); for (int j = 0; j < tempSize; j++) { int entries = _message_unary[i][j]->entries(); for (int k = 0; k < entries; k++) { double result = (*_message_unary[i][j])[k] - (*_message_unary_bar[i][j])[k]; // Note : this is possible because there is only one variable sumResidual += result * result; (*_gamma[i][j])[k] = (*_gamma[i][j])[k] * PENALTY_PARAMETER + result * PENALTY_PARAMETER; } } } // mu_c = mu_c + p * (lambda_c - sum{i:i in c} delta_bar_{ci}) for all c in C, x_c for (int i = 0; i < _cliqueSize; i++) { int entries = _lambda[i]->entries(); for (int k = 0; k < entries; k++) { (*_tempMu[i])[k] = (*_lambda[i])[k]; } _tempMu[i]->dataCompareAndCopy(*_lambda[i]); int tempSize = _updateMuOp[i].size(); for (int j = 0; j < tempSize; j++) { _updateMuOp[i][j]->execute(); // tempMu - delta_bar } for (int k = 0; k < entries; k++) { sumResidual += std::pow((*_tempMu[i])[k], 2); (*_mu[i])[k] = (*_mu[i])[k] * PENALTY_PARAMETER + (*_tempMu[i])[k] * PENALTY_PARAMETER; } } // Check for Convergence : based on the author's code in SVL double dualObjDelta = 0.0, dualObjDeltaBar = 0.0; drwnFullAssignment deltaAssignment(_numNodes), deltaBarAssignment(_numNodes); drwnFullAssignment dashAssignment(_numNodes); for (int i = 0; i < _numNodes; i++) { drwnTableFactor thetaBar(universe, &storage); thetaBar.addVariable(i); drwnTableFactor thetaBarBar(universe, &storage_2); thetaBarBar.addVariable(i); int entries = thetaBar.entries(); for (int k = 0; k < entries; k++) { thetaBar[k] = 0.0; thetaBarBar[k] = 0.0; } int tempSize = _message_unary[i].size(); for (int j = 0; j < tempSize; j++) { for (int k = 0; k < entries; k++) { // Note : this is possible because there is only one variable thetaBar[k] += (*_message_unary[i][j])[k]; thetaBarBar[k] += (*_message_unary_bar[i][j])[k]; } } if (_flag[i] != 0) { for (int k = 0; k < entries; k++) { thetaBar[k] -= (*_unary[i])[k]; thetaBarBar[k] -= (*_unary[i])[k]; } } deltaAssignment[i] = thetaBar.valueOf(i, thetaBar.indexOfMax()); dualObjDelta += thetaBar[thetaBar.indexOfMax()]; deltaBarAssignment[i] = thetaBarBar.valueOf(i, thetaBarBar.indexOfMax()); dualObjDeltaBar += thetaBarBar[thetaBarBar.indexOfMax()]; } for (int i = 0; i < _cliqueSize; i++) { int entries = _clique[i]->entries(); for (int k = 0; k < entries; k++) { (*_clique_bar[i])[k] = -(*_clique[i])[k]; (*_tempMu[i])[k] = -(*_clique[i])[k]; } int tempSize = _decodeOp[i].size(); for (int j = 0; j < tempSize; j++) { _decodeOp[i][j]->execute(); // clique - delta & clique - delta_bar } dualObjDelta += (*_clique_bar[i])[_clique_bar[i]->indexOfMax()]; dualObjDeltaBar += (*_tempMu[i])[_tempMu[i]->indexOfMax()]; } double deltaEnergy = _graph.getEnergy(deltaAssignment); double deltaBarEnergy = _graph.getEnergy(deltaBarAssignment); if (deltaEnergy < bestEnergy) { bestEnergy = deltaEnergy; mapAssignment = deltaAssignment; } if (deltaBarEnergy < bestEnergy) { bestEnergy = deltaBarEnergy; mapAssignment = deltaBarAssignment; } if ((dualObjDelta <= -bestEnergy) || (dualObjDeltaBar <= -bestEnergy) || ((sqrt(sumResidual) < EPSILON) && (sqrt(sumPrimalUpdate) < EPSILON)) ){ notConverged = false; } bestDualEnergy = -dualObjDelta; DRWN_LOG_VERBOSE("...iteration " << iteration << "; dual objective " << bestDualEnergy << "; best energy " << bestEnergy); iteration++; if (sqrt(sumResidual) > factorDiff*sqrt(sumPrimalUpdate)) { PENALTY_PARAMETER *= incrFactor ; // residual is large -> increase penalty } else if (sqrt(sumPrimalUpdate) > factorDiff*sqrt(sumResidual)) { PENALTY_PARAMETER /= incrFactor ; // residual small -> decrease penalty } } // end for return make_pair(bestEnergy, bestDualEnergy); }