Example #1
0
/**
 *  @return 0 in case of socket closed, -1 in case of other error, or
 *          >0 the number of bytes read.
 */
int ReceiveTransaction(const ConnectionInfo *conn_info, char *buffer, int *more)
{
    char proto[CF_INBAND_OFFSET + 1] = { 0 };
    char status = 'x';
    unsigned int len = 0;
    int ret;

    /* Get control channel. */
    switch(ConnectionInfoProtocolVersion(conn_info))
    {
    case CF_PROTOCOL_CLASSIC:
        ret = RecvSocketStream(ConnectionInfoSocket(conn_info), proto, CF_INBAND_OFFSET);
        break;
    case CF_PROTOCOL_TLS:
        ret = TLSRecv(ConnectionInfoSSL(conn_info), proto, CF_INBAND_OFFSET);
        break;
    default:
        UnexpectedError("ReceiveTransaction: ProtocolVersion %d!",
                        ConnectionInfoProtocolVersion(conn_info));
        ret = -1;
    }
    if (ret == -1 || ret == 0)
        return ret;

    LogRaw(LOG_LEVEL_DEBUG, "ReceiveTransaction header: ", proto, ret);

    ret = sscanf(proto, "%c %u", &status, &len);
    if (ret != 2)
    {
        Log(LOG_LEVEL_ERR,
            "ReceiveTransaction: Bad packet -- bogus header: %s", proto);
        return -1;
    }
    if (len > CF_BUFSIZE - CF_INBAND_OFFSET)
    {
        Log(LOG_LEVEL_ERR,
            "ReceiveTransaction: Bad packet -- too long (len=%d)", len);
        return -1;
    }
    if (status != CF_MORE && status != CF_DONE)
    {
        Log(LOG_LEVEL_ERR,
            "ReceiveTransaction: Bad packet -- bogus header (more='%c')",
            status);
        return -1;
    }

    if (more != NULL)
    {
        switch (status)
        {
        case CF_MORE:
                *more = true;
                break;
        case CF_DONE:
                *more = false;
                break;
        default:
            ProgrammingError("Unreachable, "
                             "bogus headers have already been checked!");
        }
    }

    /* Get data. */
    switch(ConnectionInfoProtocolVersion(conn_info))
    {
    case CF_PROTOCOL_CLASSIC:
        ret = RecvSocketStream(ConnectionInfoSocket(conn_info), buffer, len);
        break;
    case CF_PROTOCOL_TLS:
        ret = TLSRecv(ConnectionInfoSSL(conn_info), buffer, len);
        break;
    default:
        UnexpectedError("ReceiveTransaction: ProtocolVersion %d!",
                        ConnectionInfoProtocolVersion(conn_info));
        ret = -1;
    }

    LogRaw(LOG_LEVEL_DEBUG, "ReceiveTransaction data: ", buffer, ret);

    return ret;
}
Example #2
0
/*
 * This test checks for the three basic operations:
 * - TLSSend
 * - TLSRecv
 * - TLSRecvLine
 * It is difficult to test each one separatedly, so we test all at once.
 * The test consists on establishing a connection to our child process and then
 * sending and receiving data. We switch between the original functions and the
 * mock functions.
 * We do not test SSL_new, SSL_accept and such because those will be covered by either
 * the client or server tests.
 */
