示例#1
0
int pc_compute_inside_2(void)
{
    int gid;
    double prob;
    EG_NODE_PTR eg_ptr;

    gid = bpx_get_integer(bpx_get_call_arg(1,2));

    initialize_egraph_index();
    alloc_sorted_egraph(1);
    RET_ON_ERR(sort_one_egraph(gid, 0, 1));

    if (verb_graph) {
        print_egraph(0, PRINT_NEUTRAL);
    }

    eg_ptr = expl_graph[gid];

	if (log_scale) {
        RET_ON_ERR(compute_inside_scaling_log_exp());
        prob = eg_ptr->inside;
	}
	else {
        RET_ON_ERR(compute_inside_scaling_none());
        prob = eg_ptr->inside;
	}

    return bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(prob));
}
示例#2
0
 PosibErr<void> reload_filters(Speller * m) 
 {
   RET_ON_ERR(m->reload_conv());
   // Add enocder and decoder filters if any
   RET_ON_ERR(m->to_internal_->add_filters(m->config(), true, false, false));
   RET_ON_ERR(m->from_internal_->add_filters(m->config(), false, false, true));
   return no_err;
 }
int pc_prism_vbvt_2(void)
{
    struct VBVT_Engine vbvt_eng;

    RET_ON_ERR(check_smooth_vb());
    RET_ON_ERR(run_vbvt(&vbvt_eng));

    return
        bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(vbvt_eng.iterate)) &&
        bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(vbvt_eng.free_energy));
}
示例#4
0
  PosibErr<void> reload_filters(Speller * m) 
  {
    m->to_internal_->filter.clear();
    m->from_internal_->filter.clear();
    // Add enocder and decoder filters if any
    RET_ON_ERR(setup_filter(m->to_internal_->filter, m->config(), 
			    true, false, false));
    RET_ON_ERR(setup_filter(m->from_internal_->filter, m->config(), 
			    false, false, true));
    return no_err;
  }
示例#5
0
  PosibErr<Speller *> new_speller(Config * c0) 
  {
    aspell_gettext_init();

    RET_ON_ERR_SET(find_word_list(c0), Config *, c);
    StackPtr<Speller> m(get_speller_class(c));
    RET_ON_ERR(m->setup(c));

    RET_ON_ERR(reload_filters(m));
    
    return m.release();
  }
示例#6
0
int pc_prism_vbem_2(void)
{
    struct VBEM_Engine vb_eng;

    RET_ON_ERR(check_smooth_vb());
    RET_ON_ERR(run_vbem(&vb_eng));
    release_num_sw_vals();

    return
        bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(vb_eng.iterate)) &&
        bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(vb_eng.free_energy));
}
示例#7
0
    // does NOT pass it through filters
    // DOES NOT use an internal state
    PosibErr<void> convert_ec(const char * in, int size, CharVector & out, 
                              ConvertBuffer & buf, ParmStr orig) const
    {
      if (conv_) {
	RET_ON_ERR(conv_->convert_ec(in,size,out, orig));
      } else {
        buf.clear();
        RET_ON_ERR(decode_->decode_ec(in, size, buf, orig));
       RET_ON_ERR(encode_->encode_ec(buf.pbegin(), buf.pend(), 
                                      out, orig));
      }
      return no_err;
    }
int pc_prism_vt_4(void)
{
    struct VT_Engine vt_eng;

    RET_ON_ERR(check_smooth(&vt_eng.smooth));
    RET_ON_ERR(run_vt(&vt_eng));

    return
        bpx_unify(bpx_get_call_arg(1,4), bpx_build_integer(vt_eng.iterate   )) &&
        bpx_unify(bpx_get_call_arg(2,4), bpx_build_float  (vt_eng.lambda    )) &&
        bpx_unify(bpx_get_call_arg(3,4), bpx_build_float  (vt_eng.likelihood)) &&
        bpx_unify(bpx_get_call_arg(4,4), bpx_build_integer(vt_eng.smooth    )) ;
}
示例#9
0
  PosibErr<Speller *> new_speller(Config * c0) 
  {
    aspell_gettext_init();

    RET_ON_ERR_SET(find_word_list(c0), Config *, c);
    StackPtr<Speller> m(get_speller_class(c));
    RET_ON_ERR(m->setup(c));

    // Add enocder and decoder filters if any
    RET_ON_ERR(m->to_internal_->add_filters(m->config(), true, false, false));
    RET_ON_ERR(m->from_internal_->add_filters(m->config(), false, false, true));
    
    return m.release();
  }
