예제 #1
0
파일: udp.c 프로젝트: EPiCS/reconos_v2
/*
 * Handle incoming UDP packets.
 */
void
__udp_handler(pktbuf_t *pkt, ip_route_t *r)
{
    udp_header_t *udp = pkt->udp_hdr;
    ip_header_t  *ip = pkt->ip_hdr;
    udp_socket_t *s;

    if (udp->checksum == 0xffff)
	udp->checksum = 0;

    /* copy length for pseudo sum calculation */
    ip->length = udp->length;

    if (__sum((word *)udp, ntohs(udp->length), __pseudo_sum(ip)) == 0) {
	for (s = udp_list; s; s = s->next) {
	    if (s->our_port == udp->dest_port) {
		(*s->handler)(s, ((char *)udp) + sizeof(udp_header_t),
			      ntohs(udp->length) - sizeof(udp_header_t),
			      r, ntohs(udp->src_port));
                __pktbuf_free(pkt);
		return;
	    }
	}
    }
    __pktbuf_free(pkt);
}
예제 #2
0
/*
 * Send an IP packet.
 *
 * The IP data field should contain pkt->pkt_bytes of data.
 * pkt->[udp|tcp|icmp]_hdr points to the IP data field. Any
 * IP options are assumed to be already in place in the IP
 * options field.
 */
int
__ip_send(pktbuf_t *pkt, int protocol, ip_route_t *dest)
{
    ip_header_t *ip = pkt->ip_hdr;
    int         hdr_bytes;
    unsigned short cksum;
    
    /*
     * Figure out header length. The use udp_hdr is
     * somewhat arbitrary, but works because it is
     * a union with other IP protocol headers.
     */
    hdr_bytes = (((char *)pkt->udp_hdr) - ((char *)ip));

    pkt->pkt_bytes += hdr_bytes;

    ip->version = 4;
    ip->hdr_len = hdr_bytes >> 2;
    ip->tos = 0;
    ip->length = htons(pkt->pkt_bytes);
    ip->ident = htons(ip_ident);
    ip_ident++;
    ip->fragment = 0;
    ip->ttl = 255;
    ip->ttl = 64;
    ip->protocol = protocol;
    ip->checksum = 0;
    memcpy(ip->source, __local_ip_addr, sizeof(ip_addr_t));
    memcpy(ip->destination, dest->ip_addr, sizeof(ip_addr_t));
    cksum = __sum((word *)ip, hdr_bytes, 0);
    ip->checksum = htons(cksum);

    __enet_send(pkt, &dest->enet_addr, ETH_TYPE_IP);    
    return 0;
}
예제 #3
0
/*
 * The default ICMP handler only handles ICMP incoming echo request and
 * outgoing echo reply.
 */
static void
default_icmp_handler(pktbuf_t *pkt, ip_route_t *dest)
{
    word cksum;

    if (pkt->icmp_hdr->type == ICMP_TYPE_ECHOREQUEST
	&& pkt->icmp_hdr->code == 0
	&& __sum((word *)pkt->icmp_hdr, pkt->pkt_bytes, 0) == 0) {

	pkt->icmp_hdr->type = ICMP_TYPE_ECHOREPLY;
	pkt->icmp_hdr->checksum = 0;
        cksum = __sum((word *)pkt->icmp_hdr, pkt->pkt_bytes, 0);
	pkt->icmp_hdr->checksum = htons(cksum);
        __ip_send(pkt, IP_PROTO_ICMP, dest);
    }
}
예제 #4
0
static void handle_icmp(pktbuf_t * pkt, ip_route_t * src_route)
{
	icmp_header_t *icmp;
	unsigned short cksum;

	icmp = pkt->icmp_hdr;
	if (icmp->type == ICMP_TYPE_ECHOREQUEST
	    && icmp->code == 0
	    && __sum((word *) icmp, pkt->pkt_bytes, 0) == 0) {

		icmp->type = ICMP_TYPE_ECHOREPLY;
		icmp->checksum = 0;
		cksum = __sum((word *) icmp, pkt->pkt_bytes, 0);
		icmp->checksum = htons(cksum);
		__ip_send(pkt, IP_PROTO_ICMP, src_route);
	} else if (icmp->type == ICMP_TYPE_ECHOREPLY) {
		memcpy(&hold_hdr, icmp, sizeof(*icmp));
		icmp_received = true;
	}
}
예제 #5
0
/*
 * Handle IP packets coming from the polled ethernet interface.
 */
