コード例 #1
0
ファイル: hidden.hpp プロジェクト: jiyfeng/drlm
  /*********************************************
   * Build computation graph for one sentence
   * 
   * sent: Sent instance
   *********************************************/
  Expression BuildSentGraph(const Sent& sent, const unsigned sidx,
			    ComputationGraph& cg,
			    const int latval){
    builder.new_graph(cg);
    builder.start_new_sequence();
    // define expression
    Expression i_R = input(cg, p_R->dim, as_vector(p_R->values));
    Expression i_bias = input(cg, p_bias->dim, as_vector(p_bias->values));
    Expression i_context = input(cg, p_context->dim, as_vector(p_context->values));
    Expression i_L = input(cg, p_L->dim, as_vector(p_L->values));
    Expression i_lbias = input(cg, p_lbias->dim, as_vector(p_lbias->values));
    // Initialize cvec
    Expression cvec;
    if (sidx == 0)
      cvec = i_context;
    else
      cvec = input(cg, {(unsigned)final_h.size()}, final_h);
    // compute the prob for the given latval
    Expression i_Tk = const_lookup(cg, p_T, latval);
    Expression lv_neglogprob = pickneglogsoftmax(((i_L * cvec) + i_lbias), latval);
    vector<Expression> negloglik;
    Expression i_negloglik, i_x_t, i_h_t, i_y_t;
    unsigned slen = sent.size() - 1;
    for (unsigned t = 0; t < slen; t++){
      // get word representation
      i_x_t = const_lookup(cg, p_W, sent[t]);
      vector<Expression> vecexp;
      vecexp.push_back(i_x_t);
      vecexp.push_back(cvec);
      i_x_t = concatenate(vecexp);
      // compute hidden state
      i_h_t = builder.add_input(i_Tk * i_x_t);
      // compute prediction
      i_y_t = (i_R * i_h_t) + i_bias;
      // get prediction error
      i_negloglik = pickneglogsoftmax(i_y_t, sent[t+1]);
      // push back
      negloglik.push_back(i_negloglik);
    }
    // update final_h, if latval = nlatvar - 1
    vector<float> temp_h = as_vector(i_h_t.value());
    final_hlist.push_back(temp_h);
    Expression res = (sum(negloglik) + lv_neglogprob) * (-1.0);
    return res;
  }