示例#10
0
int pc_prism_em_6(void)
{
    struct EM_Engine em_eng;

    RET_ON_ERR(check_smooth(&em_eng.smooth));
    RET_ON_ERR(run_em(&em_eng));
    release_num_sw_vals();

    return
        bpx_unify(bpx_get_call_arg(1,6), bpx_build_integer(em_eng.iterate   )) &&
        bpx_unify(bpx_get_call_arg(2,6), bpx_build_float  (em_eng.lambda    )) &&
        bpx_unify(bpx_get_call_arg(3,6), bpx_build_float  (em_eng.likelihood)) &&
        bpx_unify(bpx_get_call_arg(4,6), bpx_build_float  (em_eng.bic       )) &&
        bpx_unify(bpx_get_call_arg(5,6), bpx_build_float  (em_eng.cs        )) &&
        bpx_unify(bpx_get_call_arg(6,6), bpx_build_integer(em_eng.smooth    )) ;
}
示例#11
0
文件: feature.c 项目: edechter/PRISM
int pc_crf_prepare_4(void) {
	TERM  p_fact_list;
	int   size;

	p_fact_list        = bpx_get_call_arg(1,4);
	size               = bpx_get_integer(bpx_get_call_arg(2,4));
	num_goals          = bpx_get_integer(bpx_get_call_arg(3,4));
	failure_root_index = bpx_get_integer(bpx_get_call_arg(4,4));

	failure_observed = (failure_root_index != -1);

	if (failure_root_index != -1) {
		failure_subgoal_id = prism_goal_id_get(failure_atom);
		if (failure_subgoal_id == -1) {
			emit_internal_error("no subgoal ID allocated to `failure'");
			RET_INTERNAL_ERR;
		}
	}

	initialize_egraph_index();
	alloc_sorted_egraph(size);
	RET_ON_ERR(sort_crf_egraphs(p_fact_list));
#ifndef MPI
	if (verb_graph) {
		print_egraph(0, PRINT_NEUTRAL);
	}
#endif /* !(MPI) */

	alloc_occ_switches();
	alloc_num_sw_vals();

	return BP_TRUE;
}
示例#12
0
/*
 * Note: parameters are always refreshed in advance by $pc_export_sw_info/1,
 *       so it causes no problem to overwrite them temporarily
 */
