Пример #1
0
/******************************************************************************
 * NETCON_send
 * Basically calls 'send()' unless we should use SSL
 * number of chars send is put in *sent
 */
DWORD NETCON_send(netconn_t *connection, const void *msg, size_t len, int flags,
		int *sent /* out */)
{
    /* send is always blocking. */
    set_socket_blocking(connection, TRUE);

    if(!connection->secure)
    {
	*sent = sock_send(connection->socket, msg, len, flags);
        return *sent == -1 ? WSAGetLastError() : ERROR_SUCCESS;
    }
    else
    {
        const BYTE *ptr = msg;
        size_t chunk_size;

        *sent = 0;

        while(len) {
            chunk_size = min(len, connection->ssl_sizes.cbMaximumMessage);
            if(!send_ssl_chunk(connection, ptr, chunk_size))
                return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;

            *sent += chunk_size;
            ptr += chunk_size;
            len -= chunk_size;
        }

        return ERROR_SUCCESS;
    }
}
Пример #2
0
BOOL NETCON_is_alive(netconn_t *netconn)
{
    int len;
    char b;

    set_socket_blocking(netconn, FALSE);
    len = sock_recv(netconn->socket, &b, 1, MSG_PEEK);

    return len == 1 || (len == -1 && WSAGetLastError() == WSAEWOULDBLOCK);
}
Пример #3
0
static inline int connect_server_socket(const char* name)
{
    int s = socket(AF_LOCAL, SOCK_STREAM, 0);
    set_socket_blocking(s, TRUE);
    if(socket_local_client_connect(s, name, ANDROID_SOCKET_NAMESPACE_ABSTRACT, SOCK_STREAM) >= 0)
    {
        APPL_TRACE_DEBUG2("connected to local socket:%s, fd:%d", name, s);
        return s;
    }
    else APPL_TRACE_ERROR3("connect to local socket:%s, fd:%d failed, errno:%d", name, s, errno);
    close(s);
    return -1;
}
Пример #4
0
/******************************************************************************
 * NETCON_recv
 * Basically calls 'recv()' unless we should use SSL
 * number of chars received is put in *recvd
 */
DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, BOOL blocking, int *recvd)
{
    *recvd = 0;
    if (!len)
        return ERROR_SUCCESS;

    if (!connection->secure)
    {
        set_socket_blocking(connection, blocking);
        *recvd = sock_recv(connection->socket, buf, len, 0);
        return *recvd == -1 ? WSAGetLastError() :  ERROR_SUCCESS;
    }
    else
    {
        SIZE_T size = 0;
        BOOL eof;
        DWORD res;

        if(connection->peek_msg) {
            size = min(len, connection->peek_len);
            memcpy(buf, connection->peek_msg, size);
            connection->peek_len -= size;
            connection->peek_msg += size;

            if(!connection->peek_len) {
                heap_free(connection->peek_msg_mem);
                connection->peek_msg_mem = connection->peek_msg = NULL;
            }

            *recvd = size;
            return ERROR_SUCCESS;
        }

        do {
            res = read_ssl_chunk(connection, (BYTE*)buf, len, blocking, &size, &eof);
            if(res != ERROR_SUCCESS) {
                if(res == WSAEWOULDBLOCK) {
                    if(size)
                        res = ERROR_SUCCESS;
                }else {
                    WARN("read_ssl_chunk failed\n");
                }
                break;
            }
        }while(!size && !eof);

        TRACE("received %ld bytes\n", size);
        *recvd = size;
        return res;
    }
}
Пример #5
0
static DWORD create_netconn_socket(server_t *server, netconn_t *netconn, DWORD timeout)
{
    int result;
    ULONG flag;
    DWORD res;

    init_winsock();

    assert(server->addr_len);
    result = netconn->socket = socket(server->addr.ss_family, SOCK_STREAM, 0);
    if(result != -1) {
        set_socket_blocking(netconn, FALSE);
        result = connect(netconn->socket, (struct sockaddr*)&server->addr, server->addr_len);
        if(result == -1)
        {
            res = WSAGetLastError();
            if (res == WSAEINPROGRESS || res == WSAEWOULDBLOCK) {
                FD_SET set;
                int res;
                socklen_t len = sizeof(res);
                TIMEVAL timeout_timeval = {0, timeout*1000};

                FD_ZERO(&set);
                FD_SET(netconn->socket, &set);
                res = select(netconn->socket+1, NULL, &set, NULL, &timeout_timeval);
                if(!res || res == SOCKET_ERROR) {
                    closesocket(netconn->socket);
                    netconn->socket = -1;
                    return ERROR_INTERNET_CANNOT_CONNECT;
                }
                if (!getsockopt(netconn->socket, SOL_SOCKET, SO_ERROR, (void *)&res, &len) && !res)
                    result = 0;
            }
        }
        if(result == -1)
        {
            closesocket(netconn->socket);
            netconn->socket = -1;
        }
    }
    if(result == -1)
        return ERROR_INTERNET_CANNOT_CONNECT;

    flag = 1;
    result = setsockopt(netconn->socket, IPPROTO_TCP, TCP_NODELAY, (void*)&flag, sizeof(flag));
    if(result < 0)
        WARN("setsockopt(TCP_NODELAY) failed\n");

    return ERROR_SUCCESS;
}
Пример #6
0
/******************************************************************************
 * NETCON_recv
 * Basically calls 'recv()' unless we should use SSL
 * number of chars received is put in *recvd
 */
DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t mode, int *recvd)
{
    *recvd = 0;
    if (!len)
        return ERROR_SUCCESS;

    if (!connection->secure)
    {
        int flags = 0;

        switch(mode) {
        case BLOCKING_ALLOW:
            break;
        case BLOCKING_DISALLOW:
            flags = WINE_MSG_DONTWAIT;
            break;
        case BLOCKING_WAITALL:
            flags = MSG_WAITALL;
            break;
        }

        set_socket_blocking(connection->socket, mode);
	*recvd = sock_recv(connection->socket, buf, len, flags);
	return *recvd == -1 ? sock_get_error(errno) :  ERROR_SUCCESS;
    }
    else
    {
        SIZE_T size = 0, cread;
        BOOL eof;
        DWORD res;

        if(connection->peek_msg) {
            size = min(len, connection->peek_len);
            memcpy(buf, connection->peek_msg, size);
            connection->peek_len -= size;
            connection->peek_msg += size;

            if(!connection->peek_len) {
                heap_free(connection->peek_msg_mem);
                connection->peek_msg_mem = connection->peek_msg = NULL;
            }
            /* check if we have enough data from the peek buffer */
            if(mode != BLOCKING_WAITALL || size == len) {
                *recvd = size;
                return ERROR_SUCCESS;
            }

            mode = BLOCKING_DISALLOW;
        }

        do {
            res = read_ssl_chunk(connection, (BYTE*)buf+size, len-size, mode, &cread, &eof);
            if(res != ERROR_SUCCESS) {
                if(res == WSAEWOULDBLOCK) {
                    if(size)
                        res = ERROR_SUCCESS;
                }else {
                    WARN("read_ssl_chunk failed\n");
                }
                break;
            }

            if(eof) {
                TRACE("EOF\n");
                break;
            }

            size += cread;
        }while(!size || (mode == BLOCKING_WAITALL && size < len));

        TRACE("received %zd bytes\n", size);
        *recvd = size;
        return res;
    }
}
Пример #7
0
static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, blocking_mode_t mode, SIZE_T *ret_size, BOOL *eof)
{
    const SIZE_T ssl_buf_size = conn->ssl_sizes.cbHeader+conn->ssl_sizes.cbMaximumMessage+conn->ssl_sizes.cbTrailer;
    SecBuffer bufs[4];
    SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
    SSIZE_T size, buf_len = 0;
    blocking_mode_t tmp_mode;
    int i;
    SECURITY_STATUS res;

    assert(conn->extra_len < ssl_buf_size);

    /* BLOCKING_WAITALL is handled by caller */
    if(mode == BLOCKING_WAITALL)
        mode = BLOCKING_ALLOW;

    if(conn->extra_len) {
        memcpy(conn->ssl_buf, conn->extra_buf, conn->extra_len);
        buf_len = conn->extra_len;
        conn->extra_len = 0;
        heap_free(conn->extra_buf);
        conn->extra_buf = NULL;
    }

    tmp_mode = buf_len ? BLOCKING_DISALLOW : mode;
    set_socket_blocking(conn->socket, tmp_mode);
    size = sock_recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, tmp_mode == BLOCKING_ALLOW ? 0 : WINE_MSG_DONTWAIT);
    if(size < 0) {
        if(!buf_len) {
            if(errno == EAGAIN || errno == EWOULDBLOCK) {
                TRACE("would block\n");
                return WSAEWOULDBLOCK;
            }
            WARN("recv failed\n");
            return ERROR_INTERNET_CONNECTION_ABORTED;
        }
    }else {
        buf_len += size;
    }

    *ret_size = buf_len;

    if(!buf_len) {
        *eof = TRUE;
        return ERROR_SUCCESS;
    }

    *eof = FALSE;

    do {
        memset(bufs, 0, sizeof(bufs));
        bufs[0].BufferType = SECBUFFER_DATA;
        bufs[0].cbBuffer = buf_len;
        bufs[0].pvBuffer = conn->ssl_buf;

        res = DecryptMessage(&conn->ssl_ctx, &buf_desc, 0, NULL);
        switch(res) {
        case SEC_E_OK:
            break;
        case SEC_I_CONTEXT_EXPIRED:
            TRACE("context expired\n");
            *eof = TRUE;
            return ERROR_SUCCESS;
        case SEC_E_INCOMPLETE_MESSAGE:
            assert(buf_len < ssl_buf_size);

            set_socket_blocking(conn->socket, mode);
            size = sock_recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, mode == BLOCKING_ALLOW ? 0 : WINE_MSG_DONTWAIT);
            if(size < 1) {
                if(size < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
                    TRACE("would block\n");

                    /* FIXME: Optimize extra_buf usage. */
                    conn->extra_buf = heap_alloc(buf_len);
                    if(!conn->extra_buf)
                        return ERROR_NOT_ENOUGH_MEMORY;

                    conn->extra_len = buf_len;
                    memcpy(conn->extra_buf, conn->ssl_buf, conn->extra_len);
                    return WSAEWOULDBLOCK;
                }

                return ERROR_INTERNET_CONNECTION_ABORTED;
            }

            buf_len += size;
            continue;
        default:
            WARN("failed: %08x\n", res);
            return ERROR_INTERNET_CONNECTION_ABORTED;
        }
    } while(res != SEC_E_OK);

    for(i=0; i < sizeof(bufs)/sizeof(*bufs); i++) {
        if(bufs[i].BufferType == SECBUFFER_DATA) {
            size = min(buf_size, bufs[i].cbBuffer);
            memcpy(buf, bufs[i].pvBuffer, size);
            if(size < bufs[i].cbBuffer) {
                assert(!conn->peek_len);
                conn->peek_msg_mem = conn->peek_msg = heap_alloc(bufs[i].cbBuffer - size);
                if(!conn->peek_msg)
                    return ERROR_NOT_ENOUGH_MEMORY;
                conn->peek_len = bufs[i].cbBuffer-size;
                memcpy(conn->peek_msg, (char*)bufs[i].pvBuffer+size, conn->peek_len);
            }

            *ret_size = size;
        }
    }

    for(i=0; i < sizeof(bufs)/sizeof(*bufs); i++) {
        if(bufs[i].BufferType == SECBUFFER_EXTRA) {
            conn->extra_buf = heap_alloc(bufs[i].cbBuffer);
            if(!conn->extra_buf)
                return ERROR_NOT_ENOUGH_MEMORY;

            conn->extra_len = bufs[i].cbBuffer;
            memcpy(conn->extra_buf, bufs[i].pvBuffer, conn->extra_len);
        }
    }

    return ERROR_SUCCESS;
}
Пример #8
0
inline void set_socket_blocking( SocketPtr const &s )
{
    set_socket_blocking( s->fd() );
}
Пример #9
0
inline void set_socket_blocking( Socket const &s )
{
    set_socket_blocking( s.fd() );
}
Пример #10
0
static DWORD netcon_secure_connect_setup(netconn_t *connection, BOOL compat_mode)
{
    SecBuffer out_buf = {0, SECBUFFER_TOKEN, NULL}, in_bufs[2] = {{0, SECBUFFER_TOKEN}, {0, SECBUFFER_EMPTY}};
    SecBufferDesc out_desc = {SECBUFFER_VERSION, 1, &out_buf}, in_desc = {SECBUFFER_VERSION, 2, in_bufs};
    SecHandle *cred = &cred_handle;
    BYTE *read_buf;
    SIZE_T read_buf_size = 2048;
    ULONG attrs = 0;
    CtxtHandle ctx;
    SSIZE_T size;
    int bits;
    const CERT_CONTEXT *cert;
    SECURITY_STATUS status;
    DWORD res = ERROR_SUCCESS;

    const DWORD isc_req_flags = ISC_REQ_ALLOCATE_MEMORY|ISC_REQ_USE_SESSION_KEY|ISC_REQ_CONFIDENTIALITY
        |ISC_REQ_SEQUENCE_DETECT|ISC_REQ_REPLAY_DETECT|ISC_REQ_MANUAL_CRED_VALIDATION;

    if(!ensure_cred_handle())
        return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;

    if(compat_mode) {
        if(!have_compat_cred_handle)
            return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
        cred = &compat_cred_handle;
    }

    read_buf = heap_alloc(read_buf_size);
    if(!read_buf)
        return ERROR_OUTOFMEMORY;

    status = InitializeSecurityContextW(cred, NULL, connection->server->name, isc_req_flags, 0, 0, NULL, 0,
            &ctx, &out_desc, &attrs, NULL);

    assert(status != SEC_E_OK);

    set_socket_blocking(connection, TRUE);

    while(status == SEC_I_CONTINUE_NEEDED || status == SEC_E_INCOMPLETE_MESSAGE) {
        if(out_buf.cbBuffer) {
            assert(status == SEC_I_CONTINUE_NEEDED);

            TRACE("sending %u bytes\n", out_buf.cbBuffer);

            size = sock_send(connection->socket, out_buf.pvBuffer, out_buf.cbBuffer, 0);
            if(size != out_buf.cbBuffer) {
                ERR("send failed\n");
                status = ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
                break;
            }

            FreeContextBuffer(out_buf.pvBuffer);
            out_buf.pvBuffer = NULL;
            out_buf.cbBuffer = 0;
        }

        if(status == SEC_I_CONTINUE_NEEDED) {
            assert(in_bufs[1].cbBuffer < read_buf_size);

            memmove(read_buf, (BYTE*)in_bufs[0].pvBuffer+in_bufs[0].cbBuffer-in_bufs[1].cbBuffer, in_bufs[1].cbBuffer);
            in_bufs[0].cbBuffer = in_bufs[1].cbBuffer;

            in_bufs[1].BufferType = SECBUFFER_EMPTY;
            in_bufs[1].cbBuffer = 0;
            in_bufs[1].pvBuffer = NULL;
        }

        assert(in_bufs[0].BufferType == SECBUFFER_TOKEN);
        assert(in_bufs[1].BufferType == SECBUFFER_EMPTY);

        if(in_bufs[0].cbBuffer + 1024 > read_buf_size) {
            BYTE *new_read_buf;

            new_read_buf = heap_realloc(read_buf, read_buf_size + 1024);
            if(!new_read_buf) {
                status = E_OUTOFMEMORY;
                break;
            }

            in_bufs[0].pvBuffer = read_buf = new_read_buf;
            read_buf_size += 1024;
        }

        size = sock_recv(connection->socket, read_buf+in_bufs[0].cbBuffer, read_buf_size-in_bufs[0].cbBuffer, 0);
        if(size < 1) {
            WARN("recv error\n");
            res = ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
            break;
        }

        TRACE("recv %lu bytes\n", size);

        in_bufs[0].cbBuffer += size;
        in_bufs[0].pvBuffer = read_buf;
        status = InitializeSecurityContextW(cred, &ctx, connection->server->name,  isc_req_flags, 0, 0, &in_desc,
                0, NULL, &out_desc, &attrs, NULL);
        TRACE("InitializeSecurityContext ret %08x\n", status);

        if(status == SEC_E_OK) {
            if(SecIsValidHandle(&connection->ssl_ctx))
                DeleteSecurityContext(&connection->ssl_ctx);
            connection->ssl_ctx = ctx;

            if(in_bufs[1].BufferType == SECBUFFER_EXTRA)
                FIXME("SECBUFFER_EXTRA not supported\n");

            status = QueryContextAttributesW(&ctx, SECPKG_ATTR_STREAM_SIZES, &connection->ssl_sizes);
            if(status != SEC_E_OK) {
                WARN("Could not get sizes\n");
                break;
            }

            status = QueryContextAttributesW(&ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&cert);
            if(status == SEC_E_OK) {
                res = netconn_verify_cert(connection, cert, cert->hCertStore);
                CertFreeCertificateContext(cert);
                if(res != ERROR_SUCCESS) {
                    WARN("cert verify failed: %u\n", res);
                    break;
                }
            }else {
                WARN("Could not get cert\n");
                break;
            }

            connection->ssl_buf = heap_alloc(connection->ssl_sizes.cbHeader + connection->ssl_sizes.cbMaximumMessage
                    + connection->ssl_sizes.cbTrailer);
            if(!connection->ssl_buf) {
                res = GetLastError();
                break;
            }
        }
    }

    heap_free(read_buf);

    if(status != SEC_E_OK || res != ERROR_SUCCESS) {
        WARN("Failed to establish SSL connection: %08x (%u)\n", status, res);
        heap_free(connection->ssl_buf);
        connection->ssl_buf = NULL;
        return res ? res : ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
    }

    TRACE("established SSL connection\n");
    connection->secure = TRUE;
    connection->security_flags |= SECURITY_FLAG_SECURE;

    bits = NETCON_GetCipherStrength(connection);
    if (bits >= 128)
        connection->security_flags |= SECURITY_FLAG_STRENGTH_STRONG;
    else if (bits >= 56)
        connection->security_flags |= SECURITY_FLAG_STRENGTH_MEDIUM;
    else
        connection->security_flags |= SECURITY_FLAG_STRENGTH_WEAK;

    if(connection->mask_errors)
        connection->server->security_flags = connection->security_flags;
    return ERROR_SUCCESS;
}