Example #1
0
void CVwRegressor::load_regressor(char* file)
{
	CIOBuffer source;
	int32_t fd = source.open_file(file, 'r');

	if (fd < 0)
		SG_SERROR("Unable to open file for loading regressor!\n")

	// Read version info
	vw_size_t v_length;
	source.read_file((char*)&v_length, sizeof(v_length));
	char* t = SG_MALLOC(char, v_length);
	source.read_file(t,v_length);
	if (strcmp(t,env->vw_version) != 0)
	{
		SG_FREE(t);
		SG_SERROR("Regressor source has an incompatible VW version!\n")
	}
void CVwRegressor::load_regressor(char* file)
{
	CIOBuffer source;
	int32_t fd = source.open_file(file, 'r');

	if (fd < 0)
		SG_SERROR("Unable to open file for loading regressor!\n");

	// Read version info
	size_t v_length;
	source.read_file((char*)&v_length, sizeof(v_length));
	char t[v_length];
	source.read_file(t,v_length);
	if (strcmp(t,env->vw_version) != 0)
		SG_SERROR("Regressor source has an incompatible VW version!\n");

	// Read min and max label
	source.read_file((char*)&env->min_label, sizeof(env->min_label));
	source.read_file((char*)&env->max_label, sizeof(env->max_label));

	// Read num_bits, multiple sources are not supported
	size_t local_num_bits;
	source.read_file((char *)&local_num_bits, sizeof(local_num_bits));

	if ((size_t) env->num_bits != local_num_bits)
		SG_SERROR("Wrong number of bits in regressor source!\n");

	env->num_bits = local_num_bits;

	size_t local_thread_bits;
	source.read_file((char*)&local_thread_bits, sizeof(local_thread_bits));

	env->thread_bits = local_thread_bits;

	int32_t len;
	source.read_file((char *)&len, sizeof(len));

	// Read paired namespace information
	DynArray<char*> local_pairs;
	for (; len > 0; len--)
	{
		char pair[3];
		source.read_file(pair, sizeof(char)*2);
		pair[2]='\0';
		local_pairs.push_back(pair);
	}

	env->pairs = local_pairs;

	// Initialize the weight vector
	if (weight_vectors)
		SG_FREE(weight_vectors);
	init(env);

	size_t local_ngram;
	source.read_file((char*)&local_ngram, sizeof(local_ngram));
	size_t local_skips;
	source.read_file((char*)&local_skips, sizeof(local_skips));

	env->ngram = local_ngram;
	env->skips = local_skips;

	// Read individual weights
	size_t stride = env->stride;
	while (true)
	{
		uint32_t hash;
		ssize_t hash_bytes = source.read_file((char *)&hash, sizeof(hash));
		if (hash_bytes <= 0)
			break;

		float32_t w = 0.;
		ssize_t weight_bytes = source.read_file((char *)&w, sizeof(float32_t));
		if (weight_bytes <= 0)
			break;

		size_t num_threads = env->num_threads();

		weight_vectors[hash % num_threads][(hash*stride)/num_threads]
			= weight_vectors[hash % num_threads][(hash*stride)/num_threads] + w;
	}
	source.close_file();
}