/** ------------------------------------------------------------------ ** @internal ** @brief MEX driver **/ void mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[]) { vl_uint8 *data ; enum {IN_DATA = 0, IN_K, IN_NLEAVES, IN_END} ; enum {OUT_TREE = 0, OUT_ASGN} ; int M, N, K = 2, depth = 0 ; int opt ; int next = IN_END ; mxArray const *optarg ; int nleaves = 1 ; int method_type = VL_IKM_LLOYD ; int max_niters = 200 ; int verb = 0 ; VlHIKMTree* tree ; VL_USE_MATLAB_ENV ; /* ------------------------------------------------------------------ * Check the arguments * --------------------------------------------------------------- */ if (nin < 3) { mexErrMsgTxt ("At least three arguments required."); } else if (nout > 2) { mexErrMsgTxt ("Too many output arguments."); } if (mxGetClassID (in[IN_DATA]) != mxUINT8_CLASS) { mexErrMsgTxt ("DATA must be of class UINT8."); } if (! vlmxIsPlainScalar (in[IN_NLEAVES]) || (nleaves = (int) *mxGetPr (in[IN_NLEAVES])) < 1) { mexErrMsgTxt ("NLEAVES must be a scalar not smaller than 2.") ; } M = mxGetM (in[IN_DATA]); /* n of components */ N = mxGetN (in[IN_DATA]); /* n of elements */ if (! vlmxIsPlainScalar (in[IN_K]) || (K = (int) *mxGetPr (in[IN_K])) > N ) { mexErrMsgTxt ("Cannot have more clusters than data.") ; } data = (vl_uint8 *) mxGetPr (in[IN_DATA]) ; while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) { char buf [1024] ; switch (opt) { case opt_verbose : ++ verb ; break ; case opt_max_niters : if (!vlmxIsPlainScalar(optarg) || (max_niters = (int) *mxGetPr(optarg)) < 1) { mexErrMsgTxt("MaxNiters must be not smaller than 1.") ; } break ; case opt_method : if (!vlmxIsString (optarg, -1)) { mexErrMsgTxt("'Method' must be a string.") ; } if (mxGetString (optarg, buf, sizeof(buf))) { mexErrMsgTxt("Option argument too long.") ; } if (strcmp("lloyd", buf) == 0) { method_type = VL_IKM_LLOYD ; } else if (strcmp("elkan", buf) == 0) { method_type = VL_IKM_ELKAN ; } else { mexErrMsgTxt("Unknown cost type.") ; } break ; default : abort() ; break ; } } /* --------------------------------------------------------------- * Do the job * ------------------------------------------------------------ */ depth = VL_MAX(1, ceil (log (nleaves) / log(K))) ; tree = vl_hikm_new (method_type) ; if (verb) { mexPrintf("hikmeans: # dims: %d\n", M) ; mexPrintf("hikmeans: # data: %d\n", N) ; mexPrintf("hikmeans: K: %d\n", K) ; mexPrintf("hikmeans: depth: %d\n", depth) ; } vl_hikm_set_verbosity (tree, verb) ; vl_hikm_init (tree, M, K, depth) ; vl_hikm_train (tree, data, N) ; out[OUT_TREE] = hikm_to_matlab (tree) ; if (nout > 1) { vl_uint *asgn ; int j ; out [OUT_ASGN] = mxCreateNumericMatrix (vl_hikm_get_depth (tree), N, mxUINT32_CLASS, mxREAL) ; asgn = mxGetData(out[OUT_ASGN]) ; vl_hikm_push (tree, asgn, data, N) ; for (j = 0 ; j < N*depth ; ++ j) asgn [j] ++ ; } if (verb) { mexPrintf("hikmeans: done.\n") ; } /* vl_hikm_delete (tree) ; */ }
void hikm() { const vl_size data_dim = 2; const vl_size num_data = 1000; vl_uint8 data[data_dim * num_data] = { 0, }; for (vl_size i = 0; i < data_dim * num_data; ++i) data[i] = (vl_uint8)std::floor((((float)std::rand() / RAND_MAX) * 255) + 0.5f); // std::cout << "start processing ..." << std::endl; const VlIKMAlgorithms algorithm = VL_IKM_LLOYD; const vl_size num_clusters = 3; // number of clusters per node. const vl_size tree_depth = 5; const vl_size num_max_iterations = 100; VlHIKMTree *hikm = vl_hikm_new(algorithm); vl_hikm_set_max_niters(hikm, num_max_iterations); // initilization. vl_hikm_init(hikm, data_dim, num_clusters, tree_depth); // training. vl_hikm_train(hikm, data, num_data); { //const int ndims = vl_hikm_get_ndims(hikm); // dim. of data. //const int K = vl_hikm_get_K(hikm); // num. of clusters. //const int depth = vl_hikm_get_depth(hikm); // tree depth. //std::cout << "ndims: " << ndims << ", K: " << K << ", depth: " << depth << std::endl; const VlHIKMNode *root = vl_hikm_get_root(hikm); // TODO [implement] >> visualize the tree. // { vl_uint assignments[num_data * tree_depth] = { 0, }; vl_hikm_push(hikm, assignments, data, num_data); for (int i = 0; i < num_data; ++i) { std::cout << '('; for (int j = 0; j < data_dim; ++j) std::cout << (0 == j ? "" : ",") << (int)data[i * data_dim + j]; // TODO [check] >> is it correct? std::cout << ") => "; for (int k = 0; k < tree_depth; ++k) std::cout << (0 == k ? "" : "-") << assignments[i * tree_depth + k]; // TODO [check] >> is it correct? std::cout << std::endl; } } } std::cout << "end processing ..." << std::endl; // if (hikm) { vl_hikm_delete(hikm); hikm = NULL; } }