static void test_TLSBasicIO(void)
{
    ASSERT_IF_NOT_INITIALIZED;
    RESET_STATUS;
    SSL *ssl = NULL;
    char output_buffer[] = "this is a buffer";
    int output_buffer_length = strlen(output_buffer);
    char input_buffer[4096];
    int result = 0;

    /*
     * Open a socket and establish a tcp connection.
     */
    struct sockaddr_in server_addr;
    int server = 0;

    memset(&server_addr, 0, sizeof(struct sockaddr_in));
    server = socket(AF_INET, SOCK_STREAM, 0);
    assert_int_not_equal(-1, server);
    server_addr.sin_family = AF_INET;
    /* We should not use inet_addr, but it is easier for this particular case. */
    server_addr.sin_addr.s_addr = inet_addr("127.0.0.1");
    server_addr.sin_port = htons(8035);

    /*
     * Connect
     */
    result = connect(server, (struct sockaddr *)&server_addr, sizeof(struct sockaddr_in));
    assert_int_not_equal(-1, result);
    /*
     * Create a SSL instance
     */
    ssl = SSL_new(SSLCLIENTCONTEXT);
    assert_true(ssl != NULL);
    SSL_set_fd(ssl, server);
    /*
     * Establish the TLS connection over the socket.
     */
    result = SSL_connect(ssl);
    assert_int_not_equal(-1, result);
    /*
     * Start testing. The first obvious thing to test is to send data.
     */
    result = TLSSend(ssl, output_buffer, output_buffer_length);
    assert_int_equal(result, output_buffer_length);
    /*
     * Good we sent data and the data was sent. Let's check what we get back
     * by using TLSRecv.
     */
    result = TLSRecv(ssl, input_buffer, output_buffer_length);
    assert_int_equal(output_buffer_length, result);

    input_buffer[output_buffer_length] = '\0';
    assert_string_equal(output_buffer, input_buffer);
    /*
     * Brilliant! We transmitted and received data using simple communication.
     * Let's try the line sending.
     */
    char output_line_buffer[] = "hello\ngoodbye\n";
    int output_line_buffer_length = strlen(output_line_buffer);
    char output_just_hello[] = "hello";
    int output_just_hello_length = strlen(output_just_hello);

    result = TLSSend(ssl, output_line_buffer, output_line_buffer_length);
    assert_int_equal(result, output_line_buffer_length);

    result = TLSRecvLine(ssl, input_buffer, output_line_buffer_length);
    /* The reply should be up to the first hello */
    assert_int_equal(result, output_just_hello_length);
    assert_string_equal(input_buffer, output_just_hello);

    /*
     * Basic check
     */
    USE_MOCK(SSL_write);
    USE_MOCK(SSL_read);
    assert_int_equal(0, TLSSend(ssl, output_buffer, 0));
    assert_int_equal(-1, TLSSend(ssl, output_buffer, output_buffer_length));
    assert_int_equal(-1, TLSRecv(ssl, input_buffer, output_buffer_length));
    RESET_STATUS;

    /*
     * Start replacing the functions inside to check that the logic works
     * We start by testing TLSSend, then TLSRead and at last TLSRecvLine.
     */
    USE_MOCK(SSL_write);
    SSL_WRITE_RETURN(0);
    assert_int_equal(0, TLSSend(ssl, output_buffer, output_buffer_length));
    USE_MOCK(SSL_get_shutdown);
    SSL_GET_SHUTDOWN_RETURN(1);
    assert_int_equal(0, TLSSend(ssl, output_buffer, output_buffer_length));
    SSL_WRITE_RETURN(-1);
    assert_int_equal(-1, TLSSend(ssl, output_buffer, output_buffer_length));

    USE_MOCK(SSL_read);
    SSL_READ_RETURN(0);
    SSL_GET_SHUTDOWN_RETURN(0);
    assert_int_equal(0, TLSRecv(ssl, input_buffer, output_buffer_length));
    SSL_GET_SHUTDOWN_RETURN(1);
    assert_int_equal(0, TLSRecv(ssl, input_buffer, output_buffer_length));
    SSL_READ_RETURN(-1);
    assert_int_equal(-1, TLSRecv(ssl, input_buffer, output_buffer_length));

    USE_ORIGINAL(SSL_write);
    SSL_READ_RETURN(0);
    assert_int_equal(-1, TLSRecvLine(ssl, input_buffer, output_buffer_length));
    SSL_READ_RETURN(-1);
    assert_int_equal(-1, TLSRecvLine(ssl, input_buffer, output_buffer_length));
    SSL_READ_RETURN(5);
    assert_int_equal(-1, TLSRecvLine(ssl, input_buffer, 10));
    SSL_READ_USE_BUFFER(output_line_buffer);
    assert_int_equal(5, TLSRecvLine(ssl, input_buffer, output_line_buffer_length));
    assert_string_equal(output_just_hello, input_buffer);

    result = SSL_shutdown(ssl);
    if (ssl)
    {
        SSL_free(ssl);
    }
    RESET_STATUS;
}
Example #3
0
/**
 *  Receive a transaction packet of at most CF_BUFSIZE-1 bytes, and
 *  NULL-terminate it.
 *
 *  @return 0 in case of socket closed, -1 in case of other error, or
 *          >0 the number of bytes read.
 */
