/** * @return 0 in case of socket closed, -1 in case of other error, or * >0 the number of bytes read. */ int ReceiveTransaction(const ConnectionInfo *conn_info, char *buffer, int *more) { char proto[CF_INBAND_OFFSET + 1] = { 0 }; char status = 'x'; unsigned int len = 0; int ret; /* Get control channel. */ switch(ConnectionInfoProtocolVersion(conn_info)) { case CF_PROTOCOL_CLASSIC: ret = RecvSocketStream(ConnectionInfoSocket(conn_info), proto, CF_INBAND_OFFSET); break; case CF_PROTOCOL_TLS: ret = TLSRecv(ConnectionInfoSSL(conn_info), proto, CF_INBAND_OFFSET); break; default: UnexpectedError("ReceiveTransaction: ProtocolVersion %d!", ConnectionInfoProtocolVersion(conn_info)); ret = -1; } if (ret == -1 || ret == 0) return ret; LogRaw(LOG_LEVEL_DEBUG, "ReceiveTransaction header: ", proto, ret); ret = sscanf(proto, "%c %u", &status, &len); if (ret != 2) { Log(LOG_LEVEL_ERR, "ReceiveTransaction: Bad packet -- bogus header: %s", proto); return -1; } if (len > CF_BUFSIZE - CF_INBAND_OFFSET) { Log(LOG_LEVEL_ERR, "ReceiveTransaction: Bad packet -- too long (len=%d)", len); return -1; } if (status != CF_MORE && status != CF_DONE) { Log(LOG_LEVEL_ERR, "ReceiveTransaction: Bad packet -- bogus header (more='%c')", status); return -1; } if (more != NULL) { switch (status) { case CF_MORE: *more = true; break; case CF_DONE: *more = false; break; default: ProgrammingError("Unreachable, " "bogus headers have already been checked!"); } } /* Get data. */ switch(ConnectionInfoProtocolVersion(conn_info)) { case CF_PROTOCOL_CLASSIC: ret = RecvSocketStream(ConnectionInfoSocket(conn_info), buffer, len); break; case CF_PROTOCOL_TLS: ret = TLSRecv(ConnectionInfoSSL(conn_info), buffer, len); break; default: UnexpectedError("ReceiveTransaction: ProtocolVersion %d!", ConnectionInfoProtocolVersion(conn_info)); ret = -1; } LogRaw(LOG_LEVEL_DEBUG, "ReceiveTransaction data: ", buffer, ret); return ret; }
/* * This test checks for the three basic operations: * - TLSSend * - TLSRecv * - TLSRecvLine * It is difficult to test each one separatedly, so we test all at once. * The test consists on establishing a connection to our child process and then * sending and receiving data. We switch between the original functions and the * mock functions. * We do not test SSL_new, SSL_accept and such because those will be covered by either * the client or server tests. */ static void test_TLSBasicIO(void) { ASSERT_IF_NOT_INITIALIZED; RESET_STATUS; SSL *ssl = NULL; char output_buffer[] = "this is a buffer"; int output_buffer_length = strlen(output_buffer); char input_buffer[4096]; int result = 0; /* * Open a socket and establish a tcp connection. */ struct sockaddr_in server_addr; int server = 0; memset(&server_addr, 0, sizeof(struct sockaddr_in)); server = socket(AF_INET, SOCK_STREAM, 0); assert_int_not_equal(-1, server); server_addr.sin_family = AF_INET; /* We should not use inet_addr, but it is easier for this particular case. */ server_addr.sin_addr.s_addr = inet_addr("127.0.0.1"); server_addr.sin_port = htons(8035); /* * Connect */ result = connect(server, (struct sockaddr *)&server_addr, sizeof(struct sockaddr_in)); assert_int_not_equal(-1, result); /* * Create a SSL instance */ ssl = SSL_new(SSLCLIENTCONTEXT); assert_true(ssl != NULL); SSL_set_fd(ssl, server); /* * Establish the TLS connection over the socket. */ result = SSL_connect(ssl); assert_int_not_equal(-1, result); /* * Start testing. The first obvious thing to test is to send data. */ result = TLSSend(ssl, output_buffer, output_buffer_length); assert_int_equal(result, output_buffer_length); /* * Good we sent data and the data was sent. Let's check what we get back * by using TLSRecv. */ result = TLSRecv(ssl, input_buffer, output_buffer_length); assert_int_equal(output_buffer_length, result); input_buffer[output_buffer_length] = '\0'; assert_string_equal(output_buffer, input_buffer); /* * Brilliant! We transmitted and received data using simple communication. * Let's try the line sending. */ char output_line_buffer[] = "hello\ngoodbye\n"; int output_line_buffer_length = strlen(output_line_buffer); char output_just_hello[] = "hello"; int output_just_hello_length = strlen(output_just_hello); result = TLSSend(ssl, output_line_buffer, output_line_buffer_length); assert_int_equal(result, output_line_buffer_length); result = TLSRecvLine(ssl, input_buffer, output_line_buffer_length); /* The reply should be up to the first hello */ assert_int_equal(result, output_just_hello_length); assert_string_equal(input_buffer, output_just_hello); /* * Basic check */ USE_MOCK(SSL_write); USE_MOCK(SSL_read); assert_int_equal(0, TLSSend(ssl, output_buffer, 0)); assert_int_equal(-1, TLSSend(ssl, output_buffer, output_buffer_length)); assert_int_equal(-1, TLSRecv(ssl, input_buffer, output_buffer_length)); RESET_STATUS; /* * Start replacing the functions inside to check that the logic works * We start by testing TLSSend, then TLSRead and at last TLSRecvLine. */ USE_MOCK(SSL_write); SSL_WRITE_RETURN(0); assert_int_equal(0, TLSSend(ssl, output_buffer, output_buffer_length)); USE_MOCK(SSL_get_shutdown); SSL_GET_SHUTDOWN_RETURN(1); assert_int_equal(0, TLSSend(ssl, output_buffer, output_buffer_length)); SSL_WRITE_RETURN(-1); assert_int_equal(-1, TLSSend(ssl, output_buffer, output_buffer_length)); USE_MOCK(SSL_read); SSL_READ_RETURN(0); SSL_GET_SHUTDOWN_RETURN(0); assert_int_equal(0, TLSRecv(ssl, input_buffer, output_buffer_length)); SSL_GET_SHUTDOWN_RETURN(1); assert_int_equal(0, TLSRecv(ssl, input_buffer, output_buffer_length)); SSL_READ_RETURN(-1); assert_int_equal(-1, TLSRecv(ssl, input_buffer, output_buffer_length)); USE_ORIGINAL(SSL_write); SSL_READ_RETURN(0); assert_int_equal(-1, TLSRecvLine(ssl, input_buffer, output_buffer_length)); SSL_READ_RETURN(-1); assert_int_equal(-1, TLSRecvLine(ssl, input_buffer, output_buffer_length)); SSL_READ_RETURN(5); assert_int_equal(-1, TLSRecvLine(ssl, input_buffer, 10)); SSL_READ_USE_BUFFER(output_line_buffer); assert_int_equal(5, TLSRecvLine(ssl, input_buffer, output_line_buffer_length)); assert_string_equal(output_just_hello, input_buffer); result = SSL_shutdown(ssl); if (ssl) { SSL_free(ssl); } RESET_STATUS; }
/** * Receive a transaction packet of at most CF_BUFSIZE-1 bytes, and * NULL-terminate it. * * @return 0 in case of socket closed, -1 in case of other error, or * >0 the number of bytes read. */ int ReceiveTransaction(const ConnectionInfo *conn_info, char *buffer, int *more) { char proto[CF_INBAND_OFFSET + 1] = { 0 }; int ret; /* Get control channel. */ switch(conn_info->protocol) { case CF_PROTOCOL_CLASSIC: ret = RecvSocketStream(conn_info->sd, proto, CF_INBAND_OFFSET); break; case CF_PROTOCOL_TLS: ret = TLSRecv(conn_info->ssl, proto, CF_INBAND_OFFSET); break; default: UnexpectedError("ReceiveTransaction: ProtocolVersion %d!", conn_info->protocol); ret = -1; } if (ret == -1 || ret == 0) { return ret; } else if (ret != CF_INBAND_OFFSET) { Log(LOG_LEVEL_ERR, "ReceiveTransaction: bogus short header (%d bytes: '%s')", ret, proto); return -1; } LogRaw(LOG_LEVEL_DEBUG, "ReceiveTransaction header: ", proto, ret); char status = 'x'; int len = 0; ret = sscanf(proto, "%c %d", &status, &len); if (ret != 2) { Log(LOG_LEVEL_ERR, "ReceiveTransaction: bogus header: %s", proto); return -1; } if (status != CF_MORE && status != CF_DONE) { Log(LOG_LEVEL_ERR, "ReceiveTransaction: bogus header (more='%c')", status); return -1; } if (len > CF_BUFSIZE - CF_INBAND_OFFSET) { Log(LOG_LEVEL_ERR, "ReceiveTransaction: packet too long (len=%d)", len); return -1; } else if (len <= 0) { /* Zero-length packets are disallowed, because * ReceiveTransaction() == 0 currently means connection closed. */ Log(LOG_LEVEL_ERR, "ReceiveTransaction: packet too short (len=%d)", len); return -1; } if (more != NULL) { switch (status) { case CF_MORE: *more = true; break; case CF_DONE: *more = false; break; default: ProgrammingError("Unreachable, " "bogus headers have already been checked!"); } } /* Get data. */ switch(conn_info->protocol) { case CF_PROTOCOL_CLASSIC: ret = RecvSocketStream(conn_info->sd, buffer, len); break; case CF_PROTOCOL_TLS: ret = TLSRecv(conn_info->ssl, buffer, len); break; default: UnexpectedError("ReceiveTransaction: ProtocolVersion %d!", conn_info->protocol); ret = -1; } if (ret == -1 || ret == 0) { return ret; } else if (ret != len) { /* Should never happen given that we are using SSL_MODE_AUTO_RETRY and * that transaction payload < CF_BUFSIZE < TLS record size. */ Log(LOG_LEVEL_ERR, "Partial transaction read %d != %d bytes!", ret, len); return -1; } LogRaw(LOG_LEVEL_DEBUG, "ReceiveTransaction data: ", buffer, ret); return ret; }
int CopyRegularFileNet(const char *source, const char *dest, off_t size, bool encrypt, AgentConnection *conn) { char *buf, workbuf[CF_BUFSIZE], cfchangedstr[265]; const int buf_size = 2048; off_t n_read_total = 0; EVP_CIPHER_CTX crypto_ctx; /* We encrypt only for CLASSIC protocol. The TLS protocol is always over * encrypted layer, so it does not support encrypted (S*) commands. */ encrypt = encrypt && conn->conn_info->protocol == CF_PROTOCOL_CLASSIC; if (encrypt) { return EncryptCopyRegularFileNet(source, dest, size, conn); } snprintf(cfchangedstr, 255, "%s%s", CF_CHANGEDSTR1, CF_CHANGEDSTR2); if ((strlen(dest) > CF_BUFSIZE - 20)) { Log(LOG_LEVEL_ERR, "Filename too long"); return false; } unlink(dest); /* To avoid link attacks */ int dd = safe_open(dest, O_WRONLY | O_CREAT | O_TRUNC | O_EXCL | O_BINARY, 0600); if (dd == -1) { Log(LOG_LEVEL_ERR, "Copy from server '%s' to destination '%s' failed (open: %s)", conn->this_server, dest, GetErrorStr()); unlink(dest); return false; } workbuf[0] = '\0'; int tosend = snprintf(workbuf, CF_BUFSIZE, "GET %d %s", buf_size, source); if (tosend <= 0 || tosend >= CF_BUFSIZE) { Log(LOG_LEVEL_ERR, "Failed to compose GET command for file %s", source); close(dd); return false; } /* Send proposition C0 */ if (SendTransaction(conn->conn_info, workbuf, tosend, CF_DONE) == -1) { Log(LOG_LEVEL_ERR, "Couldn't send GET command"); close(dd); return false; } buf = xmalloc(CF_BUFSIZE + sizeof(int)); /* Note CF_BUFSIZE not buf_size !! */ Log(LOG_LEVEL_VERBOSE, "Copying remote file '%s:%s', expecting %jd bytes", conn->this_server, source, (intmax_t)size); n_read_total = 0; while (n_read_total < size) { int toget = MIN(size - n_read_total, buf_size); assert(toget != 0); /* Stage C1 - receive */ int n_read; switch(conn->conn_info->protocol) { case CF_PROTOCOL_CLASSIC: n_read = RecvSocketStream(conn->conn_info->sd, buf, toget); break; case CF_PROTOCOL_TLS: n_read = TLSRecv(conn->conn_info->ssl, buf, toget); break; default: UnexpectedError("CopyRegularFileNet: ProtocolVersion %d!", conn->conn_info->protocol); n_read = -1; } if (n_read <= 0) { /* This may happen on race conditions, where the file has shrunk * since we asked for its size in SYNCH ... STAT source */ Log(LOG_LEVEL_ERR, "Error in client-server stream, has %s:%s shrunk? (code %d)", conn->this_server, source, n_read); close(dd); free(buf); return false; } /* If the first thing we get is an error message, break. */ if ((n_read_total == 0) && (strncmp(buf, CF_FAILEDSTR, strlen(CF_FAILEDSTR)) == 0)) { Log(LOG_LEVEL_INFO, "Network access to '%s:%s' denied", conn->this_server, source); close(dd); free(buf); return false; } if (strncmp(buf, cfchangedstr, strlen(cfchangedstr)) == 0) { Log(LOG_LEVEL_INFO, "Source '%s:%s' changed while copying", conn->this_server, source); close(dd); free(buf); return false; } /* Check for mismatch between encryption here and on server. */ int value = -1; sscanf(buf, "t %d", &value); if ((value > 0) && (strncmp(buf + CF_INBAND_OFFSET, "BAD: ", 5) == 0)) { Log(LOG_LEVEL_INFO, "Network access to cleartext '%s:%s' denied", conn->this_server, source); close(dd); free(buf); return false; } if (!FSWrite(dest, dd, buf, n_read)) { Log(LOG_LEVEL_ERR, "Local disk write failed copying '%s:%s' to '%s'. (FSWrite: %s)", conn->this_server, source, dest, GetErrorStr()); if (conn) { conn->error = true; } free(buf); unlink(dest); close(dd); FlushFileStream(conn->conn_info->sd, size - n_read_total); EVP_CIPHER_CTX_cleanup(&crypto_ctx); return false; } n_read_total += n_read; } /* If the file ends with a `hole', something needs to be written at the end. Otherwise the kernel would truncate the file at the end of the last write operation. Write a null character and truncate it again. */ if (ftruncate(dd, n_read_total) < 0) { Log(LOG_LEVEL_ERR, "Copy failed (no space?) while copying '%s' from network '%s'", dest, GetErrorStr()); free(buf); unlink(dest); close(dd); FlushFileStream(conn->conn_info->sd, size - n_read_total); return false; } close(dd); free(buf); return true; }