void
__ip_handler(pktbuf_t *pkt, enet_addr_t *src_enet_addr)
{
    ip_header_t *ip = pkt->ip_hdr;
    ip_route_t  r;
    int         hdr_bytes;

    /* first make sure its ours and has a good checksum. */
    if (!ip_addr_match(ip->destination) ||
	__sum((word *)ip, ip->hdr_len << 2, 0) != 0) {
	__pktbuf_free(pkt);
	return;
    }

    memcpy(r.ip_addr, ip->source, sizeof(ip_addr_t));
    memcpy(r.enet_addr, src_enet_addr, sizeof(enet_addr_t));

    hdr_bytes = ip->hdr_len << 2;
    pkt->pkt_bytes = ntohs(ip->length) - hdr_bytes;

    switch (ip->protocol) {

#if NET_SUPPORT_ICMP
      case IP_PROTO_ICMP:
	pkt->icmp_hdr = (icmp_header_t *)(((char *)ip) + hdr_bytes);
	__icmp_handler(pkt, &r);
	break;
#endif

#if NET_SUPPORT_TCP
      case IP_PROTO_TCP:
	pkt->tcp_hdr = (tcp_header_t *)(((char *)ip) + hdr_bytes);
	__tcp_handler(pkt, &r);
	break;
#endif

#if NET_SUPPORT_UDP
      case IP_PROTO_UDP:
	pkt->udp_hdr = (udp_header_t *)(((char *)ip) + hdr_bytes);
	__udp_handler(pkt, &r);
	break;
#endif

      default:
	__pktbuf_free(pkt);
	break;
    }
}
예제 #6
0
파일: udp.c 프로젝트: EPiCS/reconos_v2
/*
 * Send a UDP packet.
 */
