コード例 #1
0
ファイル: lwpr_binio.c プロジェクト: mc01104/CTR
int lwpr_io_read_rf(FILE *fp, LWPR_SubModel *sub) {
   char str[5];
   int ok;
   int nIn = sub->model->nIn;
   int nInS = sub->model->nInStore;
   int nReg;
   LWPR_ReceptiveField *RF;
   
   ok = (int) fread(str, 1, 4, fp);
   if (ok!=4) return 0;
   str[4]=0;
   if (strcmp(str,"[RF]")!=0) return 0;

   ok = lwpr_io_read_int(fp, &nReg);
   if (ok!=1 || nReg<=0 || nReg>nIn) return 0;
   
   RF = lwpr_aux_add_rf(sub,nReg);
   if (RF==NULL) return 0;
   
   ok &= lwpr_io_read_matrix(fp,nIn,nInS,nIn,RF->D);
   ok &= lwpr_io_read_matrix(fp,nIn,nInS,nIn,RF->M);   
   ok &= lwpr_io_read_matrix(fp,nIn,nInS,nIn,RF->alpha);   
   ok &= lwpr_io_read_scalar(fp,&RF->beta0);
   ok &= lwpr_io_read_vector(fp,nReg,RF->beta);
   ok &= lwpr_io_read_vector(fp,nIn,RF->c);   
   ok &= lwpr_io_read_matrix(fp,nIn,nInS,nReg,RF->SXresYres);
   ok &= lwpr_io_read_vector(fp,nReg,RF->SSs2);   
   ok &= lwpr_io_read_vector(fp,nReg,RF->SSYres);      
   ok &= lwpr_io_read_matrix(fp,nIn,nInS,nReg,RF->SSXres);   
   ok &= lwpr_io_read_matrix(fp,nIn,nInS,nReg,RF->U);      
   ok &= lwpr_io_read_matrix(fp,nIn,nInS,nReg,RF->P);      
   ok &= lwpr_io_read_vector(fp,nReg,RF->H);            
   ok &= lwpr_io_read_vector(fp,nReg,RF->r);  
   ok &= lwpr_io_read_matrix(fp,nIn,nInS,nIn,RF->h);                      
   ok &= lwpr_io_read_matrix(fp,nIn,nInS,nIn,RF->b);                   
   ok &= lwpr_io_read_vector(fp,nReg,RF->sum_w);  
   ok &= lwpr_io_read_vector(fp,nReg,RF->sum_e_cv2);  
   ok &= lwpr_io_read_scalar(fp,&RF->sum_e2);
   ok &= lwpr_io_read_scalar(fp,&RF->SSp);
   ok &= lwpr_io_read_vector(fp,nReg,RF->n_data);     
   ok &= lwpr_io_read_int(fp,&RF->trustworthy);   
   ok &= lwpr_io_read_vector(fp,nReg,RF->lambda);     
   ok &= lwpr_io_read_vector(fp,nIn,RF->mean_x);   
   ok &= lwpr_io_read_vector(fp,nIn,RF->var_x);         
   ok &= lwpr_io_read_scalar(fp,&RF->w);   
   ok &= lwpr_io_read_vector(fp,nReg,RF->s);     
   return ok;
}
コード例 #2
0
ファイル: lwpr_xml.c プロジェクト: RhobanDeps/lwpr
void lwpr_xml_start_element(void *userData, const char *name, const char **atts) {
   int M=0, N=0;
   const char **at;
   const char *fieldName;
   int wishM,wishN;

   LWPR_ParserData *ud = (LWPR_ParserData *) userData;
   LWPR_Model *model = ud->model;
   LWPR_SubModel *sub=NULL;
   LWPR_ReceptiveField *RF=NULL;

   ud->readN = ud->readM = ud->N = ud->M = 0;

   if (model->sub!=NULL) {
      sub = &(model->sub[ud->curSub]);
      if (sub->rf != NULL) RF = sub->rf[ud->curRF];
   }

   if (!strcmp(name,"integer")) {
      ud->curType = 1;
      if (!lwpr_xml_parse_scalar(atts,&fieldName)) {
         lwpr_xml_error(ud,"<integer> element without name.\n");
         return;
      }
   } else if (!strcmp(name,"scalar")) {
      ud->curType = 2;
      if (!lwpr_xml_parse_scalar(atts,&fieldName)) {
         lwpr_xml_error(ud,"<scalar> element without name.");
         return;
      }
   } else if (!strcmp(name,"vector")) {
      ud->curType = 3;
      if (!lwpr_xml_parse_vector(atts,&fieldName,&N)) {
         lwpr_xml_error(ud,"Parse error: <vector> element without name or length.\n");
         return;
      }
   } else if (!strcmp(name,"matrix")) {
      ud->curType = 4;
      if (!lwpr_xml_parse_matrix(atts,&fieldName,&M,&N)) {
         lwpr_xml_error(ud,"Parse error: <matrix> element without name, rows or columns.\n");
         return;
      }
   }

   if (ud->level == 0) {
      if (!strcmp(name,"LWPR")) {
         int nIn = 0,nOut = 0;
         const char *model_name = NULL;
         LWPR_Kernel kern = LWPR_GAUSSIAN_KERNEL;
         at = atts;

         ud->curType = 0;
         while (at[0]!=NULL && at[1]!=NULL) {
            if (!strcmp(at[0],"name")) {
               model_name = at[1];
            } else if (!strcmp(at[0],"nIn")) {
               nIn = atoi(at[1]);
            } else if (!strcmp(at[0],"nOut")) {
               nOut = atoi(at[1]);
            } else if (!strcmp(at[0],"kernel")) {
               kern = LWPR_GAUSSIAN_KERNEL;
               if (!strcmp(at[1],"BiSquare")) {
                  kern = LWPR_BISQUARE_KERNEL;
               } else if (!strcmp(at[1],"Bisquare")) {
                  kern = LWPR_BISQUARE_KERNEL;
               } else {
                  if (strcmp(at[1],"Gaussian")) {
                     ud->numWarnings++;
                     if (ud->errFile) fprintf(ud->errFile,"Unknown kernel, using Gaussian.\n");
                  }
               }
            }
            at+=2;
         }
         if (nIn>0 && nOut > 0) {
            lwpr_init_model(model,nIn,nOut,model_name);
            model->kernel = kern;
         } else {
            ud->numErrors++;
            if (ud->errFile) fprintf(ud->errFile,"Error parsing LWPR element.\n");
         }
         ud->level = 1;
      } else {
         lwpr_xml_report_unknown(ud,name);
      }
      return;
   }

   if (ud->level == 1) {
      if (!strcmp(name,"SubModel")) {
         int out_dim=-1;
         int numRFS=0;
         ud->curType = 0;
         at = atts;
         while (at[0]!=NULL && at[1]!=NULL) {
            if (!strcmp(at[0],"out_dim")) {
               out_dim = atoi(at[1]);
            } else if (!strcmp(at[0],"numRFS")) {
               numRFS = atoi(at[1]);
            }
            at+=2;
         }
         if (out_dim >= 0 && out_dim < model->nOut) {
            lwpr_mem_alloc_sub(&(model->sub[out_dim]), numRFS + 16);
            ud->level = 2;
            ud->curSub = out_dim;
            ud->curRF = 0;
         } else {
            ud->numErrors++;
            if (ud->errFile) fprintf(ud->errFile,"Error parsing SubModel element.\n");
         }
         return;
      }
      switch(ud->curType) {
         case 0:
            lwpr_xml_report_unknown(ud,name);
            break;
         case 1:
            ud->N = 1;
            if (!strcmp(fieldName,"n_data")) {
               ud->curPtr = (void *) &(model->n_data);
            } else if (!strcmp(fieldName,"diag_only")) {
               ud->curPtr = (void *) &(model->diag_only);
            } else if (!strcmp(fieldName,"update_D")) {
               ud->curPtr = (void *) &(model->update_D);
            } else if (!strcmp(fieldName,"meta")) {
               ud->curPtr = (void *) &(model->meta);
            } else {
               lwpr_xml_report_unknown(ud,fieldName);
            }
            break;
         case 2:
            ud->N = 1;
            if (!strcmp(fieldName,"meta_rate")) {
               ud->curPtr = (void *) &(model->meta_rate);
            } else if (!strcmp(fieldName,"penalty")) {
               ud->curPtr = (void *) &(model->penalty);
            } else if (!strcmp(fieldName,"w_gen")) {
               ud->curPtr = (void *) &(model->w_gen);
            } else if (!strcmp(fieldName,"w_prune")) {
               ud->curPtr = (void *) &(model->w_prune);
            } else if (!strcmp(fieldName,"init_lambda")) {
               ud->curPtr = (void *) &(model->init_lambda);
            } else if (!strcmp(fieldName,"final_lambda")) {
               ud->curPtr = (void *) &(model->final_lambda);
            } else if (!strcmp(fieldName,"tau_lambda")) {
               ud->curPtr = (void *) &(model->tau_lambda);
            } else if (!strcmp(fieldName,"init_S2")) {
               ud->curPtr = (void *) &(model->init_S2);
            } else if (!strcmp(fieldName,"add_threshold")) {
               ud->curPtr = (void *) &(model->add_threshold);
            } else {
               lwpr_xml_report_unknown(ud,fieldName);
            }
            break;
         case 3:
            if (!strcmp(fieldName,"mean_x")) {
               ud->curPtr = (void *) model->mean_x;
               wishN = model->nIn;
            } else if (!strcmp(fieldName,"var_x")) {
               ud->curPtr = (void *) model->var_x;
               wishN = model->nIn;
            } else if (!strcmp(fieldName,"norm_in")) {
               ud->curPtr = (void *) model->norm_in;
               wishN = model->nIn;
            } else if (!strcmp(fieldName,"norm_out")) {
               ud->curPtr = (void *) model->norm_out;
               wishN = model->nOut;
            } else {
               lwpr_xml_report_unknown(ud,fieldName);
               break;
            }
            if (wishN != N) {
               lwpr_xml_dim_error(ud,fieldName,0,wishN);
            } else {
               ud->N = N;
            }
            break;
         case 4:
            wishN = wishM = model->nIn;
            if (!strcmp(fieldName,"init_alpha")) {
               ud->curPtr = (void *) model->init_alpha;
            } else if (!strcmp(fieldName,"init_D")) {
               ud->curPtr = (void *) model->init_D;
            } else if (!strcmp(fieldName,"init_M")) {
               ud->curPtr = (void *) model->init_M;
            } else {
               lwpr_xml_report_unknown(ud,fieldName);
               break;
            }
            if (wishN != N || wishM != M) {
               lwpr_xml_dim_error(ud,fieldName,wishM,wishN);
            } else {
               ud->M = ud->N = model->nIn;
               ud->MS = model->nInStore;
            }
            break;

      }

      return;
   }
   if (ud->level == 2) {
      if (!strcmp(name,"ReceptiveField")) {
         int nReg = 0;

         at = atts;
         while (at[0]!=NULL && at[1]!=NULL) {
            if (!strcmp(at[0],"nReg")) {
               nReg = atoi(at[1]);
            }
            at+=2;
         }
         if (nReg > 0) {
            RF = lwpr_aux_add_rf(sub,nReg);
            ud->level = 3;
         } else {
            ud->numErrors++;
            if (ud->errFile) fprintf(ud->errFile,"Error parsing ReceptiveField element %d/%d.\n",ud->curSub,ud->curRF);
         }

         ud->level = 3;

         return;
      }

      ud->curPtr = NULL;
      if (ud->curType == 1 && !strcmp(fieldName,"n_pruned")) {
         ud->curPtr = (void *) &sub->n_pruned;
         ud->N = 1;
         return;
      }
      if (ud->curType == 0) {
         lwpr_xml_report_unknown(ud,name);
      } else {
         lwpr_xml_report_unknown(ud,fieldName);
      }
      return;
   }
   if (ud->level == 3) {
      ud->curPtr = NULL;
      switch(ud->curType) {
         case 0:
            lwpr_xml_report_unknown(ud,name);
            break;
         case 1:
            if (!strcmp(fieldName,"trustworthy")) {
               ud->curPtr = (void *) &RF->trustworthy;
               ud->N = 1;
            } else {
               lwpr_xml_report_unknown(ud,fieldName);
            }
            break;
         case 2:
            ud->N = 1;
            if (!strcmp(fieldName,"beta0")) {
               ud->curPtr = (void *) &RF->beta0;
            } else if (!strcmp(fieldName,"sum_e2")) {
               ud->curPtr = (void *) &RF->sum_e2;
            } else if (!strcmp(fieldName,"SSp")) {
               ud->curPtr = (void *) &RF->SSp;
            } else if (!strcmp(fieldName,"w")) {
               ud->curPtr = (void *) &RF->w;
            } else {
               lwpr_xml_report_unknown(ud,fieldName);
            }
            break;
         case 3:
            ud->N = N;
            if (!strcmp(fieldName,"beta")) {
               ud->curPtr = (void *) RF->beta;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"c")) {
               ud->curPtr = (void *) RF->c;
               wishN = model->nIn;
            } else if (!strcmp(fieldName,"SSs2")) {
               ud->curPtr = (void *) RF->SSs2;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"SSYres")) {
               ud->curPtr = (void *) RF->SSYres;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"H")) {
               ud->curPtr = (void *) RF->H;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"r")) {
               ud->curPtr = (void *) RF->r;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"sum_w")) {
               ud->curPtr = (void *) RF->sum_w;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"sum_e_cv2")) {
               ud->curPtr = (void *) RF->sum_e_cv2;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"n_data")) {
               ud->curPtr = (void *) RF->n_data;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"lambda")) {
               ud->curPtr = (void *) RF->lambda;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"mean_x")) {
               ud->curPtr = (void *) RF->mean_x;
               wishN = model->nIn;
            } else if (!strcmp(fieldName,"var_x")) {
               ud->curPtr = (void *) RF->var_x;
               wishN = model->nIn;
            } else if (!strcmp(fieldName,"s")) {
               ud->curPtr = (void *) RF->s;
               wishN = RF->nReg;
            } else {
               lwpr_xml_report_unknown(ud,fieldName);
               return;
            }
            if (wishN != N) {
               lwpr_xml_dim_error(ud,fieldName,0,wishN);
            }
            break;
         case 4:
            ud->M = M;
            ud->N = N;
            if (!strcmp(fieldName,"D")) {
               ud->curPtr = (void *) RF->D;
               wishM = wishN = model->nIn;
            } else if (!strcmp(fieldName,"M")) {
               ud->curPtr = (void *) RF->M;
               wishM = wishN = model->nIn;
            } else if (!strcmp(fieldName,"alpha")) {
               ud->curPtr = (void *) RF->alpha;
               wishM = wishN = model->nIn;
            } else if (!strcmp(fieldName,"b")) {
               ud->curPtr = (void *) RF->b;
               wishM = wishN = model->nIn;
            } else if (!strcmp(fieldName,"h")) {
               ud->curPtr = (void *) RF->h;
               wishM = wishN = model->nIn;
            } else if (!strcmp(fieldName,"SXresYres")) {
               ud->curPtr = (void *) RF->SXresYres;
               wishM = model->nIn;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"SSXres")) {
               ud->curPtr = (void *) RF->SSXres;
               wishM = model->nIn;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"U")) {
               ud->curPtr = (void *) RF->U;
               wishM = model->nIn;
               wishN = RF->nReg;
            } else if (!strcmp(fieldName,"P")) {
               ud->curPtr = (void *) RF->P;
               wishM = model->nIn;
               wishN = RF->nReg;
            } else {
               lwpr_xml_report_unknown(ud,fieldName);
               return;
            }
            if (wishN != N || wishM != M) {
               lwpr_xml_dim_error(ud,fieldName,wishM,wishN);
            }
            break;
      }
   }
}