int pc_compute_n_viterbi_rerank_4(void)
{
    TERM p_n_viterbi_list;
    int n,l,goal_id;

    n       = bpx_get_integer(bpx_get_call_arg(1,4));
    l       = bpx_get_integer(bpx_get_call_arg(2,4));
    goal_id = bpx_get_integer(bpx_get_call_arg(3,4));

    initialize_egraph_index();
    alloc_sorted_egraph(1);
    /* INIT_MIN_MAX_NODE_NOS; */
    RET_ON_ERR(sort_one_egraph(goal_id,0,1));
    if (verb_graph) print_egraph(0,PRINT_NEUTRAL);

    alloc_occ_switches();
    transfer_hyperparams_prolog();
	get_param_means();

    compute_n_max(l);

    get_n_most_likely_path_rerank(n,l,goal_id,&p_n_viterbi_list);

    release_occ_switches();

    return bpx_unify(bpx_get_call_arg(4,4),p_n_viterbi_list);
}
示例#13
0
static int topological_sort(int node_id)
{
    EG_PATH_PTR path_ptr;
    EG_NODE_PTR *children;
    int k,len;
    EG_NODE_PTR child_ptr;

    expl_graph[node_id]->visited = 2;
    UPDATE_MIN_MAX_NODE_NOS(node_id);

    path_ptr = expl_graph[node_id]->path_ptr;
    while (path_ptr != NULL) {
        children = path_ptr->children;
        len = path_ptr->children_len;
        for (k = 0; k < len; k++) {
            child_ptr = children[k];

            if (child_ptr->visited == 2 && error_on_cycle)
                RET_ERR(err_cycle_detected);

            if (child_ptr->visited == 0) {
                RET_ON_ERR(topological_sort(child_ptr->id));
                expand_sorted_egraph(index_to_sort + 1);
                sorted_expl_graph[index_to_sort++] = child_ptr;
            }
            child_ptr->shared += 1;
        }

        path_ptr = path_ptr->next;
    }
    expl_graph[node_id]->visited = 1;
    return BP_TRUE;
}
示例#14
0
int sort_one_egraph(int root_id, int root_index, int count)
{
    roots[root_index] = (ROOT)MALLOC(sizeof(struct ObservedFactNode));
    roots[root_index]->id = root_id;
    roots[root_index]->count = count;

    if (expl_graph[root_id]->visited == 1) {
        /*
         * This top-goal is also a sub-goal of another top-goal.  This
         * should occur only when INIT_VISITED_FLAGS is suppressed
         * (i.e. we have more than one observed goal in learning).
         */
        if (suppress_init_flags) return BP_TRUE;
    }

    if (expl_graph[root_id]->visited != 0) RET_INTERNAL_ERR;

    RET_ON_ERR(topological_sort(root_id));

    expand_sorted_egraph(index_to_sort + 1);
    sorted_expl_graph[index_to_sort] = expl_graph[root_id];

    index_to_sort++;
    sorted_egraph_size = index_to_sort;

    /* initialize flags after use */
    if (!suppress_init_flags) INIT_VISITED_FLAGS;

    return BP_TRUE;
}
示例#15
0
/* [Note] node copying is not required here even in computation without
 * inter-goal sharing, but we need to declare it explicitly.
 */
int pc_compute_viterbi_5(void)
{
    TERM p_goal_path,p_subpath_goal,p_subpath_sw;
    int goal_id;
    double viterbi_prob;

    goal_id = bpx_get_integer(bpx_get_call_arg(1,5));

    initialize_egraph_index();
    alloc_sorted_egraph(1);
    /* INIT_MIN_MAX_NODE_NOS; */
    RET_ON_ERR(sort_one_egraph(goal_id,0,1));
    if (verb_graph) print_egraph(0,PRINT_NEUTRAL);

    compute_max();

    if (debug_level) print_egraph(1,PRINT_VITERBI);

    get_most_likely_path(goal_id,&p_goal_path,&p_subpath_goal,
                         &p_subpath_sw,&viterbi_prob);

    return
        bpx_unify(bpx_get_call_arg(2,5), p_goal_path)    &&
        bpx_unify(bpx_get_call_arg(3,5), p_subpath_goal) &&
        bpx_unify(bpx_get_call_arg(4,5), p_subpath_sw)   &&
        bpx_unify(bpx_get_call_arg(5,5), bpx_build_float(viterbi_prob));
}
示例#16
0
文件: feature.c 项目: edechter/PRISM
int pc_prism_grd_2(void) {
	struct CRF_Engine crf_eng;

	RET_ON_ERR(run_grd(&crf_eng));

	return
	    bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(crf_eng.iterate)) &&
	    bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(crf_eng.likelihood));
}
示例#17
0
 PosibErr<const char *> operator() (ParmStr str)
 {
   if (conv) {
     buf.clear();
     RET_ON_ERR(conv->convert_ec(str, -1, buf, buf0, str));
     return buf.mstr();
   } else {
     return str.str();
   }
 }