コード例 #2
0
ファイル: genlex.c プロジェクト: agaurav/QT-GRETL
static void look_up_dollar_word (const char *s, parser *p)
{
    if ((p->idnum = dvar_lookup(s)) > 0) {
	p->sym = DVAR;
    } else if ((p->idnum = const_lookup(s)) > 0) {
	if (p->idnum == CONST_SYSINFO) {
	    p->sym = BUNDLE;
	    p->idstr = gretl_strdup("$sysinfo");
	    p->uval = get_sysinfo_bundle(&p->err);
	} else {
	    p->sym = CON;
	}
    } else if ((p->idnum = mvar_lookup(s)) > 0) {
	p->sym = MVAR;
    } else {
	undefined_symbol_error(s, p);
    }

#if LDEBUG
    fprintf(stderr, "look_up_dollar_word: '%s' -> %d\n",
	    s, p->idnum);
#endif
}
コード例 #3
0
ファイル: genlex.c プロジェクト: agaurav/QT-GRETL
static void look_up_word (const char *s, parser *p)
{
    int fsym, err = 0;

    fsym = p->sym = function_lookup_with_alias(s);

    if (p->sym == 0 || p->ch != '(') {
	p->idnum = const_lookup(s);
	if (p->idnum > 0) {
	    p->sym = CON;
	} else {
	    p->idnum = dummy_lookup(s);
	    if (p->idnum > 0) {
		p->sym = DUM;
	    } else {
		GretlType vtype = 0;
		char *bstr;

		if ((p->idnum = current_series_index(p->dset, s)) >= 0) {
		    p->sym = UVEC;
		    p->idstr = gretl_strdup(s);
		} else if (!strcmp(s, "time")) {
		    p->sym = DUM;
		    p->idnum = DUM_TREND;
		} else if ((p->uval = user_var_get_value_and_type(s, &vtype)) != NULL) {
		    if (vtype == GRETL_TYPE_DOUBLE) {
			p->sym = UNUM;
		    } else if (vtype == GRETL_TYPE_MATRIX) {
			p->sym = UMAT;
		    } else if (vtype == GRETL_TYPE_BUNDLE) {
			p->sym = BUNDLE;
		    } else if (vtype == GRETL_TYPE_STRING) {
			p->sym = USTR;
		    } else if (vtype == GRETL_TYPE_LIST) {
			p->sym = ULIST;
		    } else if (vtype == GRETL_TYPE_STRING) {
			p->sym = USTR;
		    }
		    p->idstr = gretl_strdup(s);
		} else if ((bstr = get_built_in_string_by_name(s))) {
		    /* FIXME should use $-accessors? */
		    p->sym = STR;
		    p->idstr = gretl_strdup(bstr);
		} else if (gretl_get_object_by_name(s)) {
		    p->sym = UOBJ;
		    p->idstr = gretl_strdup(s);
		} else if (get_user_function_by_name(s)) {
		    p->sym = UFUN;
		    p->idstr = gretl_strdup(s);
		} else if (p->targ == LIST && varname_match_any(p->dset, s)) {
		    p->sym = WLIST;
		    p->idstr = gretl_strdup(s);
		} else if (!strcmp(s, "t")) {
		    /* if "t" has not been otherwise defined, treat it
		       as an alias for "obs"
		    */
		    p->sym = DVAR;
		    p->idnum = R_INDEX;
		} else if (maybe_get_R_function(s)) {
		    /* note: all "native" types take precedence over this */
		    p->sym = RFUN;
		    p->idstr = gretl_strdup(s + 2);
		} else if (parsing_query) {
		    p->sym = UNDEF;
		    p->idstr = gretl_strdup(s);
		} else {
		    err = E_UNKVAR;
		}
	    }
	}
    }

    if (err) {
	if (fsym) {
	    function_noargs_error(s, p);
	} else {
	    undefined_symbol_error(s, p);
	}
    }
}
コード例 #4
0
ファイル: hidden.hpp プロジェクト: jiyfeng/drlm
  /************************************************
   * Build CG of a given doc with a latent sequence
   *
   * doc: 
   * cg: computation graph
   * latseq: latent sequence from decoding
   * obsseq: latent sequence from observation
   * flag: what we expected to get from this function
   *       "PROB": compute the probability of the last sentence 
   *               given the latent value
   *       "ERROR": compute the prediction error of entire doc
   *       "INFER": compute prediction error on words with 
   *                inferred latent variables
   ************************************************/
  Expression BuildRelaGraph(const Doc& doc, ComputationGraph& cg,
			    LatentSeq latseq, LatentSeq obsseq){
    builder.new_graph(cg);
    // define expression
    Expression i_R = parameter(cg, p_R);
    Expression i_bias = parameter(cg, p_bias);
    Expression i_context = parameter(cg, p_context);
    Expression i_L = parameter(cg, p_L);
    Expression i_lbias = parameter(cg, p_lbias);
    vector<Expression> negloglik, neglogprob;
    // -----------------------------------------
    // check hidden variable list
    assert(latseq.size() <= doc.size());
    // -----------------------------------------
    // iterate over latent sequences
    // get LV-related transformation matrix
    Expression i_h_t;
    vector<Expression> obj;
    for (unsigned k = 0; k < doc.size(); k++){
      auto& sent = doc[k];
      // start a new sequence for each sentence
      Expression cvec;
      if (k == 0){
	cvec = i_context;
      } else {
	cvec = input(cg, {(unsigned)final_h.size()}, final_h);
      }
      // two parts of the objective function
      Expression sent_objpart1;
      vector<Expression> sent_objpart2;
      for (int latval = 0; latval < nlatvar; latval ++){
	builder.start_new_sequence();
	// latent variable distribution
	vector<Expression> l_negloglik;
	Expression l_neglogprob = pickneglogsoftmax((i_L * cvec) + i_lbias, latval); 
	// build RNN for the current sentence
	Expression i_x_t, i_h_t, i_y_t, i_negloglik;
	Expression i_Tk = lookup(cg, p_T, latval);
	// for each word
	unsigned slen = sent.size() - 1;
	for (unsigned t = 0; t < slen; t++){
	  // get word representation
	  i_x_t = const_lookup(cg, p_W, sent[t]);
	  vector<Expression> vecexp;
	  vecexp.push_back(i_x_t);
	  vecexp.push_back(cvec);
	  i_x_t = concatenate(vecexp);
	  // compute hidden state
	  i_h_t = builder.add_input(i_Tk * i_x_t);
	  // compute prediction
	  i_y_t = (i_R * i_h_t) + i_bias;
	  // get prediction error
	  i_negloglik = pickneglogsoftmax(i_y_t, sent[t+1]);
	  // add back
	  l_negloglik.push_back(i_negloglik);
	}
	// update context vector
	if (latval == (nlatvar - 1)){
	  final_h.clear();
	  final_h = as_vector(i_h_t.value());
	}
	// - log P(Y, Z) given Y and a specific Z value
	Expression pxz = sum(l_negloglik) + l_neglogprob;
	sent_objpart2.push_back(pxz * (-1.0));
	if (obsseq[k] == latval){
	  sent_objpart1 = pxz * (-1.0);
	}
      }
      // if the latent variable is observed
      if (obsseq[k] >= 0){
	Expression sent_obj = logsumexp(sent_objpart2) - sent_objpart1;
	obj.push_back(sent_obj);
	// cout << as_scalar(sent_obj.value()) << endl;
      }
    }
    // get the objectve for entire doc
    if (obj.size() > 0){
      // if at least one observed latent value
      return sum(obj);
    } else {
      // otherwise
      Expression zero = input(cg, 0.0);
      return zero;
    }
  }