Layer* LayerFactory::ConstructLayer(std::string descriptor) { if (!IsValidDescriptor(descriptor)) return nullptr; std::string configuration = ExtractConfiguration(descriptor); std::string layertype = ExtractLayerType(descriptor); Layer* layer = nullptr; if(layertype.length() == 0) { // Leave layer a nullptr } CONV_LAYER_TYPE("convolution", ConvolutionLayer) CONV_LAYER_TYPE("maxpooling", MaxPoolingLayer) CONV_LAYER_TYPE("amaxpooling", AdvancedMaxPoolingLayer) CONV_LAYER_TYPE("tanh", TanhLayer) CONV_LAYER_TYPE("sigm", SigmoidLayer) CONV_LAYER_TYPE("relu", ReLULayer) CONV_LAYER_TYPE("gradientaccumulation", GradientAccumulationLayer) CONV_LAYER_TYPE("resize", ResizeLayer) return layer; }
std::string LayerFactory::InjectSeed(std::string descriptor, unsigned int seed) { if(IsValidDescriptor(descriptor)) { const std::string has_seed_regex = ".*seed=[0-9]+.*"; const std::string new_seed_regex = "seed=([0-9])+"; std::string configuration = ExtractConfiguration(descriptor); std::string layertype = ExtractLayerType(descriptor); std::stringstream seed_ss; seed_ss << "seed=" << seed; #ifdef BUILD_BOOST bool already_has_seed = boost::regex_match(configuration, boost::regex(has_seed_regex, boost::regex::extended)); #else bool already_has_seed = std::regex_match(configuration, std::regex(has_seed_regex, std::regex::extended)); #endif if(already_has_seed) { #ifdef BUILD_BOOST std::string new_descriptor = boost::regex_replace(descriptor, boost::regex(new_seed_regex, boost::regex::extended), seed_ss.str()); #else std::string new_descriptor = std::regex_replace(descriptor, std::regex(new_seed_regex, std::regex::extended), seed_ss.str()); #endif return new_descriptor; } else { std::stringstream new_descriptor_ss; new_descriptor_ss << layertype << "("; if(configuration.length() > 0) { new_descriptor_ss << configuration << " "; } new_descriptor_ss << seed_ss.str() << ")"; std::string new_descriptor = new_descriptor_ss.str(); return new_descriptor; } } else { return descriptor; } }