int
__udp_send(char *buf, int len, ip_route_t *dest_ip,
	   word dest_port, word src_port)
{
    pktbuf_t *pkt;
    udp_header_t *udp;
    ip_header_t *ip;
    unsigned short cksum;
    int ret;

    /* dumb */
    if (len > MAX_UDP_DATA)
	return -1;

    /* just drop it if can't get a buffer */
    if ((pkt = __pktbuf_alloc(ETH_MAX_PKTLEN)) == NULL)
	return -1;

    udp = pkt->udp_hdr;
    ip = pkt->ip_hdr;

    pkt->pkt_bytes = len + sizeof(udp_header_t);

    udp->src_port = htons(src_port);
    udp->dest_port = htons(dest_port);
    udp->length = htons(pkt->pkt_bytes);
    udp->checksum = 0;

    memcpy(((char *)udp) + sizeof(udp_header_t), buf, len);

    /* fill in some pseudo-header fields */
    memcpy(ip->source, __local_ip_addr, sizeof(ip_addr_t));
    memcpy(ip->destination, dest_ip->ip_addr, sizeof(ip_addr_t));
    ip->protocol = IP_PROTO_UDP;
    ip->length = udp->length;

    cksum = __sum((word *)udp, pkt->pkt_bytes, __pseudo_sum(ip));
    udp->checksum = htons(cksum);

    ret = __ip_send(pkt, IP_PROTO_UDP, dest_ip);
    __pktbuf_free(pkt);
    return ret;
}
예제 #7
0
파일: sparse.cpp 프로젝트: 4Liamk/KFusion
void vector_gpu_dot(Vector * a, Vector * other, TYPE &val)
{
	vector_transfer(a,1);
	vector_transfer(other,1);
	
	int call;
	size_t woot = a->lsize;
	int reductionlength = a->length/a->lsize;
	size_t woot2 = a->lsize/2;
	
	cl_mem tmp 	= clCreateBuffer(context,CL_MEM_READ_WRITE,sizeof(TYPE)*woot,NULL,&call);
	cl_mem result	= clCreateBuffer(context,CL_MEM_READ_WRITE,sizeof(TYPE),NULL,&call);
	
	check(clSetKernelArg(dotproduct_kernel,0,sizeof(cl_mem),&a->gpu_vals));
	check(clSetKernelArg(dotproduct_kernel,1,sizeof(cl_mem),&other->gpu_vals));
	check(clSetKernelArg(dotproduct_kernel,2,sizeof(cl_mem),&tmp));
	check(clEnqueueNDRangeKernel(queue,dotproduct_kernel,1,0,&a->length,&a->lsize,0,NULL,NULL));	
	
	//second kernel execution is hidden in this function call.  This preserved synchronization and allows for fusion!
	__sum(&tmp, val,call,reductionlength,woot);
	clReleaseMemObject(tmp);
} 
예제 #8
0
static void do_ping(int argc, char *argv[])
{
	struct option_info opts[7];
	long count, timeout, length, rate, start_time, end_time, timer,
	    received, tries;
	char *local_ip_addr, *host_ip_addr;
	bool local_ip_addr_set, host_ip_addr_set, count_set,
	    timeout_set, length_set, rate_set, verbose;
	struct sockaddr_in local_addr, host_addr;
	ip_addr_t hold_addr;
	icmp_header_t *icmp;
	pktbuf_t *pkt;
	ip_header_t *ip;
	unsigned short cksum;
	ip_route_t dest_ip;

	init_opts(&opts[0], 'n', true, OPTION_ARG_TYPE_NUM,
		  (void *)&count, (bool *) & count_set,
		  "<count> - number of packets to test");
	init_opts(&opts[1], 't', true, OPTION_ARG_TYPE_NUM, (void *)&timeout,
		  (bool *) & timeout_set,
		  "<timeout> - max #ms per packet [rount trip]");
	init_opts(&opts[2], 'i', true, OPTION_ARG_TYPE_STR,
		  (void *)&local_ip_addr, (bool *) & local_ip_addr_set,
		  "local IP address");
	init_opts(&opts[3], 'h', true, OPTION_ARG_TYPE_STR,
		  (void *)&host_ip_addr, (bool *) & host_ip_addr_set,
		  "host name or IP address");
	init_opts(&opts[4], 'l', true, OPTION_ARG_TYPE_NUM, (void *)&length,
		  (bool *) & length_set, "<length> - size of payload");
	init_opts(&opts[5], 'v', false, OPTION_ARG_TYPE_FLG, (void *)&verbose,
		  (bool *) 0, "verbose operation");
	init_opts(&opts[6], 'r', true, OPTION_ARG_TYPE_NUM, (void *)&rate,
		  (bool *) & rate_set, "<rate> - time between packets");
	if (!scan_opts(argc, argv, 1, opts, 7, (void **)0, 0, "")) {
		diag_printf("PING - Invalid option specified\n");
		return;
	}
	// Set defaults; this has to be done _after_ the scan, since it will
	// have destroyed all values not explicitly set.
	if (local_ip_addr_set) {
		if (!_gethostbyname(local_ip_addr, (in_addr_t *) & local_addr)) {
			diag_printf("PING - Invalid local name: %s\n",
				    local_ip_addr);
			return;
		}
	} else {
		memcpy((in_addr_t *) & local_addr, __local_ip_addr,
		       sizeof(__local_ip_addr));
	}
	if (host_ip_addr_set) {
		if (!_gethostbyname(host_ip_addr, (in_addr_t *) & host_addr)) {
			diag_printf("PING - Invalid host name: %s\n",
				    host_ip_addr);
			return;
		}
		if (__arp_lookup((ip_addr_t *) & host_addr.sin_addr, &dest_ip) <
		    0) {
			diag_printf("PING: Cannot reach server '%s' (%s)\n",
				    host_ip_addr,
				    inet_ntoa((in_addr_t *) & host_addr));
			return;
		}
	} else {
		diag_printf("PING - host name or IP address required\n");
		return;
	}
#define DEFAULT_LENGTH   64
#define DEFAULT_COUNT    10
#define DEFAULT_TIMEOUT  1000
#define DEFAULT_RATE     1000
	if (!rate_set) {
		rate = DEFAULT_RATE;
	}
	if (!length_set) {
		length = DEFAULT_LENGTH;
	}
	if ((length < 64) || (length > 1400)) {
		diag_printf("Invalid length specified: %ld\n", length);
		return;
	}
	if (!count_set) {
		count = DEFAULT_COUNT;
	}
	if (!timeout_set) {
		timeout = DEFAULT_TIMEOUT;
	}
	// Note: two prints here because 'inet_ntoa' returns a static pointer
	diag_printf("Network PING - from %s",
		    inet_ntoa((in_addr_t *) & local_addr));
	diag_printf(" to %s\n", inet_ntoa((in_addr_t *) & host_addr));
	received = 0;
	__icmp_install_listener(handle_icmp);
	// Save default "local" address
	memcpy(hold_addr, __local_ip_addr, sizeof(hold_addr));
	for (tries = 0; tries < count; tries++) {
		// The network stack uses the global variable '__local_ip_addr'
		memcpy(__local_ip_addr, &local_addr, sizeof(__local_ip_addr));
		// Build 'ping' request
		if ((pkt = __pktbuf_alloc(ETH_MAX_PKTLEN)) == NULL) {
			// Give up if no packets - something is wrong
			break;
		}

		icmp = pkt->icmp_hdr;
		ip = pkt->ip_hdr;
		pkt->pkt_bytes = length + sizeof(icmp_header_t);

		icmp->type = ICMP_TYPE_ECHOREQUEST;
		icmp->code = 0;
		icmp->checksum = 0;
		icmp->seqnum = htons(tries + 1);
		cksum = __sum((word *) icmp, pkt->pkt_bytes, 0);
		icmp->checksum = htons(cksum);

		memcpy(ip->source, (in_addr_t *) & local_addr,
		       sizeof(ip_addr_t));
		memcpy(ip->destination, (in_addr_t *) & host_addr,
		       sizeof(ip_addr_t));
		ip->protocol = IP_PROTO_ICMP;
		ip->length = htons(pkt->pkt_bytes);

		__ip_send(pkt, IP_PROTO_ICMP, &dest_ip);
		__pktbuf_free(pkt);

		start_time = MS_TICKS();
		timer = start_time + timeout;
		icmp_received = false;
		while (!icmp_received && (MS_TICKS_DELAY() < timer)) {
			if (_rb_break(1)) {
				goto abort;
			}
			MS_TICKS_DELAY();
			__enet_poll();
		}
		end_time = MS_TICKS();

		timer = MS_TICKS() + rate;
		while (MS_TICKS_DELAY() < timer) {
			if (_rb_break(1)) {
				goto abort;
			}
			MS_TICKS_DELAY();
			__enet_poll();
		}

		if (icmp_received) {
			received++;
			if (verbose) {
				diag_printf(" seq: %ld, time: %ld (ticks)\n",
					    ntohs(hold_hdr.seqnum),
					    end_time - start_time);
			}
		}
	}
abort:
	__icmp_remove_listener();
	// Clean up
	memcpy(__local_ip_addr, &hold_addr, sizeof(__local_ip_addr));
	// Report
	diag_printf("PING - received %ld of %ld expected\n", received, count);
}
예제 #9
0
 T sum(int u, int v) {
     int r = links.find(u, v);
     return st.query(index[r], index[r]) + __sum(u, r) + __sum(v, r);
 }
