FacenetClassifier::FacenetClassifier (string operation, string model_path, string classifier_model_path, string labels_file_path) {
	this->operation = operation;
	this->model_path = model_path;
	this->classifier_model_path = classifier_model_path;
	this->labels_file_path = labels_file_path;
	ReadBinaryProto(tensorflow::Env::Default(), model_path.c_str(), &graph_def);
	tensorflow::SessionOptions options;
	tensorflow::NewSession (options, &session);
	session->Create (graph_def);
}
Esempio n. 2
0
void TensorFlowEngine::load() {
	const auto networkFilename = getFilename();
	tensorflow::SessionOptions options;
	tensorflow::ConfigProto &config = options.config;
	tensorflow::GPUOptions* gpuOptions = config.mutable_gpu_options();
	gpuOptions->set_allow_growth(true); // Set this so that tensorflow will not use up all GPU memory
	//gpuOptions->set_per_process_gpu_memory_fraction(0.5);
	mSession.reset(tensorflow::NewSession(options));
	tensorflow::GraphDef tensorflow_graph;

	{
		reportInfo() << "Loading network file: " << networkFilename << reportEnd();
		tensorflow::Status s = ReadBinaryProto(tensorflow::Env::Default(), networkFilename, &tensorflow_graph);
		if (!s.ok()) {
			throw Exception("Could not read TensorFlow graph file " + networkFilename);
		}
	}

	bool nodesSpecified = true;
    int inputCounter = 0;
	if(mInputNodes.size() == 0) {
		nodesSpecified = false;
	}
    for(int i = 0; i < tensorflow_graph.node_size(); ++i) {
		tensorflow::NodeDef node = tensorflow_graph.node(i);
		if(mInputNodes.count(node.name()) > 0) {
		}
		if(node.op() == "Placeholder") {
			if(node.name().find("keras_learning_phase") != std::string::npos) {
				//mLearningPhaseTensors.insert(node.name());
				mLearningPhaseTensors.push_back(node.name());
			} else {
				// Input node found:
				// Get its shape
				// Input nodes use the Op Placeholder
				reportInfo() << "Found input node: " << i << " with name " << node.name() << reportEnd();
				auto shape = getShape(node);
				reportInfo() << "Node has shape " << shape.toString() << reportEnd();
				if(mInputNodes.count(node.name()) == 0) {
					if(nodesSpecified) {
						throw Exception("Encountered unknown node " + node.name());
					}
					reportInfo() << "Node was not specified by user" << reportEnd();
					// If node has not been specified by user, we need to add it
					// and thus know its type (fast image or tensor)
					// It is assumed to be an image if input shape has at least 4 dimensions
					NodeType type = NodeType::TENSOR;
					if(shape.getKnownDimensions() >= 2) {
						reportInfo() << "Assuming node is an image" << reportEnd();
						type = NodeType::IMAGE;
					} else {
						reportInfo() << "Assuming node is a tensor" << reportEnd();
					}
					addInputNode(inputCounter, tensorflow_graph.node(0).name(), type, shape);
					++inputCounter;
				}

				// Set its shape
				mInputNodes[node.name()].shape = shape;
			}
		}
	}

	reportInfo() << "Creating session." << reportEnd();
	tensorflow::Status s = mSession->Create(tensorflow_graph);
	if (!s.ok()) {
		throw Exception("Could not create TensorFlow Graph");
	}

	//tensorflow::graph::SetDefaultDevice("/gpu:0", &tensorflow_graph);

	// Clear the proto to save memory space.
	tensorflow_graph.Clear();
	reportInfo() << "TensorFlow graph loaded from: " << networkFilename << reportEnd();

	setIsLoaded(true);
}