SymbolRef pattern_to_grammar(PatternNode *node) { GrammarRuleSet *ruleset; switch (node->type) { case PN_WORD: { int nwords = (int)AR_size(&ar_words); const char **words = AR_data(&ar_words); /* Try to find an existing equal terminal symbol */ SymbolRef res = { SYM_TERMINAL, 0 }; for (res.index = 0; res.index < nwords; ++res.index) if (strcmp(words[res.index], node->text) == 0) break; /* Add a new terminal symbol if none existed. */ if (res.index == nwords) { char *str = strdup(node->text); AR_push(&ar_words, &str); } return res; } case PN_SEQ: { ruleset = ruleset_create(1); assert(ruleset != NULL); ruleset->rules[0] = symrefs_create(2); assert(ruleset->rules[0] != NULL); ruleset->rules[0]->refs[0] = pattern_to_grammar(node->left); ruleset->rules[0]->refs[1] = pattern_to_grammar(node->right); } break; case PN_ALT: { ruleset = ruleset_create(2); assert(ruleset != NULL); ruleset->rules[0] = symrefs_create(1); assert(ruleset->rules[0] != NULL); ruleset->rules[1] = symrefs_create(1); assert(ruleset->rules[1] != NULL); ruleset->rules[0]->refs[0] = pattern_to_grammar(node->left); ruleset->rules[1]->refs[0] = pattern_to_grammar(node->right); } break; case PN_OPT: { ruleset = ruleset_create(2); assert(ruleset != NULL); ruleset->rules[0] = symrefs_create(0); assert(ruleset->rules[0] != NULL); ruleset->rules[1] = symrefs_create(1); assert(ruleset->rules[1] != NULL); ruleset->rules[1]->refs[0] = pattern_to_grammar(node->left); } break; default: assert(false); } ruleset_sort(ruleset); /* See if the rule set matches an existing symbol's rule set */ GrammarRuleSet **rulesets = AR_data(&ar_grammar); size_t nruleset = AR_size(&ar_grammar), n; for (n = 0; n < nruleset; ++n) if (ruleset_cmp(rulesets[n], ruleset) == 0) break; if (n < nruleset) ruleset_destroy(ruleset); else AR_push(&ar_grammar, &ruleset); SymbolRef res = { SYM_NONTERMINAL, (int)n }; return res; }
ruleset_t * run_simulated_annealing(int iters, int init_size, int nsamples, int nrules, rule_t * rules, rule_t * labels, params_t *params) { ruleset_t *rs; double jump_prob; int dummy, i, j, k, iter, iters_per_step, *rs_idarray = NULL, len; double log_post_rs, max_log_posterior = -1e9, prefix_bound = 0.0; log_post_rs = 0.0; iters_per_step = 200; /* Initialize random number generator for some distrubitions. */ init_gsl_rand_gen(); /* Initialize the ruleset. */ if (create_random_ruleset(init_size, nsamples, nrules, rules, &rs) != 0) return (NULL); log_post_rs = compute_log_posterior(rs, rules, nrules, labels, params, 0, -1, &prefix_bound); if (ruleset_backup(rs, &rs_idarray) != 0) goto err; max_log_posterior = log_post_rs; len = rs->n_rules; if (debug > 10) { printf("Initial ruleset: \n"); ruleset_print(rs, rules, (debug > 100)); } /* Pre-compute the cooling schedule. */ double T[100000], tmp[50]; int ntimepoints = 0; tmp[0] = 1; for (i = 1; i < 28; i++) { tmp[i] = tmp[i - 1] + exp(0.25 * (i + 1)); for (j = (int)tmp[i - 1]; j < (int)tmp[i]; j++) T[ntimepoints++] = 1.0 / (i + 1); } if (debug > 0) printf("iters_per_step = %d, #timepoints = %d\n", iters_per_step, ntimepoints); for (k = 0; k < ntimepoints; k++) { double tk = T[k]; for (iter = 0; iter < iters_per_step; iter++) { if ((rs = propose(rs, rules, labels, nrules, &jump_prob, &log_post_rs, max_log_posterior, &dummy, &tk, params, sa_accepts)) == NULL) goto err; if (log_post_rs > max_log_posterior) { if (ruleset_backup(rs, &rs_idarray) != 0) goto err; max_log_posterior = log_post_rs; len = rs->n_rules; } } } /* Regenerate the best rule list. */ ruleset_destroy(rs); printf("\n\n/*----The best rule list is: */\n"); ruleset_init(len, nsamples, rs_idarray, rules, &rs); printf("max_log_posterior = %6f\n\n", max_log_posterior); printf("max_log_posterior = %6f\n\n", compute_log_posterior(rs, rules, nrules, labels, params, 1, -1, &prefix_bound)); free(rs_idarray); ruleset_print(rs, rules, (debug > 100)); return (rs); err: if (rs != NULL) ruleset_destroy(rs); if (rs_idarray != NULL) free(rs_idarray); return (NULL); }
pred_model_t * train(data_t *train_data, int initialization, int method, params_t *params) { int chain, default_rule; pred_model_t *pred_model; ruleset_t *rs, *rs_temp; double max_pos, pos_temp, null_bound; pred_model = NULL; rs = NULL; if (compute_pmf(train_data->nrules, params) != 0) goto err; compute_cardinality(train_data->rules, train_data->nrules); if (compute_log_gammas(train_data->nsamples, params) != 0) goto err; if ((pred_model = calloc(1, sizeof(pred_model_t))) == NULL) goto err; default_rule = 0; if (ruleset_init(1, train_data->nsamples, &default_rule, train_data->rules, &rs) != 0) goto err; max_pos = compute_log_posterior(rs, train_data->rules, train_data->nrules, train_data->labels, params, 1, -1, &null_bound); if (permute_rules(train_data->nrules) != 0) goto err; for (chain = 0; chain < params->nchain; chain++) { rs_temp = run_mcmc(params->iters, train_data->nsamples, train_data->nrules, train_data->rules, train_data->labels, params, max_pos); pos_temp = compute_log_posterior(rs_temp, train_data->rules, train_data->nrules, train_data->labels, params, 1, -1, &null_bound); if (pos_temp >= max_pos) { ruleset_destroy(rs); rs = rs_temp; max_pos = pos_temp; } else { ruleset_destroy(rs_temp); } } pred_model->theta = get_theta(rs, train_data->rules, train_data->labels, params); pred_model->rs = rs; rs = NULL; /* * THIS IS INTENTIONAL -- makes error handling localized. * If we branch to err, then we want to free an allocated model; * if we fall through naturally, then we don't. */ if (0) { err: if (pred_model != NULL) free (pred_model); } /* Free allocated memory. */ if (log_lambda_pmf != NULL) free(log_lambda_pmf); if (log_eta_pmf != NULL) free(log_eta_pmf); if (rule_permutation != NULL) free(rule_permutation); if (log_gammas != NULL) free(log_gammas); if (rs != NULL) ruleset_destroy(rs); return (pred_model); }
ruleset_t * run_mcmc(int iters, int nsamples, int nrules, rule_t *rules, rule_t *labels, params_t *params, double v_star) { ruleset_t *rs; double jump_prob, log_post_rs; int *rs_idarray, len, nsuccessful_rej; int i, rarray[2], count; double max_log_posterior, prefix_bound; rs = NULL; rs_idarray = NULL; log_post_rs = 0.0; nsuccessful_rej = 0; prefix_bound = -1e10; // This really should be -MAX_DBL n_add = n_delete = n_swap = 0; /* initialize random number generator for some distributions. */ init_gsl_rand_gen(); /* Initialize the ruleset. */ if (debug > 10) printf("Prefix bound = %10f v_star = %f\n", prefix_bound, v_star); /* * Construct rulesets with exactly 2 rules -- one drawn from * the permutation and the default rule. */ rarray[1] = 0; count = 0; while (prefix_bound < v_star) { // TODO Gather some stats on how much we loop in here. if (rs != NULL) { ruleset_destroy(rs); count++; if (count == (nrules - 1)) return (NULL); } rarray[0] = rule_permutation[permute_ndx++].ndx; if (permute_ndx >= nrules) permute_ndx = 1; ruleset_init(2, nsamples, rarray, rules, &rs); log_post_rs = compute_log_posterior(rs, rules, nrules, labels, params, 0, 1, &prefix_bound); if (debug > 10) { printf("Initial random ruleset\n"); ruleset_print(rs, rules, 1); printf("Prefix bound = %f v_star = %f\n", prefix_bound, v_star); } } /* * The initial ruleset is our best ruleset so far, so keep a * list of the rules it contains. */ if (ruleset_backup(rs, &rs_idarray) != 0) goto err; max_log_posterior = log_post_rs; len = rs->n_rules; for (i = 0; i < iters; i++) { if ((rs = propose(rs, rules, labels, nrules, &jump_prob, &log_post_rs, max_log_posterior, &nsuccessful_rej, &jump_prob, params, mcmc_accepts)) == NULL) goto err; if (log_post_rs > max_log_posterior) { if (ruleset_backup(rs, &rs_idarray) != 0) goto err; max_log_posterior = log_post_rs; len = rs->n_rules; } } /* Regenerate the best rule list */ ruleset_destroy(rs); ruleset_init(len, nsamples, rs_idarray, rules, &rs); free(rs_idarray); if (debug) { printf("\n%s%d #add=%d #delete=%d #swap=%d):\n", "The best rule list is (#reject=", nsuccessful_rej, n_add, n_delete, n_swap); printf("max_log_posterior = %6f\n", max_log_posterior); printf("max_log_posterior = %6f\n", compute_log_posterior(rs, rules, nrules, labels, params, 1, -1, &prefix_bound)); ruleset_print(rs, rules, (debug > 100)); } return (rs); err: if (rs != NULL) ruleset_destroy(rs); if (rs_idarray != NULL) free(rs_idarray); return (NULL); }
/* * Create a proposal; used both by simulated annealing and MCMC. * 1. Compute proposal parameters * 2. Create the new proposal ruleset * 3. Compute the log_posterior * 4. Call the appropriate function to determine acceptance criteria */ ruleset_t * propose(ruleset_t *rs, rule_t *rules, rule_t *labels, int nrules, double *jump_prob, double *ret_log_post, double max_log_post, int *cnt, double *extra, params_t *params, int (*accept_func)(double, double, double, double, double *)) { char stepchar; double new_log_post, prefix_bound; int change_ndx, ndx1, ndx2; ruleset_t *rs_new, *rs_ret; rs_new = NULL; if (ruleset_copy(&rs_new, rs) != 0) goto err; if (ruleset_proposal(rs_new, nrules, &ndx1, &ndx2, &stepchar, jump_prob) != 0) goto err; if (debug > 10) { printf("Given ruleset: \n"); ruleset_print(rs, rules, (debug > 100)); printf("Operation %c(%d)(%d) produced proposal:\n", stepchar, ndx1, ndx2); } switch (stepchar) { case 'A': /* Add the rule whose id is ndx1 at position ndx2 */ if (ruleset_add(rules, nrules, &rs_new, ndx1, ndx2) != 0) goto err; change_ndx = ndx2 + 1; n_add++; break; case 'D': /* Delete the rule at position ndx1. */ change_ndx = ndx1; ruleset_delete(rules, nrules, rs_new, ndx1); n_delete++; break; case 'S': /* Swap the rules at ndx1 and ndx2. */ ruleset_swap_any(rs_new, ndx1, ndx2, rules); change_ndx = 1 + (ndx1 > ndx2 ? ndx1 : ndx2); n_swap++; break; default: goto err; break; } new_log_post = compute_log_posterior(rs_new, rules, nrules, labels, params, 0, change_ndx, &prefix_bound); if (debug > 10) { ruleset_print(rs_new, rules, (debug > 100)); printf("With new log_posterior = %0.6f\n", new_log_post); } if (prefix_bound < max_log_post) (*cnt)++; if (accept_func(new_log_post, *ret_log_post, prefix_bound, max_log_post, extra)) { if (debug > 10) printf("Accepted\n"); rs_ret = rs_new; *ret_log_post = new_log_post; ruleset_destroy(rs); } else { if (debug > 10) printf("Rejected\n"); rs_ret = rs; ruleset_destroy(rs_new); } return (rs_ret); err: if (rs_new != NULL) ruleset_destroy(rs_new); ruleset_destroy(rs); return (NULL); }