示例#18
0
 PosibErr<char *> operator() (char * str, size_t sz)
 {
   if (conv) {
     buf.clear();
     RET_ON_ERR(conv->convert_ec(str, sz, buf, buf0, str));
     return buf.mstr();
   } else {
     return str;
   }
 }
示例#19
0
 PosibErr<void> create_default_readonly_dict(StringEnumeration * els,
                                             Config & config)
 {
   CachePtr<Language> lang;
   PosibErr<Language *> res = new_language(config);
   if (res.has_err()) return res;
   lang.reset(res.data);
   lang->set_lang_defaults(config);
   RET_ON_ERR(create(els,*lang,config));
   return no_err;
 }
示例#20
0
  PosibErr<void> itemize (ParmString s, MutableContainer & d) {
    ItemizeTokenizer els(s);
    ItemizeItem li;
    while (li = els.next(), li.name != 0) {
      switch (li.action) {
      case '+':
	RET_ON_ERR(d.add(li.name));
	break;
      case '-':
	RET_ON_ERR(d.remove(li.name));
	break;
      case '!':
	RET_ON_ERR(d.clear());
	break;
      default:
	abort();
      }
    }
    return no_err;
  }
示例#21
0
/*
 * Sort the explanation graph such that no node sorted_expl_graph[i] calls
 * node sorted_expl_graph[j] if i < j.
 *
 * This function is used only for probf/1-2, so we don't have to consider
 * about scaling here.
 */