예제 #10
0
/**
 * @brief This function is the sfunc for the aggregator computing the topic
 * counts. It scans the topic assignments in a document and updates the word
 * topic counts.
 * @param args[0]   The state variable, current topic counts
 * @param args[1]   The unique words in the document
 * @param args[2]   The counts of each unique word in the document
 * @param args[3]   The topic assignments in the document
 * @param args[4]   The size of vocabulary
 * @param args[5]   The number of topics
 * @return          The updated state
 **/
AnyType lda_count_topic_sfunc::run(AnyType & args)
{
    if(args[4].isNull() || args[5].isNull())
        throw std::invalid_argument("null parameter - voc_size and/or \
        topic_num is null");

    if(args[1].isNull() || args[2].isNull() || args[3].isNull())
        return args[0];

    int32_t voc_size = args[4].getAs<int32_t>();
    int32_t topic_num = args[5].getAs<int32_t>();
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");

    ArrayHandle<int32_t> words = args[1].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[2].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> topic_assignment = args[3].getAs<ArrayHandle<int32_t> >();
    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch - words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");
    if(__min(topic_assignment) < 0 || __max(topic_assignment) >= topic_num)
        throw std::invalid_argument("invalid values in topics");
    if((size_t)__sum(counts) != topic_assignment.size())
        throw std::invalid_argument(
            "dimension mismatch - sum(counts) != topic_assignment.size()");

    MutableArrayHandle<int64_t> state(NULL);
    int32_t *model;
    if(args[0].isNull()) {
        // to store a voc_size x (topic_num+1) integer matrix in
        // bigint[] (the +1 is for a flag of ceiling the count),
        // we need padding if the size is odd.
        // 1. when voc_size * (topic_num + 1) is (2n+1), gives (n+1)
        // 2. when voc_size * (topic_num + 1) is (2n), gives (n)
        int dims[1] = {static_cast<int>( (voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t) )};
        int lbs[1] = {1};
        state = madlib_construct_md_array(
            NULL, NULL, 1, dims, lbs, INT8TI.oid, INT8TI.len, INT8TI.byval,
            INT8TI.align);
        // the reason we use bigint[] because integer[] has limit on number of
        // elements and thus cannot be larger than 500MB
        model = reinterpret_cast<int32_t *>(state.ptr());
    } else {
        state = args[0].getAs<MutableArrayHandle<int64_t> >();
        model = reinterpret_cast<int32_t *>(state.ptr());
    }

    int32_t unique_word_count = static_cast<int32_t>(words.size());
    int32_t word_index = 0;
    for(int32_t i = 0; i < unique_word_count; i++){
        int32_t wordid = words[i];
        for(int32_t j = 0; j < counts[i]; j++){
            int32_t topic = topic_assignment[word_index];
            if (model[wordid * (topic_num + 1) + topic] <= 2e9) {
                model[wordid * (topic_num + 1) + topic]++;
            } else {
                model[wordid * (topic_num + 1) + topic_num] = 1;
            }
            word_index++;
        }
    }
    return state;
}
예제 #11
0
/**
 * @brief This function learns the topics of words in a document and is the
 * main step of a Gibbs sampling iteration. The word topic counts and
 * corpus topic counts are passed to this function in the first call and
 * then transfered to the rest calls through args.mSysInfo->user_fctx for
 * efficiency.
 * @param args[0]   The unique words in the documents
 * @param args[1]   The counts of each unique words
 * @param args[2]   The topic counts and topic assignments in the document
 * @param args[3]   The model (word topic counts and corpus topic
 *                  counts)
 * @param args[4]   The Dirichlet parameter for per-document topic
 *                  multinomial, i.e. alpha
 * @param args[5]   The Dirichlet parameter for per-topic word
 *                  multinomial, i.e. beta
 * @param args[6]   The size of vocabulary
 * @param args[7]   The number of topics
 * @param args[8]   The number of iterations (=1:training, >1:prediction)
 * @return          The updated topic counts and topic assignments for
 *                  the document
 **/