int ReceiveTransaction(const ConnectionInfo *conn_info, char *buffer, int *more)
{
    char proto[CF_INBAND_OFFSET + 1] = { 0 };
    int ret;

    /* Get control channel. */
    switch(conn_info->protocol)
    {
    case CF_PROTOCOL_CLASSIC:
        ret = RecvSocketStream(conn_info->sd, proto, CF_INBAND_OFFSET);
        break;
    case CF_PROTOCOL_TLS:
        ret = TLSRecv(conn_info->ssl, proto, CF_INBAND_OFFSET);
        break;
    default:
        UnexpectedError("ReceiveTransaction: ProtocolVersion %d!",
                        conn_info->protocol);
        ret = -1;
    }

    if (ret == -1 || ret == 0)
    {
        return ret;
    }
    else if (ret != CF_INBAND_OFFSET)
    {
        Log(LOG_LEVEL_ERR,
            "ReceiveTransaction: bogus short header (%d bytes: '%s')",
            ret, proto);
        return -1;
    }

    LogRaw(LOG_LEVEL_DEBUG, "ReceiveTransaction header: ", proto, ret);

    char status = 'x';
    int len = 0;

    ret = sscanf(proto, "%c %d", &status, &len);
    if (ret != 2)
    {
        Log(LOG_LEVEL_ERR,
            "ReceiveTransaction: bogus header: %s", proto);
        return -1;
    }
    if (status != CF_MORE && status != CF_DONE)
    {
        Log(LOG_LEVEL_ERR,
            "ReceiveTransaction: bogus header (more='%c')", status);
        return -1;
    }
    if (len > CF_BUFSIZE - CF_INBAND_OFFSET)
    {
        Log(LOG_LEVEL_ERR,
            "ReceiveTransaction: packet too long (len=%d)", len);
        return -1;
    }
    else if (len <= 0)
    {
        /* Zero-length packets are disallowed, because
         * ReceiveTransaction() == 0 currently means connection closed. */
        Log(LOG_LEVEL_ERR,
            "ReceiveTransaction: packet too short (len=%d)", len);
        return -1;
    }

    if (more != NULL)
    {
        switch (status)
        {
        case CF_MORE:
                *more = true;
                break;
        case CF_DONE:
                *more = false;
                break;
        default:
            ProgrammingError("Unreachable, "
                             "bogus headers have already been checked!");
        }
    }

    /* Get data. */
    switch(conn_info->protocol)
    {
    case CF_PROTOCOL_CLASSIC:
        ret = RecvSocketStream(conn_info->sd, buffer, len);
        break;
    case CF_PROTOCOL_TLS:
        ret = TLSRecv(conn_info->ssl, buffer, len);
        break;
    default:
        UnexpectedError("ReceiveTransaction: ProtocolVersion %d!",
                        conn_info->protocol);
        ret = -1;
    }

    if (ret == -1 || ret == 0)
    {
        return ret;
    }
    else if (ret != len)
    {
        /* Should never happen given that we are using SSL_MODE_AUTO_RETRY and
         * that transaction payload < CF_BUFSIZE < TLS record size. */
        Log(LOG_LEVEL_ERR,
            "Partial transaction read %d != %d bytes!",
            ret, len);
        return -1;
    }

    LogRaw(LOG_LEVEL_DEBUG, "ReceiveTransaction data: ", buffer, ret);

    return ret;
}
Example #4
0
int CopyRegularFileNet(const char *source, const char *dest, off_t size,
                       bool encrypt, AgentConnection *conn)
{
    char *buf, workbuf[CF_BUFSIZE], cfchangedstr[265];
    const int buf_size = 2048;

    off_t n_read_total = 0;
    EVP_CIPHER_CTX crypto_ctx;

    /* We encrypt only for CLASSIC protocol. The TLS protocol is always over
     * encrypted layer, so it does not support encrypted (S*) commands. */
    encrypt = encrypt && conn->conn_info->protocol == CF_PROTOCOL_CLASSIC;

    if (encrypt)
    {
        return EncryptCopyRegularFileNet(source, dest, size, conn);
    }

    snprintf(cfchangedstr, 255, "%s%s", CF_CHANGEDSTR1, CF_CHANGEDSTR2);

    if ((strlen(dest) > CF_BUFSIZE - 20))
    {
        Log(LOG_LEVEL_ERR, "Filename too long");
        return false;
    }

    unlink(dest);                /* To avoid link attacks */

    int dd = safe_open(dest, O_WRONLY | O_CREAT | O_TRUNC | O_EXCL | O_BINARY, 0600);
    if (dd == -1)
    {
        Log(LOG_LEVEL_ERR,
            "Copy from server '%s' to destination '%s' failed (open: %s)",
            conn->this_server, dest, GetErrorStr());
        unlink(dest);
        return false;
    }



    workbuf[0] = '\0';
    int tosend = snprintf(workbuf, CF_BUFSIZE, "GET %d %s", buf_size, source);
    if (tosend <= 0 || tosend >= CF_BUFSIZE)
    {
        Log(LOG_LEVEL_ERR, "Failed to compose GET command for file %s",
            source);
        close(dd);
        return false;
    }

    /* Send proposition C0 */

    if (SendTransaction(conn->conn_info, workbuf, tosend, CF_DONE) == -1)
    {
        Log(LOG_LEVEL_ERR, "Couldn't send GET command");
        close(dd);
        return false;
    }

    buf = xmalloc(CF_BUFSIZE + sizeof(int));    /* Note CF_BUFSIZE not buf_size !! */

    Log(LOG_LEVEL_VERBOSE, "Copying remote file '%s:%s', expecting %jd bytes",
          conn->this_server, source, (intmax_t)size);

    n_read_total = 0;
    while (n_read_total < size)
    {
        int toget = MIN(size - n_read_total, buf_size);

        assert(toget != 0);

        /* Stage C1 - receive */
        int n_read;
        switch(conn->conn_info->protocol)
        {
        case CF_PROTOCOL_CLASSIC:
            n_read = RecvSocketStream(conn->conn_info->sd, buf, toget);
            break;
        case CF_PROTOCOL_TLS:
            n_read = TLSRecv(conn->conn_info->ssl, buf, toget);
            break;
        default:
            UnexpectedError("CopyRegularFileNet: ProtocolVersion %d!",
                            conn->conn_info->protocol);
            n_read = -1;
        }

        if (n_read <= 0)
        {
            /* This may happen on race conditions, where the file has shrunk
             * since we asked for its size in SYNCH ... STAT source */

            Log(LOG_LEVEL_ERR,
                "Error in client-server stream, has %s:%s shrunk? (code %d)",
                conn->this_server, source, n_read);
            close(dd);
            free(buf);
            return false;
        }

        /* If the first thing we get is an error message, break. */

        if ((n_read_total == 0) && (strncmp(buf, CF_FAILEDSTR, strlen(CF_FAILEDSTR)) == 0))
        {
            Log(LOG_LEVEL_INFO, "Network access to '%s:%s' denied",
                conn->this_server, source);
            close(dd);
            free(buf);
            return false;
        }

        if (strncmp(buf, cfchangedstr, strlen(cfchangedstr)) == 0)
        {
            Log(LOG_LEVEL_INFO, "Source '%s:%s' changed while copying",
                conn->this_server, source);
            close(dd);
            free(buf);
            return false;
        }


        /* Check for mismatch between encryption here and on server. */

        int value = -1;
        sscanf(buf, "t %d", &value);

        if ((value > 0) && (strncmp(buf + CF_INBAND_OFFSET, "BAD: ", 5) == 0))
        {
            Log(LOG_LEVEL_INFO, "Network access to cleartext '%s:%s' denied",
                conn->this_server, source);
            close(dd);
            free(buf);
            return false;
        }

        if (!FSWrite(dest, dd, buf, n_read))
        {
            Log(LOG_LEVEL_ERR,
                "Local disk write failed copying '%s:%s' to '%s'. (FSWrite: %s)",
                conn->this_server, source, dest, GetErrorStr());
            if (conn)
            {
                conn->error = true;
            }
            free(buf);
            unlink(dest);
            close(dd);
            FlushFileStream(conn->conn_info->sd, size - n_read_total);
            EVP_CIPHER_CTX_cleanup(&crypto_ctx);
            return false;
        }

        n_read_total += n_read;
    }

    /* If the file ends with a `hole', something needs to be written at
       the end.  Otherwise the kernel would truncate the file at the end
       of the last write operation. Write a null character and truncate
       it again.  */

    if (ftruncate(dd, n_read_total) < 0)
    {
        Log(LOG_LEVEL_ERR, "Copy failed (no space?) while copying '%s' from network '%s'",
            dest, GetErrorStr());
        free(buf);
        unlink(dest);
        close(dd);
        FlushFileStream(conn->conn_info->sd, size - n_read_total);
        return false;
    }

    close(dd);
    free(buf);
    return true;
}