int pc_alloc_sort_egraph_1(void)
{
    int root_id;

    root_id = bpx_get_integer(bpx_get_call_arg(1,1));

    index_to_sort = 0;
    alloc_sorted_egraph(1);
    RET_ON_ERR(sort_one_egraph(root_id,0,1));

    return BP_TRUE;
}
示例#22
0
文件: feature.c 项目: edechter/PRISM
int pc_compute_fprobf_1(void) {
	int prmode;

	prmode = bpx_get_integer(bpx_get_call_arg(1,1));

	failure_root_index = -1;

	initialize_weights();

	/* [31 Mar 2008, by yuizumi]
	 * compute_outside_scaling_*() needs to be called because
	 * eg_ptr->outside computed by compute_expectation_scaling_*()
	 * is different from the outside probability.
	 */
	if (log_scale) {
		RET_ON_ERR(compute_feature_scaling_log_exp());
		if (prmode != 1) {
			RET_ON_ERR(compute_expectation_scaling_log_exp());
			RET_ON_ERR(compute_outside_scaling_log_exp());
		}
	} else {
		RET_ON_ERR(compute_feature_scaling_none());
		if (prmode != 1) {
			RET_ON_ERR(compute_expectation_scaling_none());
			RET_ON_ERR(compute_outside_scaling_none());
		}
	}

	return BP_TRUE;
}
示例#23
0
文件: feature.c 项目: edechter/PRISM
static double compute_phi_alpha_LBFGS(CRF_ENG_PTR crf_ptr, double alpha) {
	int i;
	SW_INS_PTR sw_ptr;

	for (i=0; i<occ_switch_tab_size; i++) {
		sw_ptr = occ_switches[i];
		if (sw_ptr->fixed > 0) continue;
		while (sw_ptr!=NULL) {
			if (sw_ptr->fixed == 0) {
				sw_ptr->inside = sw_ptr->current_inside - ( alpha * sw_ptr->LBFGS_q );
			}
			sw_ptr = sw_ptr->next;
		}
	}

	RET_ON_ERR(crf_ptr->compute_feature());
	crf_ptr->compute_crf_probs();
	return crf_ptr->compute_likelihood() * (-1);
}
示例#24
0
int pc_add_egraph_path_3(void)
{
    TERM p_node_id,p_children,p_sws;
	int node_id;

    /* children_prolog and sws_prolog must be in the table area */
    p_node_id  = bpx_get_call_arg(1,3);
    p_children = bpx_get_call_arg(2,3);
    p_sws      = bpx_get_call_arg(3,3);

	if (!bpx_is_integer(p_node_id))	RET_ERR(err_invalid_goal_id);
	node_id = bpx_get_integer(p_node_id);

    XDEREF(p_children);
    XDEREF(p_sws);

    RET_ON_ERR(add_egraph_path(node_id,p_children,p_sws));

    return BP_TRUE;
}
示例#25
0
int pc_compute_n_viterbi_3(void)
{
    TERM p_n_viterbi_list;
    int n,goal_id;

    n       = bpx_get_integer(bpx_get_call_arg(1,3));
    goal_id = bpx_get_integer(bpx_get_call_arg(2,3));

    initialize_egraph_index();
    alloc_sorted_egraph(1);
    /* INIT_MIN_MAX_NODE_NOS; */
    RET_ON_ERR(sort_one_egraph(goal_id,0,1));
    if (verb_graph) print_egraph(0,PRINT_NEUTRAL);

    compute_n_max(n);

    if (debug_level) print_egraph(1,PRINT_VITERBI);

    get_n_most_likely_path(n,goal_id,&p_n_viterbi_list);

    return bpx_unify(bpx_get_call_arg(3,3),p_n_viterbi_list);
}
示例#26
0
int pc_compute_probf_1(void)
{
    EG_NODE_PTR eg_ptr;
    int prmode;

    prmode = bpx_get_integer(bpx_get_call_arg(1,1));

    if (prmode == 3) {
        compute_max();
        return BP_TRUE;
    }

    eg_ptr = expl_graph[roots[0]->id];
    failure_root_index = -1;

    /* [31 Mar 2008, by yuizumi]
     * compute_outside_scaling_*() is needed to be called because
     * eg_ptr->outside computed by compute_expectation_scaling_*()
     * is different from the outside probability.
     */
	if (log_scale) {
        RET_ON_ERR(compute_inside_scaling_log_exp());
        if (prmode != 1) {
            RET_ON_ERR(compute_expectation_scaling_log_exp());
            RET_ON_ERR(compute_outside_scaling_log_exp());
        }
	}
	else {
        RET_ON_ERR(compute_inside_scaling_none());
        if (prmode != 1) {
            RET_ON_ERR(compute_expectation_scaling_none());
            RET_ON_ERR(compute_outside_scaling_none());
        }
	}

    return BP_TRUE;
}
示例#27
0
  PosibErr<Config *> find_word_list(Config * c) 
  {
    Config * config = new_config();
    RET_ON_ERR(config->read_in_settings(c));
    String dict_name;

    if (config->have("master")) {
      dict_name = config->retrieve("master");

    } else {

      ////////////////////////////////////////////////////////////////////
      //
      // Give first preference to an exact match for the language-country
      // code, then give preference to those in the alternate code list
      // in the order they are presented, then if there is no match
      // look for one for just language.  If that fails give up.
      // Once the best matching code is found, try to find a matching
      // variety if one exists, other wise look for one with no variety.
      //

      BetterList b_code;
      //BetterList b_jargon;
      BetterVariety b_variety;
      BetterList b_module;
      BetterSize b_size;
      Better * better[4] = {&b_code,&b_variety,&b_module,&b_size};
      const DictInfo * best = 0;

      //
      // retrieve and normalize code
      //
      const char * p;
      String code;
      PosibErr<String> str = config->retrieve("lang");
      p = str.data.c_str();
      while (asc_isalpha(*p))
        code += asc_tolower(*p++);
      String lang = code;
      bool have_country = false;
      if (*p == '-' || *p == '_') {
        ++p;
        have_country = true;
        code += '_'; 
        while (asc_isalpha(*p))
          code += asc_toupper(*p++);
      }
  
      //
      // Retrieve acceptable code search orders
      //
      String lang_country_list;
      if (have_country) {
        lang_country_list = code;
        lang_country_list += ' ';
      }
      String lang_only_list = lang;
      lang_only_list += ' ';

      // read retrieve lang_country_list and lang_only_list from file(s)
      // FIXME: Write Me

      //
      split_string_list(b_code.list, lang_country_list);
      split_string_list(b_code.list, lang_only_list);
      b_code.init();

      //
      // Retrieve Variety
      // 
      config->retrieve_list("variety", &b_variety.list);
      if (b_variety.list.empty() && config->have("jargon")) 
        b_variety.list.add(config->retrieve("jargon"));
      b_variety.init();
      str.data.clear();

      //
      // Retrieve module list
      //
      if (config->have("module"))
        b_module.list.add(config->retrieve("module"));
      else if (config->have("module-search-order"))
        config->retrieve_list("module-search-order", &b_module.list);
      {
        StackPtr<ModuleInfoEnumeration> els(get_module_info_list(config)->elements());
        const ModuleInfo * entry;
        while ( (entry = els->next()) != 0)
          b_module.list.add(entry->name);
      }
      b_module.init();

      //
      // Retrieve size
      //
      str = config->retrieve("size");
      p = str.data.c_str();
      if (p[0] == '+' || p[0] == '-' || p[0] == '<' || p[0] == '>') {
        b_size.req_type = p[0];
        ++p;
      } else {
        b_size.req_type = '+';
      }
      if (!asc_isdigit(p[0]) || !asc_isdigit(p[1]) || p[2] != '\0')
        abort(); //FIXME: create an error condition here
      b_size.requested = atoi(p);
      b_size.init();

      //
      // 
      //

      const DictInfoList * dlist = get_dict_info_list(config);
      DictInfoEnumeration * dels = dlist->elements();
      const DictInfo * entry;

      while ( (entry = dels->next()) != 0) {

        b_code  .cur = entry->code;
        b_module.cur = entry->module->name;

        b_variety.cur = entry->variety;
    
        b_size.cur_str = entry->size_str;
        b_size.cur     = entry->size;

        //
        // check to see if we got a better match than the current
        // best_match if any
        //

        IsBetter is_better = SameMatch;
        for (int i = 0; i != 4; ++i)
          is_better = better[i]->better_match(is_better);
    
        if (is_better == BetterMatch) {
          for (int i = 0; i != 4; ++i)
            better[i]->set_best_from_cur();
          best = entry;
        }
      }

      delete dels;

      //
      // set config to best match
      //
      if (best != 0) {
        String main_wl,flags;
        PosibErrBase ret = get_dict_file_name(best, main_wl, flags);
        if (ret.has_err()) {
          delete config;
          return ret;
        }
        dict_name = best->name;
        config->replace("lang", b_code.best);
        config->replace("language-tag", b_code.best);
        config->replace("master", main_wl.c_str());
        config->replace("master-flags", flags.c_str());
        config->replace("module", b_module.best);
        config->replace("jargon", b_variety.best);
        config->replace("clear-variety", "");
        unsigned p;
        for (const char * c = b_module.best; *c != '\0'; c += p) {
          p = strcspn(c, "-");
          config->replace("add-variety", String(c, p));
        }
        config->replace("size", b_size.best_str);
      } else {
        delete config;
        return make_err(no_wordlist_for_lang, code);
      }
    }

    const StringMap * dict_aliases = get_dict_aliases(config);
    const char * val = dict_aliases->lookup(dict_name);
    if (val) config->replace("master", val);
    return config;
  }