AnyType lda_gibbs_sample::run(AnyType & args)
{
    ArrayHandle<int32_t> words = args[0].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[1].getAs<ArrayHandle<int32_t> >();
    MutableArrayHandle<int32_t> doc_topic = args[2].getAs<MutableArrayHandle<int32_t> >();
    double alpha = args[4].getAs<double>();
    double beta = args[5].getAs<double>();
    int32_t voc_size = args[6].getAs<int32_t>();
    int32_t topic_num = args[7].getAs<int32_t>();
    int32_t iter_num = args[8].getAs<int32_t>();
    size_t model64_size = static_cast<size_t>(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t);

    if(alpha <= 0)
        throw std::invalid_argument("invalid argument - alpha");
    if(beta <= 0)
        throw std::invalid_argument("invalid argument - beta");
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");
    if(iter_num <= 0)
        throw std::invalid_argument(
            "invalid argument - iter_num");

    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch: words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");

    int32_t word_count = __sum(counts);
    if(doc_topic.size() != (size_t)(word_count + topic_num))
        throw std::invalid_argument(
            "invalid dimension - doc_topic.size() != word_count + topic_num");
    if(__min(doc_topic, 0, topic_num) < 0)
        throw std::invalid_argument("invalid values in topic_count");
    if(
        __min(doc_topic, topic_num, word_count) < 0 ||
        __max(doc_topic, topic_num, word_count) >= topic_num)
        throw std::invalid_argument( "invalid values in topic_assignment");

    if (!args.getUserFuncContext()) {
        ArrayHandle<int64_t> model64 = args[3].getAs<ArrayHandle<int64_t> >();
        if (model64.size() != model64_size) {
            std::stringstream ss;
            ss << "invalid dimension: model64.size() = " << model64.size();
            throw std::invalid_argument(ss.str());
        }
        if (__min(model64) < 0) {
            throw std::invalid_argument("invalid topic counts in model");
        }

        int32_t *context =
            static_cast<int32_t *>(
                MemoryContextAllocZero(
                    args.getCacheMemoryContext(),
                    model64.size() * sizeof(int64_t)
                        + topic_num * sizeof(int64_t)));
        memcpy(context, model64.ptr(), model64.size() * sizeof(int64_t));
        int32_t *model = context;

        int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
                context + model64_size * sizeof(int64_t) / sizeof(int32_t));
        for (int i = 0; i < voc_size; i ++) {
            for (int j = 0; j < topic_num; j ++) {
                running_topic_counts[j] += model[i * (topic_num + 1) + j];
            }
        }

        args.setUserFuncContext(context);
    }

    int32_t *context = static_cast<int32_t *>(args.getUserFuncContext());
    if (context == NULL) {
        throw std::runtime_error("args.mSysInfo->user_fctx is null");
    }
    int32_t *model = context;
    int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
            context + model64_size * sizeof(int64_t) / sizeof(int32_t));

    int32_t unique_word_count = static_cast<int32_t>(words.size());
    for(int32_t it = 0; it < iter_num; it++){
        int32_t word_index = topic_num;
        for(int32_t i = 0; i < unique_word_count; i++) {
            int32_t wordid = words[i];
            for(int32_t j = 0; j < counts[i]; j++){
                int32_t topic = doc_topic[word_index];
                int32_t retopic = __lda_gibbs_sample(
                    topic_num, topic, doc_topic.ptr(),
                    model + wordid * (topic_num + 1),
                    running_topic_counts, alpha, beta);
                doc_topic[word_index] = retopic;
                doc_topic[topic]--;
                doc_topic[retopic]++;

                if(iter_num == 1) {
                    if (model[wordid * (topic_num + 1) + retopic] <= 2e9) {
                        running_topic_counts[topic] --;
                        running_topic_counts[retopic] ++;
                        model[wordid * (topic_num + 1) + topic]--;
                        model[wordid * (topic_num + 1) + retopic]++;
                    } else {
                        model[wordid * (topic_num + 1) + topic_num] = 1;
                    }
                }
                word_index++;
            }
        }
    }

    return doc_topic;
}
예제 #12
0
파일: lda.cpp 프로젝트: adirastogi/madlib
/**
 * @brief This function is the sfunc for the aggregator computing the topic
 * counts. It scans the topic assignments in a document and updates the word
 * topic counts.
 * @param args[0]   The state variable, current topic counts
 * @param args[1]   The unique words in the document
 * @param args[2]   The counts of each unique word in the document
 * @param args[3]   The topic assignments in the document
 * @param args[4]   The size of vocabulary
 * @param args[5]   The number of topics 
 * @return          The updated state
 **/
