Ejemplo n.º 1
0
/** ------------------------------------------------------------------
 ** @internal
 ** @brief MEX driver
 **/
void mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[])
{
  vl_uint8 *data ;
  enum {IN_DATA = 0, IN_K, IN_NLEAVES, IN_END} ;
  enum {OUT_TREE = 0, OUT_ASGN} ;
  int M, N, K = 2, depth = 0 ;

  int             opt ;
  int             next = IN_END ;
  mxArray const  *optarg ;

  int nleaves     = 1 ;
  int method_type = VL_IKM_LLOYD ;
  int max_niters  = 200 ;
  int verb        = 0 ;

  VlHIKMTree* tree ;

  VL_USE_MATLAB_ENV ;

  /* ------------------------------------------------------------------
   *                                                Check the arguments
   * --------------------------------------------------------------- */

  if (nin < 3)
    {
      mexErrMsgTxt ("At least three arguments required.");
    }
  else if (nout > 2)
    {
      mexErrMsgTxt ("Too many output arguments.");
    }

  if (mxGetClassID (in[IN_DATA]) != mxUINT8_CLASS)
    {
      mexErrMsgTxt ("DATA must be of class UINT8.");
    }

  if (! vlmxIsPlainScalar (in[IN_NLEAVES])           ||
      (nleaves = (int) *mxGetPr (in[IN_NLEAVES])) < 1) {
    mexErrMsgTxt ("NLEAVES must be a scalar not smaller than 2.") ;
  }

  M = mxGetM (in[IN_DATA]);   /* n of components */
  N = mxGetN (in[IN_DATA]);   /* n of elements */

  if (! vlmxIsPlainScalar (in[IN_K])         ||
      (K = (int) *mxGetPr (in[IN_K])) > N  ) {
    mexErrMsgTxt ("Cannot have more clusters than data.") ;
  }

  data = (vl_uint8 *) mxGetPr (in[IN_DATA]) ;

  while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
    char buf [1024] ;

    switch (opt) {

    case opt_verbose :
      ++ verb ;
      break ;

    case opt_max_niters :
      if (!vlmxIsPlainScalar(optarg) || (max_niters = (int) *mxGetPr(optarg)) < 1) {
        mexErrMsgTxt("MaxNiters must be not smaller than 1.") ;
      }
      break ;

    case opt_method :
      if (!vlmxIsString (optarg, -1)) {
        mexErrMsgTxt("'Method' must be a string.") ;
      }
      if (mxGetString (optarg, buf, sizeof(buf))) {
        mexErrMsgTxt("Option argument too long.") ;
      }
      if (strcmp("lloyd", buf) == 0) {
        method_type = VL_IKM_LLOYD ;
      } else if (strcmp("elkan", buf) == 0) {
        method_type = VL_IKM_ELKAN ;
      } else {
        mexErrMsgTxt("Unknown cost type.") ;
      }

      break ;

    default :
      abort() ;
      break ;
    }
  }

  /* ---------------------------------------------------------------
   *                                                      Do the job
   * ------------------------------------------------------------ */

  depth = VL_MAX(1, ceil (log (nleaves) / log(K))) ;
  tree  = vl_hikm_new  (method_type) ;

  if (verb) {
    mexPrintf("hikmeans: # dims: %d\n", M) ;
    mexPrintf("hikmeans: # data: %d\n", N) ;
    mexPrintf("hikmeans: K: %d\n", K) ;
    mexPrintf("hikmeans: depth: %d\n", depth) ;
  }

  vl_hikm_set_verbosity (tree, verb) ;
  vl_hikm_init          (tree, M, K, depth) ;
  vl_hikm_train         (tree, data, N) ;

  out[OUT_TREE] = hikm_to_matlab (tree) ;

  if (nout > 1) {
    vl_uint *asgn ;
    int j ;
    out [OUT_ASGN] = mxCreateNumericMatrix
      (vl_hikm_get_depth (tree), N, mxUINT32_CLASS, mxREAL) ;
    asgn = mxGetData(out[OUT_ASGN]) ;
    vl_hikm_push (tree, asgn, data, N) ;
    for (j = 0 ; j < N*depth ; ++ j) asgn [j] ++ ;
  }

  if (verb) {
    mexPrintf("hikmeans: done.\n") ;
  }

  /* vl_hikm_delete (tree) ; */
}
void hikm()
{
	const vl_size data_dim = 2;
	const vl_size num_data = 1000;

	vl_uint8 data[data_dim * num_data] = { 0, };
	for (vl_size i = 0; i < data_dim * num_data; ++i)
		data[i] = (vl_uint8)std::floor((((float)std::rand() / RAND_MAX) * 255) + 0.5f);

	//
	std::cout << "start processing ..." << std::endl;

	const VlIKMAlgorithms algorithm = VL_IKM_LLOYD;
	const vl_size num_clusters = 3;  // number of clusters per node.
	const vl_size tree_depth = 5;
	const vl_size num_max_iterations = 100;

	VlHIKMTree *hikm = vl_hikm_new(algorithm);

	vl_hikm_set_max_niters(hikm, num_max_iterations);

	// initilization.
	vl_hikm_init(hikm, data_dim, num_clusters, tree_depth);

	// training.
	vl_hikm_train(hikm, data, num_data);
	{
		//const int ndims = vl_hikm_get_ndims(hikm);  // dim. of data.
		//const int K = vl_hikm_get_K(hikm);  // num. of clusters.
		//const int depth = vl_hikm_get_depth(hikm);  // tree depth.
		//std::cout << "ndims: " << ndims << ", K: " << K << ", depth: " << depth << std::endl;

		const VlHIKMNode *root = vl_hikm_get_root(hikm);

		// TODO [implement] >> visualize the tree.

		//
		{
			vl_uint assignments[num_data * tree_depth] = { 0, };
			vl_hikm_push(hikm, assignments, data, num_data);
			for (int i = 0; i < num_data; ++i)
			{
				std::cout << '(';
				for (int j = 0; j < data_dim; ++j)
					std::cout << (0 == j ? "" : ",") << (int)data[i * data_dim + j];  // TODO [check] >> is it correct?
				std::cout << ") => ";
				for (int k = 0; k < tree_depth; ++k)
					std::cout << (0 == k ? "" : "-") << assignments[i * tree_depth + k];  // TODO [check] >> is it correct?
				std::cout << std::endl;
			}
		}
	}

	std::cout << "end processing ..." << std::endl;

	//
	if (hikm)
	{
		vl_hikm_delete(hikm);
		hikm = NULL;
	}
}
Ejemplo n.º 3
0
/** @brief MEX driver entry point 
 **/
void mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[])
{
  enum {IN_TREE = 0, IN_DATA, IN_END} ;
  enum {OUT_ASGN = 0} ;
  vl_uint8 const *data; 

  int             opt ;
  int             next = IN_END ;
  mxArray const  *optarg ;
  
  int N = 0 ;
  int method_type = VL_IKM_LLOYD ;
  int verb = 0 ;
  
  /* -----------------------------------------------------------------
   *                                               Check the arguments
   * -------------------------------------------------------------- */
  if (nin < 2)
    mexErrMsgTxt ("At least two arguments required.");
  else if (nout > 1)
    mexErrMsgTxt ("Too many output arguments.");
  
  if (mxGetClassID (in[IN_DATA]) != mxUINT8_CLASS) {
    mexErrMsgTxt ("DATA must be of class UINT8");
  }
  
  N = mxGetN (in[IN_DATA]);   /* n of elements */
  data = (vl_uint8 *) mxGetPr (in[IN_DATA]);
  
  while ((opt = uNextOption(in, nin, options, &next, &optarg)) >= 0) {
    char buf [1024] ;
    
    switch (opt) {
      
    case opt_verbose :
      ++ verb ;
      break ;
      
    case opt_method :
      if (!uIsString (optarg, -1)) {
        mexErrMsgTxt("'Method' must be a string.") ;        
      }
      if (mxGetString (optarg, buf, sizeof(buf))) {
        mexErrMsgTxt("Option argument too long.") ;
      }
      if (strcmp("lloyd", buf) == 0) {
        method_type = VL_IKM_LLOYD ;
      } else if (strcmp("elkan", buf) == 0) {
        method_type = VL_IKM_ELKAN ;
      } else {
        mexErrMsgTxt("Unknown cost type.") ;
      }
      
      break ;
      
    default :
      assert(0) ;
      break ;
    }
  }

  /* -----------------------------------------------------------------
   *                                                        Do the job
   * -------------------------------------------------------------- */

  {
    VlHIKMTree *tree ;
    vl_uint  *ids  ;
    int j;
    int depth ;

    tree  = matlab_to_hikm (in[IN_TREE], method_type) ;
    depth = vl_hikm_get_depth (tree) ;

    if (verb) {
      mexPrintf("hikmeanspush: ndims: %d K: %d depth: %d\n",
                vl_hikm_get_ndims (tree), 
                vl_hikm_get_K (tree),
                depth) ;
    }
    
    out[OUT_ASGN] = mxCreateNumericMatrix (depth, N, mxUINT32_CLASS, mxREAL) ;
    ids = mxGetData (out[OUT_ASGN]) ;

    vl_hikm_push   (tree, ids, data, N) ;    
    vl_hikm_delete (tree) ;
    
    for (j = 0 ; j < N*depth ; j++) ids [j] ++ ;
  }
}
Ejemplo n.º 4
0
/** ----------------------------------------------------------------
 **
 **/
PyObject * vl_hikmeanspush_python(
		VlHIKMTree_python & inTree,
		PyArrayObject & inData,
		int verb,
		char * method)
{
	vl_uint8 const *data;

	int N = 0;
	int method_type = VL_IKM_LLOYD;

	N = inData.dimensions[1]; /* n of elements */

#ifdef DEBUG
	printf("n of elements: %d\n", N);
	printf("n of split: %d\n", inTree.K);
	printf("depth: %d\n", inTree.depth);
	printf("n of children: %d\n", inTree.sub.size());
#endif

	data = (vl_uint8 *) inData.data;

	if (strcmp("lloyd", method) == 0) {
		method_type = VL_IKM_LLOYD;
	} else if (strcmp("elkan", method) == 0) {
		method_type = VL_IKM_ELKAN;
	} else {
		printf("Unknown cost type.\n");
	}

	/* -----------------------------------------------------------------
	 *                                                        Do the job
	 * -------------------------------------------------------------- */

	VlHIKMTree * tree;
	vl_uint *ids;
	int j;
	int depth;

	tree = python_to_hikm(inTree, method_type);
	depth = vl_hikm_get_depth(tree);

	if (verb) {
		printf("hikmeanspush: ndims: %d K: %d depth: %d\n", vl_hikm_get_ndims(
			tree), vl_hikm_get_K(tree), depth);
	}

	npy_intp dims[2] = { depth, N };
	PyArrayObject * out_asgn = (PyArrayObject*) PyArray_NewFromDescr(
		&PyArray_Type, PyArray_DescrFromType(PyArray_INT32), 2, dims, NULL,
		NULL, NPY_F_CONTIGUOUS, NULL);

	ids = (vl_uint *) out_asgn->data;

	vl_hikm_push(tree, ids, data, N);
	vl_hikm_delete(tree);

	for (j = 0; j < N * depth; j++)
		ids[j]++;

	return PyArray_Return(out_asgn);

}