コード例 #1
0
/*****************************************************************************
  FUNCTION : cc_getErr
  PURPOSE  : get sum of squared errors (sse) = (o_actual - y_desired)^2
  NOTES    :

  UPDATE   : 19.01.96
******************************************************************************/
float cc_getErr (int StartPattern, int EndPattern)
{
    int p=0, sub, start, end, n,  pat, dummy;
    float sse=0, devit,error;
    register Patterns out_pat;
    register struct Unit *OutputUnitPtr;
    int Correct;
    int WhichWin,CorrWin;
    float MaxAct;

    KernelErrorCode = kr_initSubPatternOrder(StartPattern,EndPattern);
    ERROR_CHECK;
    cc_getPatternParameter(StartPattern,EndPattern,&start,&end,&n);
    ERROR_CHECK;
    SumSqError = 0.0;

    for(p=start; p<=end;p++){
	Correct=TRUE;
	MaxAct=0.0;
	cc_getActivationsForActualPattern(p,start,&pat,&sub);
	PROPAGATE_THROUGH_OUTPUT_LAYER(OutputUnitPtr,dummy,p);

	out_pat = kr_getSubPatData(pat,sub,OUTPUT,NULL);

	FOR_ALL_OUTPUT_UNITS(OutputUnitPtr,dummy){
	    if (*out_pat > 0.5) CorrWin = dummy;
	    devit =  OutputUnitPtr->Out.output - *(out_pat++);
	    if  (OutputUnitPtr->Out.output > MaxAct)
	    {
		MaxAct=OutputUnitPtr->Out.output;
		WhichWin=dummy;
	    }
	    if (abs(devit) > 0.2) Correct=FALSE;
	    sse += devit*devit;
	    error = devit * 
		((*OutputUnitPtr->act_deriv_func)(OutputUnitPtr) + cc_fse);
	    SumSqError += error*error;
	}
    }
    cc_actualNetSaved=TRUE;
    return sse;
}
コード例 #2
0
krui_err
SnnsCLib::LEARN_MonteCarlo(int start_pattern, int end_pattern, float *parameterInArray,
		 int NoOfInParams, float **parameterOutArray,
		 int *NoOfOutParams)
{
    //static float    LEARN_MonteCarlo_OutParameter[1]; /* LEARN_MonteCarlo_OutParameter[0] stores the learning error  */
    int             ret_code, pattern_no, sub_pat_no;
    float           error;
    register FlagWord flags;
    register struct Link *link_ptr;
    register struct Unit *unit_ptr;
    register struct Site *site_ptr;

    if (NoOfInParams < 2)
	return (KRERR_PARAMETERS); /* Not enough input parameters  */
    *NoOfOutParams = 1;		/* One return value is available (the
				 * learning error)  */
    *parameterOutArray = LEARN_MonteCarlo_OutParameter; /* set the output parameter reference  */
    ret_code = KRERR_NO_ERROR;	/* reset return code  */

    if (NetModified) {		/* Net has been modified */

	/* count the no. of I/O units and check the patterns  */
	ret_code = kr_IOCheck();
	if (ret_code < KRERR_NO_ERROR)
	    return (ret_code);

	/* sort units by topology and by topologic type  */
	ret_code = kr_topoSort(TOPOLOGICAL_FF);
	if ((ret_code != KRERR_NO_ERROR) && (ret_code != KRERR_DEAD_UNITS))
	    return (ret_code);
	MinimumError = 10000000;
	NetModified = FALSE;
    }
    if (NetInitialize || LearnFuncHasChanged) {	/* Net has been modified */
	MinimumError = 10000000;
    }
    /* randomize weigths and bias */

    FOR_ALL_UNITS(unit_ptr) {
	unit_ptr->bias = (FlintType) u_drand48() *
	    (LEARN_PARAM2(parameterInArray) - LEARN_PARAM1(parameterInArray))
		+ LEARN_PARAM1(parameterInArray);
	flags = unit_ptr->flags;
	if ((flags & UFLAG_IN_USE) == UFLAG_IN_USE) { /* unit is in use  */
	    unit_ptr->value_a = (FlintType) 0;

	    if (flags & UFLAG_SITES) { /* unit has sites  */
		FOR_ALL_SITES_AND_LINKS(unit_ptr, site_ptr, link_ptr)
		    link_ptr->weight = (FlintType) u_drand48() *
			(LEARN_PARAM2(parameterInArray) -
			 LEARN_PARAM1(parameterInArray)) +
			     LEARN_PARAM1(parameterInArray);
	    } else {		/* unit has no sites   */
		if (flags & UFLAG_DLINKS) { /* unit has direct links */
		    FOR_ALL_LINKS(unit_ptr, link_ptr)
			link_ptr->weight = (FlintType) u_drand48() *
			    (LEARN_PARAM2(parameterInArray) -
			     LEARN_PARAM1(parameterInArray)) +
				 LEARN_PARAM1(parameterInArray);
		}
	    }
	}
    }

    /* compute the necessary sub patterns */
    KernelErrorCode = kr_initSubPatternOrder(start_pattern, end_pattern);
    if (KernelErrorCode != KRERR_NO_ERROR)
	return (KernelErrorCode);
    NET_ERROR(LEARN_MonteCarlo_OutParameter) = 0.0; /* reset network error value  */

    /* calculate performance of new net */
    while (kr_getSubPatternByOrder(&pattern_no, &sub_pat_no)) {
	propagateNetForward(pattern_no, sub_pat_no);
	/* Forward propagation */
	if ((error = calculate_SS_error(pattern_no, sub_pat_no)) == -1)
	    return (-1);
	NET_ERROR(LEARN_MonteCarlo_OutParameter) += error;
    }

    /* store weights and bias if error decreased */
    if (NET_ERROR(LEARN_MonteCarlo_OutParameter) < MinimumError) {
	MinimumError = NET_ERROR(LEARN_MonteCarlo_OutParameter);
	FOR_ALL_UNITS(unit_ptr) {
	    flags = unit_ptr->flags;
	    unit_ptr->value_b = unit_ptr->bias;
	    if ((flags & UFLAG_IN_USE) == UFLAG_IN_USE) {
		/* unit is in use  */
		if (flags & UFLAG_SITES) { /* unit has sites  */
		    FOR_ALL_SITES_AND_LINKS(unit_ptr, site_ptr, link_ptr)
			link_ptr->value_b = link_ptr->weight;
		} else {	/* unit has no sites   */
		    if (flags & UFLAG_DLINKS) {
			/* unit has direct links         */
			FOR_ALL_LINKS(unit_ptr, link_ptr)
			    link_ptr->value_b = link_ptr->weight;
		    }
		}
	    }
	}

    }