AnyType lda_count_topic_sfunc::run(AnyType & args)
{
    if(args[4].isNull() || args[5].isNull())
        throw std::invalid_argument("null parameter - voc_size and/or \
        topic_num is null");

    if(args[1].isNull() || args[2].isNull() || args[3].isNull()) 
        return args[0];

    int32_t voc_size = args[4].getAs<int32_t>();
    int32_t topic_num = args[5].getAs<int32_t>();
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");

    ArrayHandle<int32_t> words = args[1].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[2].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> topic_assignment = args[3].getAs<ArrayHandle<int32_t> >();
    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch - words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");
    if(__min(topic_assignment) < 0 || __max(topic_assignment) >= topic_num)
        throw std::invalid_argument("invalid values in topics");
    if((size_t)__sum(counts) != topic_assignment.size())
        throw std::invalid_argument(
            "dimension mismatch - sum(counts) != topic_assignment.size()");

    MutableArrayHandle<int64_t> state(NULL);
    if(args[0].isNull()){
        int dims[2] = {voc_size + 1, topic_num};
        int lbs[2] = {1, 1};
        state = madlib_construct_md_array(
            NULL, NULL, 2, dims, lbs, INT8TI.oid, INT8TI.len, INT8TI.byval,
            INT8TI.align);
    } else {
        state = args[0].getAs<MutableArrayHandle<int64_t> >();
    }

    int32_t unique_word_count = static_cast<int32_t>(words.size());
    int32_t word_index = 0;
    for(int32_t i = 0; i < unique_word_count; i++){
        int32_t wordid = words[i];
        for(int32_t j = 0; j < counts[i]; j++){
            int32_t topic = topic_assignment[word_index];
            state[wordid * topic_num + topic]++;
            state[voc_size * topic_num + topic]++;
            word_index++;
        }
    }

    return state;
}
예제 #13
0
파일: lda.cpp 프로젝트: adirastogi/madlib
/**
 * @brief This function learns the topics of words in a document and is the
 * main step of a Gibbs sampling iteration. The word topic counts and
 * corpus topic counts are passed to this function in the first call and
 * then transfered to the rest calls through args.mSysInfo->user_fctx for
 * efficiency. 
 * @param args[0]   The unique words in the documents
 * @param args[1]   The counts of each unique words
 * @param args[2]   The topic counts and topic assignments in the document
 * @param args[3]   The model (word topic counts and corpus topic
 *                  counts)
 * @param args[4]   The Dirichlet parameter for per-document topic
 *                  multinomial, i.e. alpha
 * @param args[5]   The Dirichlet parameter for per-topic word
 *                  multinomial, i.e. beta
 * @param args[6]   The size of vocabulary
 * @param args[7]   The number of topics
 * @param args[8]   The number of iterations (=1:training, >1:prediction)
 * @return          The updated topic counts and topic assignments for
 *                  the document
 **/
