Exemple #1
0
/** ------------------------------------------------------------------
 ** @internal
 ** @brief MEX driver
 **/
void mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[])
{
  vl_uint8 *data ;
  enum {IN_DATA = 0, IN_K, IN_NLEAVES, IN_END} ;
  enum {OUT_TREE = 0, OUT_ASGN} ;
  int M, N, K = 2, depth = 0 ;

  int             opt ;
  int             next = IN_END ;
  mxArray const  *optarg ;

  int nleaves     = 1 ;
  int method_type = VL_IKM_LLOYD ;
  int max_niters  = 200 ;
  int verb        = 0 ;

  VlHIKMTree* tree ;

  VL_USE_MATLAB_ENV ;

  /* ------------------------------------------------------------------
   *                                                Check the arguments
   * --------------------------------------------------------------- */

  if (nin < 3)
    {
      mexErrMsgTxt ("At least three arguments required.");
    }
  else if (nout > 2)
    {
      mexErrMsgTxt ("Too many output arguments.");
    }

  if (mxGetClassID (in[IN_DATA]) != mxUINT8_CLASS)
    {
      mexErrMsgTxt ("DATA must be of class UINT8.");
    }

  if (! vlmxIsPlainScalar (in[IN_NLEAVES])           ||
      (nleaves = (int) *mxGetPr (in[IN_NLEAVES])) < 1) {
    mexErrMsgTxt ("NLEAVES must be a scalar not smaller than 2.") ;
  }

  M = mxGetM (in[IN_DATA]);   /* n of components */
  N = mxGetN (in[IN_DATA]);   /* n of elements */

  if (! vlmxIsPlainScalar (in[IN_K])         ||
      (K = (int) *mxGetPr (in[IN_K])) > N  ) {
    mexErrMsgTxt ("Cannot have more clusters than data.") ;
  }

  data = (vl_uint8 *) mxGetPr (in[IN_DATA]) ;

  while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
    char buf [1024] ;

    switch (opt) {

    case opt_verbose :
      ++ verb ;
      break ;

    case opt_max_niters :
      if (!vlmxIsPlainScalar(optarg) || (max_niters = (int) *mxGetPr(optarg)) < 1) {
        mexErrMsgTxt("MaxNiters must be not smaller than 1.") ;
      }
      break ;

    case opt_method :
      if (!vlmxIsString (optarg, -1)) {
        mexErrMsgTxt("'Method' must be a string.") ;
      }
      if (mxGetString (optarg, buf, sizeof(buf))) {
        mexErrMsgTxt("Option argument too long.") ;
      }
      if (strcmp("lloyd", buf) == 0) {
        method_type = VL_IKM_LLOYD ;
      } else if (strcmp("elkan", buf) == 0) {
        method_type = VL_IKM_ELKAN ;
      } else {
        mexErrMsgTxt("Unknown cost type.") ;
      }

      break ;

    default :
      abort() ;
      break ;
    }
  }

  /* ---------------------------------------------------------------
   *                                                      Do the job
   * ------------------------------------------------------------ */

  depth = VL_MAX(1, ceil (log (nleaves) / log(K))) ;
  tree  = vl_hikm_new  (method_type) ;

  if (verb) {
    mexPrintf("hikmeans: # dims: %d\n", M) ;
    mexPrintf("hikmeans: # data: %d\n", N) ;
    mexPrintf("hikmeans: K: %d\n", K) ;
    mexPrintf("hikmeans: depth: %d\n", depth) ;
  }

  vl_hikm_set_verbosity (tree, verb) ;
  vl_hikm_init          (tree, M, K, depth) ;
  vl_hikm_train         (tree, data, N) ;

  out[OUT_TREE] = hikm_to_matlab (tree) ;

  if (nout > 1) {
    vl_uint *asgn ;
    int j ;
    out [OUT_ASGN] = mxCreateNumericMatrix
      (vl_hikm_get_depth (tree), N, mxUINT32_CLASS, mxREAL) ;
    asgn = mxGetData(out[OUT_ASGN]) ;
    vl_hikm_push (tree, asgn, data, N) ;
    for (j = 0 ; j < N*depth ; ++ j) asgn [j] ++ ;
  }

  if (verb) {
    mexPrintf("hikmeans: done.\n") ;
  }

  /* vl_hikm_delete (tree) ; */
}
void hikm()
{
	const vl_size data_dim = 2;
	const vl_size num_data = 1000;

	vl_uint8 data[data_dim * num_data] = { 0, };
	for (vl_size i = 0; i < data_dim * num_data; ++i)
		data[i] = (vl_uint8)std::floor((((float)std::rand() / RAND_MAX) * 255) + 0.5f);

	//
	std::cout << "start processing ..." << std::endl;

	const VlIKMAlgorithms algorithm = VL_IKM_LLOYD;
	const vl_size num_clusters = 3;  // number of clusters per node.
	const vl_size tree_depth = 5;
	const vl_size num_max_iterations = 100;

	VlHIKMTree *hikm = vl_hikm_new(algorithm);

	vl_hikm_set_max_niters(hikm, num_max_iterations);

	// initilization.
	vl_hikm_init(hikm, data_dim, num_clusters, tree_depth);

	// training.
	vl_hikm_train(hikm, data, num_data);
	{
		//const int ndims = vl_hikm_get_ndims(hikm);  // dim. of data.
		//const int K = vl_hikm_get_K(hikm);  // num. of clusters.
		//const int depth = vl_hikm_get_depth(hikm);  // tree depth.
		//std::cout << "ndims: " << ndims << ", K: " << K << ", depth: " << depth << std::endl;

		const VlHIKMNode *root = vl_hikm_get_root(hikm);

		// TODO [implement] >> visualize the tree.

		//
		{
			vl_uint assignments[num_data * tree_depth] = { 0, };
			vl_hikm_push(hikm, assignments, data, num_data);
			for (int i = 0; i < num_data; ++i)
			{
				std::cout << '(';
				for (int j = 0; j < data_dim; ++j)
					std::cout << (0 == j ? "" : ",") << (int)data[i * data_dim + j];  // TODO [check] >> is it correct?
				std::cout << ") => ";
				for (int k = 0; k < tree_depth; ++k)
					std::cout << (0 == k ? "" : "-") << assignments[i * tree_depth + k];  // TODO [check] >> is it correct?
				std::cout << std::endl;
			}
		}
	}

	std::cout << "end processing ..." << std::endl;

	//
	if (hikm)
	{
		vl_hikm_delete(hikm);
		hikm = NULL;
	}
}