示例#28
0
 PosibErr<void> setup(const Config & c, ParmStr from, ParmStr to, Normalize norm)
 {
   RET_ON_ERR(conv_obj.setup(c,from,to,norm));
   conv = conv_obj.ptr;
   return no_err;
 }
int pc_compute_hindsight_4(void)
{
    TERM p_subgoal,p_hindsight_pairs,t,t1,p_pair;
    int goal_id,is_cond,j;

    goal_id   = bpx_get_integer(bpx_get_call_arg(1,4));
    p_subgoal = bpx_get_call_arg(2,4);
    is_cond   = bpx_get_integer(bpx_get_call_arg(3,4));

    initialize_egraph_index();
    alloc_sorted_egraph(1);
    RET_ON_ERR(sort_one_egraph(goal_id,0,1));
    if (verb_graph) print_egraph(0,PRINT_NEUTRAL);

    alloc_hindsight_goals();

    if (log_scale) {
        RET_ON_ERR(compute_inside_scaling_log_exp());
        RET_ON_ERR(compute_outside_scaling_log_exp());
        RET_ON_ERR(get_hindsight_goals_scaling_log_exp(p_subgoal,is_cond));
    }
    else {
        RET_ON_ERR(compute_inside_scaling_none());
        RET_ON_ERR(compute_outside_scaling_none());
        RET_ON_ERR(get_hindsight_goals_scaling_none(p_subgoal,is_cond));
    }

    if (hindsight_goal_size > 0) {
        /* Build the list of pairs of a subgoal and its hindsight probability */
        p_hindsight_pairs = bpx_build_list();
        t = p_hindsight_pairs;

        for (j = 0; j < hindsight_goal_size; j++) {
            p_pair = bpx_build_list();

            t1 = p_pair;
            bpx_unify(bpx_get_car(t1),
                      bpx_build_integer(hindsight_goals[j]));
            bpx_unify(bpx_get_cdr(t1),bpx_build_list());

            t1 = bpx_get_cdr(t1);
            bpx_unify(bpx_get_car(t1),bpx_build_float(hindsight_probs[j]));
            bpx_unify(bpx_get_cdr(t1),bpx_build_nil());

            bpx_unify(bpx_get_car(t),p_pair);

            if (j == hindsight_goal_size - 1) {
                bpx_unify(bpx_get_cdr(t),bpx_build_nil());
            }
            else {
                bpx_unify(bpx_get_cdr(t),bpx_build_list());
                t = bpx_get_cdr(t);
            }
        }
    }
    else {
        p_hindsight_pairs = bpx_build_nil();
    }

    FREE(hindsight_goals);
    FREE(hindsight_probs);

    return bpx_unify(bpx_get_call_arg(4,4),p_hindsight_pairs);
}
示例#30
0
文件: feature.c 项目: edechter/PRISM
/* main loop */
static int run_grd(CRF_ENG_PTR crf_ptr) {
	int r,iterate,old_valid,converged,conv_time,saved = 0;
	double likelihood,old_likelihood = 0.0;
	double crf_max_iterate = 0.0;
	double tmp_epsilon,alpha0,gf_sd,old_gf_sd = 0.0;

	config_crf(crf_ptr);

	initialize_weights();

	if (crf_learn_mode == 1) {
		initialize_LBFGS();
		printf("L-BFGS mode\n");
	}

	if (crf_learning_rate==1) {
		printf("learning rate:annealing\n");
	} else if (crf_learning_rate==2) {
		printf("learning rate:backtrack\n");
	} else if (crf_learning_rate==3) {
		printf("learning rate:golden section\n");
	}

	if (max_iterate == -1) {
		crf_max_iterate = DEFAULT_MAX_ITERATE;
	} else if (max_iterate >= +1) {
		crf_max_iterate = max_iterate;
	}

	for (r = 0; r < num_restart; r++) {
		SHOW_PROGRESS_HEAD("#crf-iters", r);

		initialize_crf_count();
		initialize_lambdas();
		initialize_visited_flags();

		old_valid = 0;
		iterate = 0;
		tmp_epsilon = crf_epsilon;

		LBFGS_index = 0;
		conv_time = 0;

		while (1) {
			if (CTRLC_PRESSED) {
				SHOW_PROGRESS_INTR();
				RET_ERR(err_ctrl_c_pressed);
			}

			RET_ON_ERR(crf_ptr->compute_feature());

			crf_ptr->compute_crf_probs();

			likelihood = crf_ptr->compute_likelihood();

			if (verb_em) {
				prism_printf("Iteration #%d:\tlog_likelihood=%.9f\n", iterate, likelihood);
			}

			if (debug_level) {
				prism_printf("After I-step[%d]:\n", iterate);
				prism_printf("likelihood = %.9f\n", likelihood);
				print_egraph(debug_level, PRINT_EM);
			}

			if (!isfinite(likelihood)) {
				emit_internal_error("invalid log likelihood: %s (at iteration #%d)",
				                    isnan(likelihood) ? "NaN" : "infinity", iterate);
				RET_ERR(ierr_invalid_likelihood);
			}
			/*        if (old_valid && old_likelihood - likelihood > prism_epsilon) {
					  emit_error("log likelihood decreased [old: %.9f, new: %.9f] (at iteration #%d)",
					  old_likelihood, likelihood, iterate);
					  RET_ERR(err_invalid_likelihood);
					  }*/
			if (likelihood > 0.0) {
				emit_error("log likelihood greater than zero [value: %.9f] (at iteration #%d)",
				           likelihood, iterate);
				RET_ERR(err_invalid_likelihood);
			}

			if (crf_learn_mode == 1 && iterate > 0) restore_old_gradient();

			RET_ON_ERR(crf_ptr->compute_gradient());

			if (crf_learn_mode == 1 && iterate > 0) {
				compute_LBFGS_y_rho();
				compute_hessian(iterate);
			} else if (crf_learn_mode == 1 && iterate == 0) {
				initialize_LBFGS_q();
			}

			converged = (old_valid && fabs(likelihood - old_likelihood) <= prism_epsilon);

			if (converged || REACHED_MAX_ITERATE(iterate)) {
				break;
			}

			old_likelihood = likelihood;
			old_valid = 1;

			if (debug_level) {
				prism_printf("After O-step[%d]:\n", iterate);
				print_egraph(debug_level, PRINT_EM);
			}

			SHOW_PROGRESS(iterate);

			if (crf_learning_rate == 1) { // annealing
				tmp_epsilon = (annealing_weight / (annealing_weight + iterate)) * crf_epsilon;
			} else if (crf_learning_rate == 2) { // line-search(backtrack)
				if (crf_learn_mode == 1) {
					gf_sd = compute_gf_sd_LBFGS();
				} else {
					gf_sd = compute_gf_sd();
				}
				if (iterate==0) {
					alpha0 = 1;
				} else {
					alpha0 = tmp_epsilon * old_gf_sd / gf_sd;
				}
				if (crf_learn_mode == 1) {
					tmp_epsilon = line_search_LBFGS(crf_ptr,alpha0,crf_ls_rho,crf_ls_c1,likelihood,gf_sd);
				} else {
					tmp_epsilon = line_search(crf_ptr,alpha0,crf_ls_rho,crf_ls_c1,likelihood,gf_sd);
				}

				if (tmp_epsilon < EPS) {
					emit_error("invalid alpha in line search(=0.0) (at iteration #%d)",iterate);
					RET_ERR(err_line_search);
				}
				old_gf_sd = gf_sd;
			} else if (crf_learning_rate == 3) { // line-search(golden section)
				if (crf_learn_mode == 1) {
					tmp_epsilon = golden_section_LBFGS(crf_ptr,0,crf_golden_b);
				} else {
					tmp_epsilon = golden_section(crf_ptr,0,crf_golden_b);
				}
			}
			crf_ptr->update_lambdas(tmp_epsilon);

			iterate++;
		}

		SHOW_PROGRESS_TAIL(converged, iterate, likelihood);

		if (r == 0 || likelihood > crf_ptr->likelihood) {
			crf_ptr->likelihood = likelihood;
			crf_ptr->iterate    = iterate;

			saved = (r < num_restart - 1);
			if (saved) {
				save_params();
			}
		}
	}

	if (crf_learn_mode == 1) clean_LBFGS();
	INIT_VISITED_FLAGS;
	return BP_TRUE;
}