AnyType lda_gibbs_sample::run(AnyType & args)
{
    ArrayHandle<int32_t> words = args[0].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[1].getAs<ArrayHandle<int32_t> >();
    MutableArrayHandle<int32_t> doc_topic = args[2].getAs<MutableArrayHandle<int32_t> >();
    double alpha = args[4].getAs<double>();
    double beta = args[5].getAs<double>();
    int32_t voc_size = args[6].getAs<int32_t>();
    int32_t topic_num = args[7].getAs<int32_t>();
    int32_t iter_num = args[8].getAs<int32_t>();

    if(alpha <= 0)
        throw std::invalid_argument("invalid argument - alpha");
    if(beta <= 0)
        throw std::invalid_argument("invalid argument - beta");
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");
    if(iter_num <= 0)
        throw std::invalid_argument(
            "invalid argument - iter_num");

    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch: words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");

    int32_t word_count = __sum(counts);
    if(doc_topic.size() != (size_t)(word_count + topic_num))
        throw std::invalid_argument(
            "invalid dimension - doc_topic.size() != word_count + topic_num");
    if(__min(doc_topic, 0, topic_num) < 0)
        throw std::invalid_argument("invalid values in topic_count");
    if(
        __min(doc_topic, topic_num, word_count) < 0 ||
        __max(doc_topic, topic_num, word_count) >= topic_num)
        throw std::invalid_argument( "invalid values in topic_assignment");

    if (!args.getUserFuncContext())
    {
        if(args[3].isNull())
            throw std::invalid_argument("invalid argument - the model \
            parameter should not be null for the first call");
        ArrayHandle<int64_t> model = args[3].getAs<ArrayHandle<int64_t> >();
        if(model.size() != (size_t)((voc_size + 1) * topic_num))
            throw std::invalid_argument(
                "invalid dimension - model.size() != (voc_size + 1) * topic_num");
        if(__min(model) < 0)
            throw std::invalid_argument("invalid topic counts in model");

        int64_t * state = 
            static_cast<int64_t *>(
                MemoryContextAllocZero(
                    args.getCacheMemoryContext(), 
                    model.size() * sizeof(int64_t)));
        memcpy(state, model.ptr(), model.size() * sizeof(int64_t));
        args.setUserFuncContext(state);
    }

    int64_t * state = static_cast<int64_t *>(args.getUserFuncContext());
    if(NULL == state){
        throw std::runtime_error("args.mSysInfo->user_fctx is null");
    }

    int32_t unique_word_count = static_cast<int32_t>(words.size());
    for(int32_t it = 0; it < iter_num; it++){
        int32_t word_index = topic_num;
        for(int32_t i = 0; i < unique_word_count; i++) {
            int32_t wordid = words[i];
            for(int32_t j = 0; j < counts[i]; j++){
                int32_t topic = doc_topic[word_index];
                int32_t retopic = __lda_gibbs_sample(
                    topic_num, topic, doc_topic.ptr(), 
                    state + wordid * topic_num, 
                    state + voc_size * topic_num, alpha, beta);
                doc_topic[word_index] = retopic;
                doc_topic[topic]--;
                doc_topic[retopic]++;

                if(iter_num == 1){
                    state[voc_size * topic_num + topic]--;
                    state[voc_size * topic_num + retopic]++;
                    state[wordid * topic_num + topic]--;
                    state[wordid * topic_num + retopic]++;
                }
                word_index++;
            }
        }
    }
    
    return doc_topic;
}
예제 #14
0
	inline T __sum(T x) { return __sum(a, x)+(__sum(b, size)-__sum(b, x))*x; }
예제 #15
0
	inline long long sum(int l, int r) { return __sum(r)-__sum(l-1); }
예제 #16
0
template <class U, class B> typename __sumtype2<typename U::for_in_unit,B>::type __sum(U *iter, B b) {
    typename __sumtype1<typename U::for_in_unit>::type result1 = __sum(iter);
    return __add((typename __sumtype2<typename U::for_in_unit,B>::type)b, (typename __sumtype2<typename U::for_in_unit,B>::type)result1);
}