Пример #1
0
Layer* LayerFactory::ConstructLayer(std::string descriptor) {
  if (!IsValidDescriptor(descriptor))
    return nullptr;
  std::string configuration = ExtractConfiguration(descriptor);
  std::string layertype = ExtractLayerType(descriptor);
  
  Layer* layer = nullptr;
  if(layertype.length() == 0) {
    // Leave layer a nullptr
  }
  CONV_LAYER_TYPE("convolution", ConvolutionLayer)
  CONV_LAYER_TYPE("maxpooling", MaxPoolingLayer)
  CONV_LAYER_TYPE("amaxpooling", AdvancedMaxPoolingLayer)
  CONV_LAYER_TYPE("tanh", TanhLayer)
  CONV_LAYER_TYPE("sigm", SigmoidLayer)
  CONV_LAYER_TYPE("relu", ReLULayer)
  CONV_LAYER_TYPE("gradientaccumulation", GradientAccumulationLayer)
  CONV_LAYER_TYPE("resize", ResizeLayer)
  
  return layer;
}
Пример #2
0
std::string LayerFactory::InjectSeed(std::string descriptor, unsigned int seed) {
  if(IsValidDescriptor(descriptor)) {
    const std::string has_seed_regex = ".*seed=[0-9]+.*";
    const std::string new_seed_regex = "seed=([0-9])+";
    std::string configuration = ExtractConfiguration(descriptor);
    std::string layertype = ExtractLayerType(descriptor);
    
    std::stringstream seed_ss;
    seed_ss << "seed=" << seed;
    
#ifdef BUILD_BOOST
    bool already_has_seed = boost::regex_match(configuration, boost::regex(has_seed_regex, boost::regex::extended));
#else
    bool already_has_seed = std::regex_match(configuration, std::regex(has_seed_regex, std::regex::extended));
#endif
    if(already_has_seed) {
#ifdef BUILD_BOOST
      std::string new_descriptor = boost::regex_replace(descriptor, boost::regex(new_seed_regex, boost::regex::extended), seed_ss.str());
#else
      std::string new_descriptor = std::regex_replace(descriptor, std::regex(new_seed_regex, std::regex::extended), seed_ss.str());
#endif
      return new_descriptor;
    } else {
      std::stringstream new_descriptor_ss;
      new_descriptor_ss << layertype << "(";
      if(configuration.length() > 0) {
        new_descriptor_ss << configuration << " ";
      }
      new_descriptor_ss << seed_ss.str() << ")";
      std::string new_descriptor = new_descriptor_ss.str();
      return new_descriptor;
    }
  } else {
    return descriptor;
  }
}