int virNetServerClientInit(virNetServerClientPtr client) { virObjectLock(client); #if WITH_GNUTLS if (!client->tlsCtxt) { #endif /* Plain socket, so prepare to read first message */ if (virNetServerClientRegisterEvent(client) < 0) goto error; #if WITH_GNUTLS } else { int ret; if (!(client->tls = virNetTLSSessionNew(client->tlsCtxt, NULL))) goto error; virNetSocketSetTLSSession(client->sock, client->tls); /* Begin the TLS handshake. */ ret = virNetTLSSessionHandshake(client->tls); if (ret == 0) { /* Unlikely, but ... Next step is to check the certificate. */ if (virNetServerClientCheckAccess(client) < 0) goto error; /* Handshake & cert check OK, so prepare to read first message */ if (virNetServerClientRegisterEvent(client) < 0) goto error; } else if (ret > 0) { /* Most likely, need to do more handshake data */ if (virNetServerClientRegisterEvent(client) < 0) goto error; } else { goto error; } } #endif virObjectUnlock(client); return 0; error: client->wantClose = true; virObjectUnlock(client); return -1; }
static void virNetServerClientDispatchHandshake(virNetServerClientPtr client) { int ret; /* Continue the handshake. */ ret = virNetTLSSessionHandshake(client->tls); if (ret == 0) { /* Finished. Next step is to check the certificate. */ if (virNetServerClientCheckAccess(client) < 0) client->wantClose = true; else virNetServerClientUpdateEvent(client); } else if (ret > 0) { /* Carry on waiting for more handshake. Update the events just in case handshake data flow direction has changed */ virNetServerClientUpdateEvent (client); } else { /* Fatal error in handshake */ client->wantClose = true; } }
/* * This tests validation checking of peer certificates * * This is replicating the checks that are done for an * active TLS session after handshake completes. To * simulate that we create our TLS contexts, skipping * sanity checks. When then get a socketpair, and * initiate a TLS session across them. Finally do * do actual cert validation tests */ static int testTLSSessionInit(const void *opaque) { struct testTLSSessionData *data = (struct testTLSSessionData *)opaque; virNetTLSContextPtr clientCtxt = NULL; virNetTLSContextPtr serverCtxt = NULL; virNetTLSSessionPtr clientSess = NULL; virNetTLSSessionPtr serverSess = NULL; int ret = -1; int channel[2]; bool clientShake = false; bool serverShake = false; /* We'll use this for our fake client-server connection */ if (socketpair(AF_UNIX, SOCK_STREAM, 0, channel) < 0) abort(); /* * We have an evil loop to do the handshake in a single * thread, so we need these non-blocking to avoid deadlock * of ourselves */ ignore_value(virSetNonBlock(channel[0])); ignore_value(virSetNonBlock(channel[1])); /* We skip initial sanity checks here because we * want to make sure that problems are being * detected at the TLS session validation stage */ serverCtxt = virNetTLSContextNewServer(data->servercacrt, NULL, data->servercrt, KEYFILE, data->wildcards, NULL, false, true); clientCtxt = virNetTLSContextNewClient(data->clientcacrt, NULL, data->clientcrt, KEYFILE, NULL, false, true); if (!serverCtxt) { VIR_WARN("Unexpected failure loading %s against %s", data->servercacrt, data->servercrt); goto cleanup; } if (!clientCtxt) { VIR_WARN("Unexpected failure loading %s against %s", data->clientcacrt, data->clientcrt); goto cleanup; } /* Now the real part of the test, setup the sessions */ serverSess = virNetTLSSessionNew(serverCtxt, NULL); clientSess = virNetTLSSessionNew(clientCtxt, data->hostname); if (!serverSess) { VIR_WARN("Unexpected failure using %s against %s", data->servercacrt, data->servercrt); goto cleanup; } if (!clientSess) { VIR_WARN("Unexpected failure using %s against %s", data->clientcacrt, data->clientcrt); goto cleanup; } /* For handshake to work, we need to set the I/O callbacks * to read/write over the socketpair */ virNetTLSSessionSetIOCallbacks(serverSess, testWrite, testRead, &channel[0]); virNetTLSSessionSetIOCallbacks(clientSess, testWrite, testRead, &channel[1]); /* * Finally we loop around & around doing handshake on each * session until we get an error, or the handshake completes. * This relies on the socketpair being nonblocking to avoid * deadlocking ourselves upon handshake */ do { int rv; if (!serverShake) { rv = virNetTLSSessionHandshake(serverSess); if (rv < 0) goto cleanup; if (rv == VIR_NET_TLS_HANDSHAKE_COMPLETE) serverShake = true; } if (!clientShake) { rv = virNetTLSSessionHandshake(clientSess); if (rv < 0) goto cleanup; if (rv == VIR_NET_TLS_HANDSHAKE_COMPLETE) clientShake = true; } } while (!clientShake && !serverShake); /* Finally make sure the server validation does what * we were expecting */ if (virNetTLSContextCheckCertificate(serverCtxt, serverSess) < 0) { if (!data->expectServerFail) { VIR_WARN("Unexpected server cert check fail"); goto cleanup; } else { VIR_DEBUG("Got expected server cert fail"); } } else { if (data->expectServerFail) { VIR_WARN("Expected server cert check fail"); goto cleanup; } else { VIR_DEBUG("No unexpected server cert fail"); } } /* * And the same for the client validation check */ if (virNetTLSContextCheckCertificate(clientCtxt, clientSess) < 0) { if (!data->expectClientFail) { VIR_WARN("Unexpected client cert check fail"); goto cleanup; } else { VIR_DEBUG("Got expected client cert fail"); } } else { if (data->expectClientFail) { VIR_WARN("Expected client cert check fail"); goto cleanup; } else { VIR_DEBUG("No unexpected client cert fail"); } } ret = 0; cleanup: virObjectUnref(serverCtxt); virObjectUnref(clientCtxt); virObjectUnref(serverSess); virObjectUnref(clientSess); VIR_FORCE_CLOSE(channel[0]); VIR_FORCE_CLOSE(channel[1]); return ret; }
int virNetClientSetTLSSession(virNetClientPtr client, virNetTLSContextPtr tls) { int ret; char buf[1]; int len; struct pollfd fds[1]; sigset_t oldmask, blockedsigs; sigemptyset (&blockedsigs); #ifdef SIGWINCH sigaddset (&blockedsigs, SIGWINCH); #endif #ifdef SIGCHLD sigaddset (&blockedsigs, SIGCHLD); #endif sigaddset (&blockedsigs, SIGPIPE); virNetClientLock(client); if (!(client->tls = virNetTLSSessionNew(tls, client->hostname))) goto error; virNetSocketSetTLSSession(client->sock, client->tls); for (;;) { ret = virNetTLSSessionHandshake(client->tls); if (ret < 0) goto error; if (ret == 0) break; fds[0].fd = virNetSocketGetFD(client->sock); fds[0].revents = 0; if (virNetTLSSessionGetHandshakeStatus(client->tls) == VIR_NET_TLS_HANDSHAKE_RECVING) fds[0].events = POLLIN; else fds[0].events = POLLOUT; /* Block SIGWINCH from interrupting poll in curses programs, * then restore the original signal mask again immediately * after the call (RHBZ#567931). Same for SIGCHLD and SIGPIPE * at the suggestion of Paolo Bonzini and Daniel Berrange. */ ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); repoll: ret = poll(fds, ARRAY_CARDINALITY(fds), -1); if (ret < 0 && errno == EAGAIN) goto repoll; ignore_value(pthread_sigmask(SIG_BLOCK, &oldmask, NULL)); } ret = virNetTLSContextCheckCertificate(tls, client->tls); if (ret < 0) goto error; /* At this point, the server is verifying _our_ certificate, IP address, * etc. If we make the grade, it will send us a '\1' byte. */ fds[0].fd = virNetSocketGetFD(client->sock); fds[0].revents = 0; fds[0].events = POLLIN; /* Block SIGWINCH from interrupting poll in curses programs */ ignore_value(pthread_sigmask(SIG_BLOCK, &blockedsigs, &oldmask)); repoll2: ret = poll(fds, ARRAY_CARDINALITY(fds), -1); if (ret < 0 && errno == EAGAIN) goto repoll2; ignore_value(pthread_sigmask(SIG_BLOCK, &oldmask, NULL)); len = virNetTLSSessionRead(client->tls, buf, 1); if (len < 0 && errno != ENOMSG) { virReportSystemError(errno, "%s", _("Unable to read TLS confirmation")); goto error; } if (len != 1 || buf[0] != '\1') { virNetError(VIR_ERR_RPC, "%s", _("server verification (of our certificate or IP " "address) failed")); goto error; } virNetClientUnlock(client); return 0; error: virNetTLSSessionFree(client->tls); client->tls = NULL; virNetClientUnlock(client); return -1; }