Example #1
0
VOID
NTAPI
LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
                        IN ULONG MessageId,
                        IN ULONG CallbackId,
                        IN CLIENT_ID ClientId)
{
    PLPCP_MESSAGE Message;
    PLIST_ENTRY ListHead, NextEntry;

    /* Check if the port we want is the connection port */
    if ((Port->Flags & LPCP_PORT_TYPE_MASK) > LPCP_UNCONNECTED_PORT)
    {
        /* Use it */
        Port = Port->ConnectionPort;
        if (!Port) return;
    }

    /* Loop the list */
    ListHead = &Port->LpcDataInfoChainHead;
    NextEntry = ListHead->Flink;
    while (ListHead != NextEntry)
    {
        /* Get the message */
        Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry);

        /* Make sure it matches */
        if ((Message->Request.MessageId == MessageId) &&
            (Message->Request.ClientId.UniqueThread == ClientId.UniqueThread) &&
            (Message->Request.ClientId.UniqueProcess == ClientId.UniqueProcess))
        {
            /* Unlink and free it */
            RemoveEntryList(&Message->Entry);
            InitializeListHead(&Message->Entry);
            LpcpFreeToPortZone(Message, LPCP_LOCK_HELD);
            break;
        }

        /* Go to the next entry */
        NextEntry = NextEntry->Flink;
    }
}
Example #2
0
VOID
LpcExitThread(
    PETHREAD Thread
    )
{
    PLPCP_MESSAGE Msg;

    //
    // Acquire the mutex that protects the LpcReplyMessage field of
    // the thread.  Zero the field so nobody else tries to process it
    // when we release the lock.
    //

    ExAcquireFastMutex( &LpcpLock );

    if (!IsListEmpty( &Thread->LpcReplyChain )) {
        RemoveEntryList( &Thread->LpcReplyChain );
        }

    Thread->LpcExitThreadCalled = TRUE;
    Thread->LpcReplyMessageId = 0;

    Msg = Thread->LpcReplyMessage;
    if (Msg != NULL) {
        Thread->LpcReplyMessage = NULL;
        if (Msg->RepliedToThread != NULL) {
            ObDereferenceObject( Msg->RepliedToThread );
            Msg->RepliedToThread = NULL;
            }

        LpcpTrace(( "Cleanup Msg %lx (%d) for Thread %lx allocated\n", Msg, IsListEmpty( &Msg->Entry ), Thread ));

        LpcpFreeToPortZone( Msg, TRUE );
        }

    ExReleaseFastMutex( &LpcpLock );
}
Example #3
0
/*
 * @implemented
 */
NTSTATUS
NTAPI
NtRequestWaitReplyPort(IN HANDLE PortHandle,
                       IN PPORT_MESSAGE LpcRequest,
                       IN OUT PPORT_MESSAGE LpcReply)
{
    PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
    NTSTATUS Status;
    PLPCP_MESSAGE Message;
    PETHREAD Thread = PsGetCurrentThread();
    BOOLEAN Callback;
    PKSEMAPHORE Semaphore;
    ULONG MessageType;
    PAGED_CODE();
    LPCTRACE(LPC_SEND_DEBUG,
             "Handle: %lx. Messages: %p/%p. Type: %lx\n",
             PortHandle,
             LpcRequest,
             LpcReply,
             LpcpGetMessageType(LpcRequest));

    /* Check if the thread is dying */
    if (Thread->LpcExitThreadCalled) return STATUS_THREAD_IS_TERMINATING;

    /* Check if this is an LPC Request */
    if (LpcpGetMessageType(LpcRequest) == LPC_REQUEST)
    {
        /* Then it's a callback */
        Callback = TRUE;
    }
    else if (LpcpGetMessageType(LpcRequest))
    {
        /* This is a not kernel-mode message */
        return STATUS_INVALID_PARAMETER;
    }
    else
    {
        /* This is a kernel-mode message without a callback */
        LpcRequest->u2.s2.Type |= LPC_REQUEST;
        Callback = FALSE;
    }

    /* Get the message type */
    MessageType = LpcRequest->u2.s2.Type;

    /* Validate the length */
    if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
            (ULONG)LpcRequest->u1.s1.TotalLength)
    {
        /* Fail */
        return STATUS_INVALID_PARAMETER;
    }

    /* Reference the object */
    Status = ObReferenceObjectByHandle(PortHandle,
                                       0,
                                       LpcPortObjectType,
                                       PreviousMode,
                                       (PVOID*)&Port,
                                       NULL);
    if (!NT_SUCCESS(Status)) return Status;

    /* Validate the message length */
    if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
            ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
    {
        /* Fail */
        ObDereferenceObject(Port);
        return STATUS_PORT_MESSAGE_TOO_LONG;
    }

    /* Allocate a message from the port zone */
    Message = LpcpAllocateFromPortZone();
    if (!Message)
    {
        /* Fail if we couldn't allocate a message */
        ObDereferenceObject(Port);
        return STATUS_NO_MEMORY;
    }

    /* Check if this is a callback */
    if (Callback)
    {
        /* FIXME: TODO */
        Semaphore = NULL; // we'd use the Thread Semaphore here
        ASSERT(FALSE);
    }
    else
    {
        /* No callback, just copy the message */
        _SEH2_TRY
        {
            /* Copy it */
            LpcpMoveMessage(&Message->Request,
            LpcRequest,
            LpcRequest + 1,
            MessageType,
            &Thread->Cid);
        }
        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
        {
            /* Fail */
            LpcpFreeToPortZone(Message, 0);
            ObDereferenceObject(Port);
            _SEH2_YIELD(return _SEH2_GetExceptionCode());
        }
        _SEH2_END;

        /* Acquire the LPC lock */
        KeAcquireGuardedMutex(&LpcpLock);

        /* Right now clear the port context */
        Message->PortContext = NULL;

        /* Check if this is a not connection port */
        if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
        {
            /* We want the connected port */
            QueuePort = Port->ConnectedPort;
            if (!QueuePort)
            {
                /* We have no connected port, fail */
                LpcpFreeToPortZone(Message, 3);
                ObDereferenceObject(Port);
                return STATUS_PORT_DISCONNECTED;
            }

            /* This will be the rundown port */
            ReplyPort = QueuePort;

            /* Check if this is a communication port */
            if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
            {
                /* Copy the port context and use the connection port */
                Message->PortContext = QueuePort->PortContext;
                ConnectionPort = QueuePort = Port->ConnectionPort;
                if (!ConnectionPort)
                {
                    /* Fail */
                    LpcpFreeToPortZone(Message, 3);
                    ObDereferenceObject(Port);
                    return STATUS_PORT_DISCONNECTED;
                }
            }
            else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
                     LPCP_COMMUNICATION_PORT)
            {
                /* Use the connection port for anything but communication ports */
                ConnectionPort = QueuePort = Port->ConnectionPort;
                if (!ConnectionPort)
                {
                    /* Fail */
                    LpcpFreeToPortZone(Message, 3);
                    ObDereferenceObject(Port);
                    return STATUS_PORT_DISCONNECTED;
                }
            }

            /* Reference the connection port if it exists */
            if (ConnectionPort) ObReferenceObject(ConnectionPort);
        }
        else
        {
            /* Otherwise, for a connection port, use the same port object */
            QueuePort = ReplyPort = Port;
        }

        /* No reply thread */
        Message->RepliedToThread = NULL;
        Message->SenderPort = Port;

        /* Generate the Message ID and set it */
        Message->Request.MessageId =  LpcpNextMessageId++;
        if (!LpcpNextMessageId) LpcpNextMessageId = 1;
        Message->Request.CallbackId = 0;

        /* Set the message ID for our thread now */
        Thread->LpcReplyMessageId = Message->Request.MessageId;
        Thread->LpcReplyMessage = NULL;

        /* Insert the message in our chain */
        InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
        InsertTailList(&ReplyPort->LpcReplyChainHead, &Thread->LpcReplyChain);
        LpcpSetPortToThread(Thread, Port);

        /* Release the lock and get the semaphore we'll use later */
        KeEnterCriticalRegion();
        KeReleaseGuardedMutex(&LpcpLock);
        Semaphore = QueuePort->MsgQueue.Semaphore;

        /* If this is a waitable port, wake it up */
        if (QueuePort->Flags & LPCP_WAITABLE_PORT)
        {
            /* Wake it */
            KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
        }
    }

    /* Now release the semaphore */
    LpcpCompleteWait(Semaphore);
    KeLeaveCriticalRegion();

    /* And let's wait for the reply */
    LpcpReplyWait(&Thread->LpcReplySemaphore, PreviousMode);

    /* Acquire the LPC lock */
    KeAcquireGuardedMutex(&LpcpLock);

    /* Get the LPC Message and clear our thread's reply data */
    Message = LpcpGetMessageFromThread(Thread);
    Thread->LpcReplyMessage = NULL;
    Thread->LpcReplyMessageId = 0;

    /* Check if we have anything on the reply chain*/
    if (!IsListEmpty(&Thread->LpcReplyChain))
    {
        /* Remove this thread and reinitialize the list */
        RemoveEntryList(&Thread->LpcReplyChain);
        InitializeListHead(&Thread->LpcReplyChain);
    }

    /* Release the lock */
    KeReleaseGuardedMutex(&LpcpLock);

    /* Check if we got a reply */
    if (Status == STATUS_SUCCESS)
    {
        /* Check if we have a valid message */
        if (Message)
        {
            LPCTRACE(LPC_SEND_DEBUG,
                     "Reply Messages: %p/%p\n",
                     &Message->Request,
                     (&Message->Request) + 1);

            /* Move the message */
            _SEH2_TRY
            {
                LpcpMoveMessage(LpcReply,
                &Message->Request,
                (&Message->Request) + 1,
                0,
                NULL);
            }
            _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
            {
                Status = _SEH2_GetExceptionCode();
            }
            _SEH2_END;

            /* Check if this is an LPC request with data information */
            if ((LpcpGetMessageType(&Message->Request) == LPC_REQUEST) &&
                    (Message->Request.u2.s2.DataInfoOffset))
            {
                /* Save the data information */
                LpcpSaveDataInfoMessage(Port, Message, 0);
            }
            else
            {
                /* Otherwise, just free it */
                LpcpFreeToPortZone(Message, 0);
            }
        }
        else
        {
            /* We don't have a reply */
            Status = STATUS_LPC_REPLY_LOST;
        }
    }
Example #4
0
/*
 * @implemented
 */
NTSTATUS
NTAPI
NtRequestPort(IN HANDLE PortHandle,
              IN PPORT_MESSAGE LpcRequest)
{
    PLPCP_PORT_OBJECT Port, QueuePort, ConnectionPort = NULL;
    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
    NTSTATUS Status;
    PLPCP_MESSAGE Message;
    PETHREAD Thread = PsGetCurrentThread();

    PKSEMAPHORE Semaphore;
    ULONG MessageType;
    PAGED_CODE();
    LPCTRACE(LPC_SEND_DEBUG,
             "Handle: %lx. Message: %p. Type: %lx\n",
             PortHandle,
             LpcRequest,
             LpcpGetMessageType(LpcRequest));

    /* Get the message type */
    MessageType = LpcRequest->u2.s2.Type | LPC_DATAGRAM;

    /* Can't have data information on this type of call */
    if (LpcRequest->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;

    /* Validate the length */
    if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
            (ULONG)LpcRequest->u1.s1.TotalLength)
    {
        /* Fail */
        return STATUS_INVALID_PARAMETER;
    }

    /* Reference the object */
    Status = ObReferenceObjectByHandle(PortHandle,
                                       0,
                                       LpcPortObjectType,
                                       PreviousMode,
                                       (PVOID*)&Port,
                                       NULL);
    if (!NT_SUCCESS(Status)) return Status;

    /* Validate the message length */
    if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
            ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
    {
        /* Fail */
        ObDereferenceObject(Port);
        return STATUS_PORT_MESSAGE_TOO_LONG;
    }

    /* Allocate a message from the port zone */
    Message = LpcpAllocateFromPortZone();
    if (!Message)
    {
        /* Fail if we couldn't allocate a message */
        ObDereferenceObject(Port);
        return STATUS_NO_MEMORY;
    }

    /* No callback, just copy the message */
    _SEH2_TRY
    {
        /* Copy it */
        LpcpMoveMessage(&Message->Request,
        LpcRequest,
        LpcRequest + 1,
        MessageType,
        &Thread->Cid);
    }
    _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
    {
        /* Fail */
        LpcpFreeToPortZone(Message, 0);
        ObDereferenceObject(Port);
        _SEH2_YIELD(return _SEH2_GetExceptionCode());
    }
    _SEH2_END;

    /* Acquire the LPC lock */
    KeAcquireGuardedMutex(&LpcpLock);

    /* Right now clear the port context */
    Message->PortContext = NULL;

    /* Check if this is a not connection port */
    if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
    {
        /* We want the connected port */
        QueuePort = Port->ConnectedPort;
        if (!QueuePort)
        {
            /* We have no connected port, fail */
            LpcpFreeToPortZone(Message, 3);
            ObDereferenceObject(Port);
            return STATUS_PORT_DISCONNECTED;
        }

        /* Check if this is a communication port */
        if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
        {
            /* Copy the port context and use the connection port */
            Message->PortContext = QueuePort->PortContext;
            ConnectionPort = QueuePort = Port->ConnectionPort;
            if (!ConnectionPort)
            {
                /* Fail */
                LpcpFreeToPortZone(Message, 3);
                ObDereferenceObject(Port);
                return STATUS_PORT_DISCONNECTED;
            }
        }
        else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
                 LPCP_COMMUNICATION_PORT)
        {
            /* Use the connection port for anything but communication ports */
            ConnectionPort = QueuePort = Port->ConnectionPort;
            if (!ConnectionPort)
            {
                /* Fail */
                LpcpFreeToPortZone(Message, 3);
                ObDereferenceObject(Port);
                return STATUS_PORT_DISCONNECTED;
            }
        }

        /* Reference the connection port if it exists */
        if (ConnectionPort) ObReferenceObject(ConnectionPort);
    }
    else
    {
        /* Otherwise, for a connection port, use the same port object */
        QueuePort = Port;
    }

    /* Reference QueuePort if we have it */
    if (QueuePort && ObReferenceObjectSafe(QueuePort))
    {
        /* Set sender's port */
        Message->SenderPort = Port;

        /* Generate the Message ID and set it */
        Message->Request.MessageId =  LpcpNextMessageId++;
        if (!LpcpNextMessageId) LpcpNextMessageId = 1;
        Message->Request.CallbackId = 0;

        /* No Message ID for the thread */
        PsGetCurrentThread()->LpcReplyMessageId = 0;

        /* Insert the message in our chain */
        InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);

        /* Release the lock and get the semaphore we'll use later */
        KeEnterCriticalRegion();
        KeReleaseGuardedMutex(&LpcpLock);

        /* Now release the semaphore */
        Semaphore = QueuePort->MsgQueue.Semaphore;
        LpcpCompleteWait(Semaphore);

        /* If this is a waitable port, wake it up */
        if (QueuePort->Flags & LPCP_WAITABLE_PORT)
        {
            /* Wake it */
            KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
        }

        KeLeaveCriticalRegion();

        /* Dereference objects */
        if (ConnectionPort) ObDereferenceObject(ConnectionPort);
        ObDereferenceObject(QueuePort);
        ObDereferenceObject(Port);
        LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
        return STATUS_SUCCESS;
    }

    Status = STATUS_PORT_DISCONNECTED;

    /* All done with a failure*/
    LPCTRACE(LPC_SEND_DEBUG,
             "Port: %p. Status: %p\n",
             Port,
             Status);

    /* The wait failed, free the message */
    if (Message) LpcpFreeToPortZone(Message, 3);

    ObDereferenceObject(Port);
    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
    return Status;
}
Example #5
0
/*
 * @implemented
 */
NTSTATUS
NTAPI
LpcRequestPort(IN PVOID PortObject,
               IN PPORT_MESSAGE LpcMessage)
{
    PLPCP_PORT_OBJECT Port = PortObject, QueuePort, ConnectionPort = NULL;
    ULONG MessageType;
    PLPCP_MESSAGE Message;
    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
    PAGED_CODE();
    LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", Port, LpcMessage);

    /* Check if this is a non-datagram message */
    if (LpcMessage->u2.s2.Type)
    {
        /* Get the message type */
        MessageType = LpcpGetMessageType(LpcMessage);

        /* Validate it */
        if ((MessageType < LPC_DATAGRAM) || (MessageType > LPC_CLIENT_DIED))
        {
            /* Fail */
            return STATUS_INVALID_PARAMETER;
        }

        /* Mark this as a kernel-mode message only if we really came from it */
        if ((PreviousMode == KernelMode) &&
                (LpcMessage->u2.s2.Type & LPC_KERNELMODE_MESSAGE))
        {
            /* We did, this is a kernel mode message */
            MessageType |= LPC_KERNELMODE_MESSAGE;
        }
    }
    else
    {
        /* This is a datagram */
        MessageType = LPC_DATAGRAM;
    }

    /* Can't have data information on this type of call */
    if (LpcMessage->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;

    /* Validate message sizes */
    if (((ULONG)LpcMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
            ((ULONG)LpcMessage->u1.s1.TotalLength <= (ULONG)LpcMessage->u1.s1.DataLength))
    {
        /* Fail */
        return STATUS_PORT_MESSAGE_TOO_LONG;
    }

    /* Allocate a new message */
    Message = LpcpAllocateFromPortZone();
    if (!Message) return STATUS_NO_MEMORY;

    /* Clear the context */
    Message->RepliedToThread = NULL;
    Message->PortContext = NULL;

    /* Copy the message */
    LpcpMoveMessage(&Message->Request,
                    LpcMessage,
                    LpcMessage + 1,
                    MessageType,
                    &PsGetCurrentThread()->Cid);

    /* Acquire the LPC lock */
    KeAcquireGuardedMutex(&LpcpLock);

    /* Check if this is anything but a connection port */
    if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
    {
        /* The queue port is the connected port */
        QueuePort = Port->ConnectedPort;
        if (QueuePort)
        {
            /* Check if this is a client port */
            if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
            {
                /* Then copy the context */
                Message->PortContext = QueuePort->PortContext;
                ConnectionPort = QueuePort = Port->ConnectionPort;
                if (!ConnectionPort)
                {
                    /* Fail */
                    LpcpFreeToPortZone(Message, 3);
                    return STATUS_PORT_DISCONNECTED;
                }
            }
            else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
            {
                /* Any other kind of port, use the connection port */
                ConnectionPort = QueuePort = Port->ConnectionPort;
                if (!ConnectionPort)
                {
                    /* Fail */
                    LpcpFreeToPortZone(Message, 3);
                    return STATUS_PORT_DISCONNECTED;
                }
            }

            /* If we have a connection port, reference it */
            if (ConnectionPort) ObReferenceObject(ConnectionPort);
        }
    }
    else
    {
        /* For connection ports, use the port itself */
        QueuePort = PortObject;
    }

    /* Make sure we have a port */
    if (QueuePort)
    {
        /* Generate the Message ID and set it */
        Message->Request.MessageId =  LpcpNextMessageId++;
        if (!LpcpNextMessageId) LpcpNextMessageId = 1;
        Message->Request.CallbackId = 0;

        /* No Message ID for the thread */
        PsGetCurrentThread()->LpcReplyMessageId = 0;

        /* Insert the message in our chain */
        InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);

        /* Release the lock and release the semaphore */
        KeEnterCriticalRegion();
        KeReleaseGuardedMutex(&LpcpLock);
        LpcpCompleteWait(QueuePort->MsgQueue.Semaphore);

        /* If this is a waitable port, wake it up */
        if (QueuePort->Flags & LPCP_WAITABLE_PORT)
        {
            /* Wake it */
            KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
        }

        /* We're done */
        KeLeaveCriticalRegion();
        if (ConnectionPort) ObDereferenceObject(ConnectionPort);
        LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
        return STATUS_SUCCESS;
    }

    /* If we got here, then free the message and fail */
    LpcpFreeToPortZone(Message, 3);
    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
    return STATUS_PORT_DISCONNECTED;
}
Example #6
0
/*
* @implemented
*/
NTSTATUS
NTAPI
LpcRequestWaitReplyPort(IN PVOID PortObject,
                        IN PPORT_MESSAGE LpcRequest,
                        OUT PPORT_MESSAGE LpcReply)
{
    PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
    NTSTATUS Status = STATUS_SUCCESS;
    PLPCP_MESSAGE Message;
    PETHREAD Thread = PsGetCurrentThread();
    BOOLEAN Callback = FALSE;
    PKSEMAPHORE Semaphore;
    USHORT MessageType;
    PAGED_CODE();

    Port = (PLPCP_PORT_OBJECT)PortObject;

    LPCTRACE(LPC_SEND_DEBUG,
             "Port: %p. Messages: %p/%p. Type: %lx\n",
             Port,
             LpcRequest,
             LpcReply,
             LpcpGetMessageType(LpcRequest));

    /* Check if the thread is dying */
    if (Thread->LpcExitThreadCalled) return STATUS_THREAD_IS_TERMINATING;

    /* Check if this is an LPC Request */
    MessageType = LpcpGetMessageType(LpcRequest);
    switch (MessageType)
    {
    /* No type */
    case 0:

        /* Assume LPC request */
        MessageType = LPC_REQUEST;
        break;

    /* LPC request callback */
    case LPC_REQUEST:

        /* This is a callback */
        Callback = TRUE;
        break;

    /* Anything else */
    case LPC_CLIENT_DIED:
    case LPC_PORT_CLOSED:
    case LPC_EXCEPTION:
    case LPC_DEBUG_EVENT:
    case LPC_ERROR_EVENT:

        /* Nothing to do */
        break;

    default:

        /* Invalid message type */
        return STATUS_INVALID_PARAMETER;
    }

    /* Set the request type */
    LpcRequest->u2.s2.Type = MessageType;

    /* Validate the message length */
    if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
            ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
    {
        /* Fail */
        return STATUS_PORT_MESSAGE_TOO_LONG;
    }

    /* Allocate a message from the port zone */
    Message = LpcpAllocateFromPortZone();
    if (!Message)
    {
        /* Fail if we couldn't allocate a message */
        return STATUS_NO_MEMORY;
    }

    /* Check if this is a callback */
    if (Callback)
    {
        /* FIXME: TODO */
        Semaphore = NULL; // we'd use the Thread Semaphore here
        ASSERT(FALSE);
        return STATUS_NOT_IMPLEMENTED;
    }
    else
    {
        /* No callback, just copy the message */
        LpcpMoveMessage(&Message->Request,
                        LpcRequest,
                        LpcRequest + 1,
                        0,
                        &Thread->Cid);

        /* Acquire the LPC lock */
        KeAcquireGuardedMutex(&LpcpLock);

        /* Right now clear the port context */
        Message->PortContext = NULL;

        /* Check if this is a not connection port */
        if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
        {
            /* We want the connected port */
            QueuePort = Port->ConnectedPort;
            if (!QueuePort)
            {
                /* We have no connected port, fail */
                LpcpFreeToPortZone(Message, 3);
                return STATUS_PORT_DISCONNECTED;
            }

            /* This will be the rundown port */
            ReplyPort = QueuePort;

            /* Check if this is a communication port */
            if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
            {
                /* Copy the port context and use the connection port */
                Message->PortContext = QueuePort->PortContext;
                ConnectionPort = QueuePort = Port->ConnectionPort;
                if (!ConnectionPort)
                {
                    /* Fail */
                    LpcpFreeToPortZone(Message, 3);
                    return STATUS_PORT_DISCONNECTED;
                }
            }
            else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
                     LPCP_COMMUNICATION_PORT)
            {
                /* Use the connection port for anything but communication ports */
                ConnectionPort = QueuePort = Port->ConnectionPort;
                if (!ConnectionPort)
                {
                    /* Fail */
                    LpcpFreeToPortZone(Message, 3);
                    return STATUS_PORT_DISCONNECTED;
                }
            }

            /* Reference the connection port if it exists */
            if (ConnectionPort) ObReferenceObject(ConnectionPort);
        }
        else
        {
            /* Otherwise, for a connection port, use the same port object */
            QueuePort = ReplyPort = Port;
        }

        /* No reply thread */
        Message->RepliedToThread = NULL;
        Message->SenderPort = Port;

        /* Generate the Message ID and set it */
        Message->Request.MessageId =  LpcpNextMessageId++;
        if (!LpcpNextMessageId) LpcpNextMessageId = 1;
        Message->Request.CallbackId = 0;

        /* Set the message ID for our thread now */
        Thread->LpcReplyMessageId = Message->Request.MessageId;
        Thread->LpcReplyMessage = NULL;

        /* Insert the message in our chain */
        InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
        InsertTailList(&ReplyPort->LpcReplyChainHead, &Thread->LpcReplyChain);
        LpcpSetPortToThread(Thread, Port);

        /* Release the lock and get the semaphore we'll use later */
        KeEnterCriticalRegion();
        KeReleaseGuardedMutex(&LpcpLock);
        Semaphore = QueuePort->MsgQueue.Semaphore;

        /* If this is a waitable port, wake it up */
        if (QueuePort->Flags & LPCP_WAITABLE_PORT)
        {
            /* Wake it */
            KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
        }
    }

    /* Now release the semaphore */
    LpcpCompleteWait(Semaphore);
    KeLeaveCriticalRegion();

    /* And let's wait for the reply */
    LpcpReplyWait(&Thread->LpcReplySemaphore, PreviousMode);

    /* Acquire the LPC lock */
    KeAcquireGuardedMutex(&LpcpLock);

    /* Get the LPC Message and clear our thread's reply data */
    Message = LpcpGetMessageFromThread(Thread);
    Thread->LpcReplyMessage = NULL;
    Thread->LpcReplyMessageId = 0;

    /* Check if we have anything on the reply chain*/
    if (!IsListEmpty(&Thread->LpcReplyChain))
    {
        /* Remove this thread and reinitialize the list */
        RemoveEntryList(&Thread->LpcReplyChain);
        InitializeListHead(&Thread->LpcReplyChain);
    }

    /* Release the lock */
    KeReleaseGuardedMutex(&LpcpLock);

    /* Check if we got a reply */
    if (Status == STATUS_SUCCESS)
    {
        /* Check if we have a valid message */
        if (Message)
        {
            LPCTRACE(LPC_SEND_DEBUG,
                     "Reply Messages: %p/%p\n",
                     &Message->Request,
                     (&Message->Request) + 1);

            /* Move the message */
            LpcpMoveMessage(LpcReply,
                            &Message->Request,
                            (&Message->Request) + 1,
                            0,
                            NULL);

            /* Acquire the lock */
            KeAcquireGuardedMutex(&LpcpLock);

            /* Check if we replied to a thread */
            if (Message->RepliedToThread)
            {
                /* Dereference */
                ObDereferenceObject(Message->RepliedToThread);
                Message->RepliedToThread = NULL;
            }


            /* Free the message */
            LpcpFreeToPortZone(Message, 3);
        }
        else
        {
            /* We don't have a reply */
            Status = STATUS_LPC_REPLY_LOST;
        }
    }
    else
    {
        /* The wait failed, free the message */
        if (Message) LpcpFreeToPortZone(Message, 0);
    }

    /* All done */
    LPCTRACE(LPC_SEND_DEBUG,
             "Port: %p. Status: %p\n",
             Port,
             Status);

    /* Dereference the connection port */
    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
    return Status;
}
Example #7
0
VOID
LpcExitThread (
    PETHREAD Thread
    )

/*++

Routine Description:

    This routine is called whenever a thread is exiting and need to cleanup the
    lpc port for the thread.

Arguments:

    Thread - Supplies the thread being terminated

Return Value:

    None.

--*/

{
    PLPCP_MESSAGE Msg;

    //
    //  Acquire the mutex that protects the LpcReplyMessage field of
    //  the thread.  Zero the field so nobody else tries to process it
    //  when we release the lock.
    //

    ASSERT (Thread == PsGetCurrentThread());

    LpcpAcquireLpcpLockByThread(Thread);

    if (!IsListEmpty( &Thread->LpcReplyChain )) {

        RemoveEntryList( &Thread->LpcReplyChain );
    }

    //
    //  Indicate that this thread is exiting
    //

    Thread->LpcExitThreadCalled = TRUE;
    Thread->LpcReplyMessageId = 0;

    //
    //  If we need to reply to a message then if the thread that we need to reply
    //  to is still around we want to dereference the thread and free the message
    //

    Msg = LpcpGetThreadMessage(Thread);

    if (Msg != NULL) {

        Thread->LpcReplyMessage = NULL;

        if (Msg->RepliedToThread != NULL) {

            ObDereferenceObject( Msg->RepliedToThread );

            Msg->RepliedToThread = NULL;
        }

        LpcpTrace(( "Cleanup Msg %lx (%d) for Thread %lx allocated\n", Msg, IsListEmpty( &Msg->Entry ), Thread ));

        LpcpFreeToPortZone( Msg, LPCP_MUTEX_OWNED | LPCP_MUTEX_RELEASE_ON_RETURN );
    }
    else {

        //
        //  Free the global lpc mutex.
        //

        LpcpReleaseLpcpLock();
    }

    //
    //  And return to our caller
    //

    return;
}
Example #8
0
VOID
LpcpDeletePort (
    IN PVOID Object
    )

/*++

Routine Description:

    This routine is the callback used for deleting a port object.

Arguments:

    Object - Supplies a pointer to the port object being deleted

Return Value:

    None.

--*/

{
    PETHREAD CurrentThread;
    PLPCP_PORT_OBJECT Port = Object;
    PLPCP_PORT_OBJECT ConnectionPort;
    LPC_CLIENT_DIED_MSG ClientPortClosedDatagram;
    PLPCP_MESSAGE Msg;
    PLIST_ENTRY Head, Next;
    HANDLE CurrentProcessId;
    NTSTATUS Status;
    LARGE_INTEGER RetryInterval = {(ULONG)(-10 * 1000 * 100), -1}; // 100 milliseconds

    PAGED_CODE();

    CurrentThread = PsGetCurrentThread ();

    //
    //  If the port is a server communication port then make sure that if
    //  there is a dangling client thread that we get rid of it.  This
    //  handles the case of someone calling NtAcceptConnectPort and not
    //  calling NtCompleteConnectPort
    //

    if ((Port->Flags & PORT_TYPE) == SERVER_COMMUNICATION_PORT) {

        PETHREAD ClientThread;

        LpcpAcquireLpcpLockByThread(CurrentThread);

        if ((ClientThread = Port->ClientThread) != NULL) {

            Port->ClientThread = NULL;

            LpcpReleaseLpcpLock();

            ObDereferenceObject( ClientThread );

        } else {

            LpcpReleaseLpcpLock();
        }
    }

    //
    //  Send an LPC_PORT_CLOSED datagram to whoever is connected
    //  to this port so they know they are no longer connected.
    //

    if ((Port->Flags & PORT_TYPE) == CLIENT_COMMUNICATION_PORT) {

        ClientPortClosedDatagram.PortMsg.u1.s1.TotalLength = sizeof( ClientPortClosedDatagram );
        ClientPortClosedDatagram.PortMsg.u1.s1.DataLength = sizeof( ClientPortClosedDatagram.CreateTime );

        ClientPortClosedDatagram.PortMsg.u2.s2.Type = LPC_PORT_CLOSED;
        ClientPortClosedDatagram.PortMsg.u2.s2.DataInfoOffset = 0;

        ClientPortClosedDatagram.CreateTime = PsGetCurrentProcess()->CreateTime;

        Status = LpcRequestPort( Port, (PPORT_MESSAGE)&ClientPortClosedDatagram );

        while (Status == STATUS_NO_MEMORY) {

            KeDelayExecutionThread(KernelMode, FALSE, &RetryInterval);

            Status = LpcRequestPort( Port, (PPORT_MESSAGE)&ClientPortClosedDatagram );
        }
    }

    //
    //  If connected, disconnect the port, and then scan the message queue
    //  for this port and dereference any messages in the queue.
    //

    LpcpDestroyPortQueue( Port, TRUE );

    //
    //  If we had mapped sections into the server or client communication ports,
    //  we need to unmap them in the context of that process.
    //

    if ( (Port->ClientSectionBase != NULL) ||
         (Port->ServerSectionBase != NULL) ) {

        //
        //  If the client has a port memory view, then unmap it
        //

        if (Port->ClientSectionBase != NULL) {

            MmUnmapViewOfSection( Port->MappingProcess,
                                  Port->ClientSectionBase );

        }

        //
        //  If the server has a port memory view, then unmap it
        //

        if (Port->ServerSectionBase != NULL) {

            MmUnmapViewOfSection( Port->MappingProcess,
                                  Port->ServerSectionBase  );

        }

        //
        //  Removing the reference added while mapping the section
        //

        ObDereferenceObject( Port->MappingProcess );

        Port->MappingProcess = NULL;
    }

    //
    //  Dereference the pointer to the connection port if it is not
    //  this port.
    //

    LpcpAcquireLpcpLockByThread(CurrentThread);

    ConnectionPort = Port->ConnectionPort;

    if (ConnectionPort) {

        CurrentProcessId = CurrentThread->Cid.UniqueProcess;

        Head = &ConnectionPort->LpcDataInfoChainHead;
        Next = Head->Flink;

        while (Next != Head) {

            Msg = CONTAINING_RECORD( Next, LPCP_MESSAGE, Entry );
            Next = Next->Flink;
            
            if (Port == ConnectionPort) {

                //
                //  If the Connection port is going away free all queued messages
                //

                RemoveEntryList( &Msg->Entry );
                InitializeListHead( &Msg->Entry );

                LpcpFreeToPortZone( Msg, LPCP_MUTEX_OWNED );

                //
                //  In LpcpFreeToPortZone the LPC lock is released and reacquired.
                //  Another thread might free the LPC message captured above
                //  in Next. We need to restart the search at the list head.
                //

                Next = Head->Flink;

            } else if ((Msg->Request.ClientId.UniqueProcess == CurrentProcessId)
                    &&
                ((Msg->SenderPort == Port) 
                        || 
                 (Msg->SenderPort == Port->ConnectedPort) 
                        || 
                 (Msg->SenderPort == ConnectionPort))) {

                //
                //  Test whether the message come from the same port and process
                //

                LpcpTrace(( "%s Freeing DataInfo Message %lx (%u.%u)  Port: %lx\n",
                            PsGetCurrentProcess()->ImageFileName,
                            Msg,
                            Msg->Request.MessageId,
                            Msg->Request.CallbackId,
                            ConnectionPort ));

                RemoveEntryList( &Msg->Entry );
                InitializeListHead( &Msg->Entry );

                LpcpFreeToPortZone( Msg, LPCP_MUTEX_OWNED );

                //
                //  In LpcpFreeToPortZone the LPC lock is released and reacquired.
                //  Another thread might free the LPC message captured above
                //  in Next. We need to restart the search at the list head.
                //

                Next = Head->Flink;
            }
        }

        LpcpReleaseLpcpLock();

        if (ConnectionPort != Port) {

            ObDereferenceObject( ConnectionPort );
        }

    } else {

        LpcpReleaseLpcpLock();
    }

    if (((Port->Flags & PORT_TYPE) == SERVER_CONNECTION_PORT) &&
        (ConnectionPort->ServerProcess != NULL)) {

        ObDereferenceObject( ConnectionPort->ServerProcess );

        ConnectionPort->ServerProcess = NULL;
    }

    //
    //  Free any static client security context
    //

    LpcpFreePortClientSecurity( Port );

    //
    //  And return to our caller
    //

    return;
}
Example #9
0
/*
 * @implemented
 */
NTSTATUS
NTAPI
NtSecureConnectPort(OUT PHANDLE PortHandle,
                    IN PUNICODE_STRING PortName,
                    IN PSECURITY_QUALITY_OF_SERVICE Qos,
                    IN OUT PPORT_VIEW ClientView OPTIONAL,
                    IN PSID ServerSid OPTIONAL,
                    IN OUT PREMOTE_PORT_VIEW ServerView OPTIONAL,
                    OUT PULONG MaxMessageLength OPTIONAL,
                    IN OUT PVOID ConnectionInformation OPTIONAL,
                    IN OUT PULONG ConnectionInformationLength OPTIONAL)
{
    ULONG ConnectionInfoLength = 0;
    PLPCP_PORT_OBJECT Port, ClientPort;
    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
    NTSTATUS Status = STATUS_SUCCESS;
    HANDLE Handle;
    PVOID SectionToMap;
    PLPCP_MESSAGE Message;
    PLPCP_CONNECTION_MESSAGE ConnectMessage;
    PETHREAD Thread = PsGetCurrentThread();
    ULONG PortMessageLength;
    LARGE_INTEGER SectionOffset;
    PTOKEN Token;
    PTOKEN_USER TokenUserInfo;
    PAGED_CODE();
    LPCTRACE(LPC_CONNECT_DEBUG,
             "Name: %wZ. Qos: %p. Views: %p/%p. Sid: %p\n",
             PortName,
             Qos,
             ClientView,
             ServerView,
             ServerSid);

    /* Validate client view */
    if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW)))
    {
        /* Fail */
        return STATUS_INVALID_PARAMETER;
    }

    /* Validate server view */
    if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW)))
    {
        /* Fail */
        return STATUS_INVALID_PARAMETER;
    }

    /* Check if caller sent connection information length */
    if (ConnectionInformationLength)
    {
        /* Retrieve the input length */
        ConnectionInfoLength = *ConnectionInformationLength;
    }

    /* Get the port */
    Status = ObReferenceObjectByName(PortName,
                                     0,
                                     NULL,
                                     PORT_CONNECT,
                                     LpcPortObjectType,
                                     PreviousMode,
                                     NULL,
                                     (PVOID *)&Port);
    if (!NT_SUCCESS(Status))
    {
        DPRINT1("Failed to reference port '%wZ': 0x%lx\n", PortName, Status);
        return Status;
    }

    /* This has to be a connection port */
    if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
    {
        /* It isn't, so fail */
        ObDereferenceObject(Port);
        return STATUS_INVALID_PORT_HANDLE;
    }

    /* Check if we have a SID */
    if (ServerSid)
    {
        /* Make sure that we have a server */
        if (Port->ServerProcess)
        {
            /* Get its token and query user information */
            Token = PsReferencePrimaryToken(Port->ServerProcess);
            //Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo);
            // FIXME: Need SeQueryInformationToken
            Status = STATUS_SUCCESS;
            TokenUserInfo = ExAllocatePoolWithTag(PagedPool, sizeof(TOKEN_USER), TAG_SE);
            TokenUserInfo->User.Sid = ServerSid;
            PsDereferencePrimaryToken(Token);

            /* Check for success */
            if (NT_SUCCESS(Status))
            {
                /* Compare the SIDs */
                if (!RtlEqualSid(ServerSid, TokenUserInfo->User.Sid))
                {
                    /* Fail */
                    Status = STATUS_SERVER_SID_MISMATCH;
                }

                /* Free token information */
                ExFreePoolWithTag(TokenUserInfo, TAG_SE);
            }
        }
        else
        {
            /* Invalid SID */
            Status = STATUS_SERVER_SID_MISMATCH;
        }

        /* Check if SID failed */
        if (!NT_SUCCESS(Status))
        {
            /* Quit */
            ObDereferenceObject(Port);
            return Status;
        }
    }

    /* Create the client port */
    Status = ObCreateObject(PreviousMode,
                            LpcPortObjectType,
                            NULL,
                            PreviousMode,
                            NULL,
                            sizeof(LPCP_PORT_OBJECT),
                            0,
                            0,
                            (PVOID *)&ClientPort);
    if (!NT_SUCCESS(Status))
    {
        /* Failed, dereference the server port and return */
        ObDereferenceObject(Port);
        return Status;
    }

    /* Setup the client port */
    RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT));
    ClientPort->Flags = LPCP_CLIENT_PORT;
    ClientPort->ConnectionPort = Port;
    ClientPort->MaxMessageLength = Port->MaxMessageLength;
    ClientPort->SecurityQos = *Qos;
    InitializeListHead(&ClientPort->LpcReplyChainHead);
    InitializeListHead(&ClientPort->LpcDataInfoChainHead);

    /* Check if we have dynamic security */
    if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
    {
        /* Remember that */
        ClientPort->Flags |= LPCP_SECURITY_DYNAMIC;
    }
    else
    {
        /* Create our own client security */
        Status = SeCreateClientSecurity(Thread,
                                        Qos,
                                        FALSE,
                                        &ClientPort->StaticSecurity);
        if (!NT_SUCCESS(Status))
        {
            /* Security failed, dereference and return */
            ObDereferenceObject(ClientPort);
            return Status;
        }
    }

    /* Initialize the port queue */
    Status = LpcpInitializePortQueue(ClientPort);
    if (!NT_SUCCESS(Status))
    {
        /* Failed */
        ObDereferenceObject(ClientPort);
        return Status;
    }

    /* Check if we have a client view */
    if (ClientView)
    {
        /* Get the section handle */
        Status = ObReferenceObjectByHandle(ClientView->SectionHandle,
                                           SECTION_MAP_READ |
                                           SECTION_MAP_WRITE,
                                           MmSectionObjectType,
                                           PreviousMode,
                                           (PVOID*)&SectionToMap,
                                           NULL);
        if (!NT_SUCCESS(Status))
        {
            /* Fail */
            ObDereferenceObject(Port);
            return Status;
        }

        /* Set the section offset */
        SectionOffset.QuadPart = ClientView->SectionOffset;

        /* Map it */
        Status = MmMapViewOfSection(SectionToMap,
                                    PsGetCurrentProcess(),
                                    &ClientPort->ClientSectionBase,
                                    0,
                                    0,
                                    &SectionOffset,
                                    &ClientView->ViewSize,
                                    ViewUnmap,
                                    0,
                                    PAGE_READWRITE);

        /* Update the offset */
        ClientView->SectionOffset = SectionOffset.LowPart;

        /* Check for failure */
        if (!NT_SUCCESS(Status))
        {
            /* Fail */
            ObDereferenceObject(SectionToMap);
            ObDereferenceObject(Port);
            return Status;
        }

        /* Update the base */
        ClientView->ViewBase = ClientPort->ClientSectionBase;

        /* Reference and remember the process */
        ClientPort->MappingProcess = PsGetCurrentProcess();
        ObReferenceObject(ClientPort->MappingProcess);
    }
    else
    {
        /* No section */
        SectionToMap = NULL;
    }

    /* Normalize connection information */
    if (ConnectionInfoLength > Port->MaxConnectionInfoLength)
    {
        /* Use the port's maximum allowed value */
        ConnectionInfoLength = Port->MaxConnectionInfoLength;
    }

    /* Allocate a message from the port zone */
    Message = LpcpAllocateFromPortZone();
    if (!Message)
    {
        /* Fail if we couldn't allocate a message */
        if (SectionToMap) ObDereferenceObject(SectionToMap);
        ObDereferenceObject(ClientPort);
        return STATUS_NO_MEMORY;
    }

    /* Set pointer to the connection message and fill in the CID */
    ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
    Message->Request.ClientId = Thread->Cid;

    /* Check if we have a client view */
    if (ClientView)
    {
        /* Set the view size */
        Message->Request.ClientViewSize = ClientView->ViewSize;

        /* Copy the client view and clear the server view */
        RtlCopyMemory(&ConnectMessage->ClientView,
                      ClientView,
                      sizeof(PORT_VIEW));
        RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW));
    }
    else
    {
        /* Set the size to 0 and clear the connect message */
        Message->Request.ClientViewSize = 0;
        RtlZeroMemory(ConnectMessage, sizeof(LPCP_CONNECTION_MESSAGE));
    }

    /* Set the section and client port. Port is NULL for now */
    ConnectMessage->ClientPort = NULL;
    ConnectMessage->SectionToMap = SectionToMap;

    /* Set the data for the connection request message */
    Message->Request.u1.s1.DataLength = (CSHORT)ConnectionInfoLength +
                                         sizeof(LPCP_CONNECTION_MESSAGE);
    Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
                                         Message->Request.u1.s1.DataLength;
    Message->Request.u2.s2.Type = LPC_CONNECTION_REQUEST;

    /* Check if we have connection information */
    if (ConnectionInformation)
    {
        /* Copy it in */
        RtlCopyMemory(ConnectMessage + 1,
                      ConnectionInformation,
                      ConnectionInfoLength);
    }

    /* Acquire the port lock */
    KeAcquireGuardedMutex(&LpcpLock);

    /* Check if someone already deleted the port name */
    if (Port->Flags & LPCP_NAME_DELETED)
    {
        /* Fail the request */
        Status = STATUS_OBJECT_NAME_NOT_FOUND;
    }
    else
    {
        /* Associate no thread yet */
        Message->RepliedToThread = NULL;

        /* Generate the Message ID and set it */
        Message->Request.MessageId =  LpcpNextMessageId++;
        if (!LpcpNextMessageId) LpcpNextMessageId = 1;
        Thread->LpcReplyMessageId = Message->Request.MessageId;

        /* Insert the message into the queue and thread chain */
        InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
        InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
        Thread->LpcReplyMessage = Message;

        /* Now we can finally reference the client port and link it*/
        ObReferenceObject(ClientPort);
        ConnectMessage->ClientPort = ClientPort;

        /* Enter a critical region */
        KeEnterCriticalRegion();
    }

    /* Add another reference to the port */
    ObReferenceObject(Port);

    /* Release the lock */
    KeReleaseGuardedMutex(&LpcpLock);

    /* Check for success */
    if (NT_SUCCESS(Status))
    {
        LPCTRACE(LPC_CONNECT_DEBUG,
                 "Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
                 Message,
                 ConnectMessage,
                 Port,
                 ClientPort,
                 Status);

        /* If this is a waitable port, set the event */
        if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
                                                         1,
                                                         FALSE);

        /* Release the queue semaphore and leave the critical region */
        LpcpCompleteWait(Port->MsgQueue.Semaphore);
        KeLeaveCriticalRegion();

        /* Now wait for a reply */
        LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
    }

    /* Check for failure */
    if (!NT_SUCCESS(Status)) goto Cleanup;

    /* Free the connection message */
    SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);

    /* Check if we got a message back */
    if (Message)
    {
        /* Check for new return length */
        if ((Message->Request.u1.s1.DataLength -
             sizeof(LPCP_CONNECTION_MESSAGE)) < ConnectionInfoLength)
        {
            /* Set new normalized connection length */
            ConnectionInfoLength = Message->Request.u1.s1.DataLength -
                                   sizeof(LPCP_CONNECTION_MESSAGE);
        }

        /* Check if we had connection information */
        if (ConnectionInformation)
        {
            /* Check if we had a length pointer */
            if (ConnectionInformationLength)
            {
                /* Return the length */
                *ConnectionInformationLength = ConnectionInfoLength;
            }

            /* Return the connection information */
            RtlCopyMemory(ConnectionInformation,
                          ConnectMessage + 1,
                          ConnectionInfoLength );
        }

        /* Make sure we had a connected port */
        if (ClientPort->ConnectedPort)
        {
            /* Get the message length before the port might get killed */
            PortMessageLength = Port->MaxMessageLength;

            /* Insert the client port */
            Status = ObInsertObject(ClientPort,
                                    NULL,
                                    PORT_ALL_ACCESS,
                                    0,
                                    (PVOID *)NULL,
                                    &Handle);
            if (NT_SUCCESS(Status))
            {
                /* Return the handle */
                *PortHandle = Handle;
                LPCTRACE(LPC_CONNECT_DEBUG,
                         "Handle: %p. Length: %lx\n",
                         Handle,
                         PortMessageLength);

                /* Check if maximum length was requested */
                if (MaxMessageLength) *MaxMessageLength = PortMessageLength;

                /* Check if we had a client view */
                if (ClientView)
                {
                    /* Copy it back */
                    RtlCopyMemory(ClientView,
                                  &ConnectMessage->ClientView,
                                  sizeof(PORT_VIEW));
                }

                /* Check if we had a server view */
                if (ServerView)
                {
                    /* Copy it back */
                    RtlCopyMemory(ServerView,
                                  &ConnectMessage->ServerView,
                                  sizeof(REMOTE_PORT_VIEW));
                }
            }
        }
        else
        {
            /* No connection port, we failed */
            if (SectionToMap) ObDereferenceObject(SectionToMap);

            /* Acquire the lock */
            KeAcquireGuardedMutex(&LpcpLock);

            /* Check if it's because the name got deleted */
            if (!(ClientPort->ConnectionPort) ||
                (Port->Flags & LPCP_NAME_DELETED))
            {
                /* Set the correct status */
                Status = STATUS_OBJECT_NAME_NOT_FOUND;
            }
            else
            {
                /* Otherwise, the caller refused us */
                Status = STATUS_PORT_CONNECTION_REFUSED;
            }

            /* Release the lock */
            KeReleaseGuardedMutex(&LpcpLock);

            /* Kill the port */
            ObDereferenceObject(ClientPort);
        }

        /* Free the message */
        LpcpFreeToPortZone(Message, 0);
    }
    else
    {
        /* No reply message, fail */
        if (SectionToMap) ObDereferenceObject(SectionToMap);
        ObDereferenceObject(ClientPort);
        Status = STATUS_PORT_CONNECTION_REFUSED;
    }

    /* Return status */
    ObDereferenceObject(Port);
    return Status;

Cleanup:
    /* We failed, free the message */
    SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);

    /* Check if the semaphore got signaled */
    if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
    {
        /* Wait on it */
        KeWaitForSingleObject(&Thread->LpcReplySemaphore,
                              WrExecutive,
                              KernelMode,
                              FALSE,
                              NULL);
    }

    /* Check if we had a message and free it */
    if (Message) LpcpFreeToPortZone(Message, 0);

    /* Dereference other objects */
    if (SectionToMap) ObDereferenceObject(SectionToMap);
    ObDereferenceObject(ClientPort);

    /* Return status */
    ObDereferenceObject(Port);
    return Status;
}
Example #10
0
VOID
NTAPI
LpcpDeletePort(IN PVOID ObjectBody)
{
    LARGE_INTEGER Timeout;
    PETHREAD Thread;
    PLPCP_PORT_OBJECT Port = (PLPCP_PORT_OBJECT)ObjectBody;
    PLPCP_PORT_OBJECT ConnectionPort;
    PLPCP_MESSAGE Message;
    PLIST_ENTRY ListHead, NextEntry;
    HANDLE Pid;
    CLIENT_DIED_MSG ClientDiedMsg;

    PAGED_CODE();
    LPCTRACE(LPC_CLOSE_DEBUG, "Port: %p. Flags: %lx\n", Port, Port->Flags);

    Timeout.QuadPart = -1000000;

    /* Check if this is a communication port */
    if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_COMMUNICATION_PORT)
    {
        /* Acquire the lock */
        KeAcquireGuardedMutex(&LpcpLock);

        /* Get the thread */
        Thread = Port->ClientThread;
        if (Thread)
        {
            /* Clear it */
            Port->ClientThread = NULL;

            /* Release the lock and dereference */
            KeReleaseGuardedMutex(&LpcpLock);
            ObDereferenceObject(Thread);
        }
        else
        {
            /* Release the lock */
            KeReleaseGuardedMutex(&LpcpLock);
        }
    }

    /* Check if this is a client-side port */
    if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
    {
        /* Setup the client died message */
        ClientDiedMsg.h.u1.s1.TotalLength = sizeof(ClientDiedMsg);
        ClientDiedMsg.h.u1.s1.DataLength = sizeof(ClientDiedMsg.CreateTime);
        ClientDiedMsg.h.u2.ZeroInit = 0;
        ClientDiedMsg.h.u2.s2.Type = LPC_PORT_CLOSED;
        ClientDiedMsg.CreateTime = PsGetCurrentProcess()->CreateTime;

        /* Send it */
        for (;;)
        {
            /* Send the message */
            if (LpcRequestPort(Port, &ClientDiedMsg.h) != STATUS_NO_MEMORY)
                break;

            /* Wait until trying again */
            KeDelayExecutionThread(KernelMode, FALSE, &Timeout);
        }
    }

    /* Destroy the port queue */
    LpcpDestroyPortQueue(Port, TRUE);

    /* Check if we had views */
    if ((Port->ClientSectionBase) || (Port->ServerSectionBase))
    {
        /* Check if we had a client view */
        if (Port->ClientSectionBase)
        {
            /* Unmap it */
            MmUnmapViewOfSection(Port->MappingProcess,
                                 Port->ClientSectionBase);
        }

        /* Check for a server view */
        if (Port->ServerSectionBase)
        {
            /* Unmap it */
            MmUnmapViewOfSection(Port->MappingProcess,
                                 Port->ServerSectionBase);
        }

        /* Dereference the mapping process */
        ObDereferenceObject(Port->MappingProcess);
        Port->MappingProcess = NULL;
    }

    /* Acquire the lock */
    KeAcquireGuardedMutex(&LpcpLock);

    /* Get the connection port */
    ConnectionPort = Port->ConnectionPort;
    if (ConnectionPort)
    {
        /* Get the PID */
        Pid = PsGetCurrentProcessId();

        /* Loop the data lists */
        ListHead = &ConnectionPort->LpcDataInfoChainHead;
        NextEntry = ListHead->Flink;
        while (NextEntry != ListHead)
        {
            /* Get the message */
            Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry);
            NextEntry = NextEntry->Flink;

            /* Check if this is the connection port */
            if (Port == ConnectionPort)
            {
                /* Free queued messages */
                RemoveEntryList(&Message->Entry);
                InitializeListHead(&Message->Entry);
                LpcpFreeToPortZone(Message, LPCP_LOCK_HELD);

                /* Restart at the head */
                NextEntry = ListHead->Flink;
            }
            else if ((Message->Request.ClientId.UniqueProcess == Pid) &&
                     ((Message->SenderPort == Port) ||
                      (Message->SenderPort == Port->ConnectedPort) ||
                      (Message->SenderPort == ConnectionPort)))
            {
                /* Remove it */
                RemoveEntryList(&Message->Entry);
                InitializeListHead(&Message->Entry);
                LpcpFreeToPortZone(Message, LPCP_LOCK_HELD);

                /* Restart at the head */
                NextEntry = ListHead->Flink;
            }
        }

        /* Release the lock */
        KeReleaseGuardedMutex(&LpcpLock);

        /* Dereference the object unless it's the same port */
        if (ConnectionPort != Port) ObDereferenceObject(ConnectionPort);

        /* Check if this is a connection port with a server process */
        if (((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT) &&
            (ConnectionPort->ServerProcess))
        {
            /* Dereference the server process */
            ObDereferenceObject(ConnectionPort->ServerProcess);
            ConnectionPort->ServerProcess = NULL;
        }
    }
    else
    {
        /* Release the lock */
        KeReleaseGuardedMutex(&LpcpLock);
    }

    /* Free client security */
    LpcpFreePortClientSecurity(Port);
    LPCTRACE(LPC_CLOSE_DEBUG, "Port: %p deleted\n", Port);
}
Example #11
0
VOID
NTAPI
LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port,
                     IN BOOLEAN Destroy)
{
    PLIST_ENTRY ListHead, NextEntry;
    PETHREAD Thread;
    PLPCP_MESSAGE Message;
    PLPCP_PORT_OBJECT ConnectionPort = NULL;
    PLPCP_CONNECTION_MESSAGE ConnectMessage;
    PLPCP_NONPAGED_PORT_QUEUE MessageQueue;

    PAGED_CODE();
    LPCTRACE(LPC_CLOSE_DEBUG, "Port: %p. Flags: %lx\n", Port, Port->Flags);

    /* Hold the lock */
    KeAcquireGuardedMutex(&LpcpLock);

    /* Check if we have a connected port */
    if (((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_UNCONNECTED_PORT) &&
        (Port->ConnectedPort))
    {
        /* Disconnect it */
        Port->ConnectedPort->ConnectedPort = NULL;
        ConnectionPort = Port->ConnectedPort->ConnectionPort;
        if (ConnectionPort)
        {
            /* Clear connection port */
            Port->ConnectedPort->ConnectionPort = NULL;
        }
    }

    /* Check if this is a connection port */
    if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT)
    {
        /* Delete the name */
        Port->Flags |= LPCP_NAME_DELETED;
    }

    /* Walk all the threads waiting and signal them */
    ListHead = &Port->LpcReplyChainHead;
    NextEntry = ListHead->Flink;
    while ((NextEntry) && (NextEntry != ListHead))
    {
        /* Get the Thread */
        Thread = CONTAINING_RECORD(NextEntry, ETHREAD, LpcReplyChain);

        /* Make sure we're not in exit */
        if (Thread->LpcExitThreadCalled) break;

        /* Move to the next entry */
        NextEntry = NextEntry->Flink;

        /* Remove and reinitialize the List */
        RemoveEntryList(&Thread->LpcReplyChain);
        InitializeListHead(&Thread->LpcReplyChain);

        /* Check if someone is waiting */
        if (!KeReadStateSemaphore(&Thread->LpcReplySemaphore))
        {
            /* Get the message */
            Message = LpcpGetMessageFromThread(Thread);
            if (Message)
            {
                /* Check if it's a connection request */
                if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST)
                {
                    /* Get the connection message */
                    ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);

                    /* Check if it had a section */
                    if (ConnectMessage->SectionToMap)
                    {
                        /* Dereference it */
                        ObDereferenceObject(ConnectMessage->SectionToMap);
                    }
                }

                /* Clear the reply message */
                Thread->LpcReplyMessage = NULL;

                /* And remove the message from the port zone */
                LpcpFreeToPortZone(Message, LPCP_LOCK_HELD);
                NextEntry = Port->LpcReplyChainHead.Flink;
            }

            /* Release the semaphore and reset message id count */
            Thread->LpcReplyMessageId = 0;
            KeReleaseSemaphore(&Thread->LpcReplySemaphore, 0, 1, FALSE);
        }
    }

    /* Reinitialize the list head */
    InitializeListHead(&Port->LpcReplyChainHead);

    /* Loop queued messages */
    while ((Port->MsgQueue.ReceiveHead.Flink) &&
           !(IsListEmpty(&Port->MsgQueue.ReceiveHead)))
    {
        /* Get the message */
        Message = CONTAINING_RECORD(Port->MsgQueue.ReceiveHead.Flink,
                                    LPCP_MESSAGE,
                                    Entry);

        /* Free and reinitialize it's list head */
        RemoveEntryList(&Message->Entry);
        InitializeListHead(&Message->Entry);

        /* Remove it from the port zone */
        LpcpFreeToPortZone(Message, LPCP_LOCK_HELD);
    }

    /* Release the lock */
    KeReleaseGuardedMutex(&LpcpLock);

    /* Dereference the connection port */
    if (ConnectionPort) ObDereferenceObject(ConnectionPort);

    /* Check if we have to free the port entirely */
    if (Destroy)
    {
        /* Check if the semaphore exists */
        if (Port->MsgQueue.Semaphore)
        {
            /* Use the semaphore to find the port queue and free it */
            MessageQueue = CONTAINING_RECORD(Port->MsgQueue.Semaphore,
                                             LPCP_NONPAGED_PORT_QUEUE,
                                             Semaphore);
            ExFreePoolWithTag(MessageQueue, 'troP');
        }
    }
}
Example #12
0
VOID
LpcpDeletePort(
    IN PVOID Object
    )
{
    PLPCP_PORT_OBJECT Port = Object;
    PLPCP_PORT_OBJECT ConnectionPort;
    LPC_CLIENT_DIED_MSG ClientPortClosedDatagram;
    PLPCP_MESSAGE Msg;
    PLIST_ENTRY Head, Next;
    HANDLE CurrentProcessId;

    PAGED_CODE();
    //
    // Send an LPC_PORT_CLOSED datagram to whoever is connected
    // to this port so they know they are no longer connected.
    //
    if ((Port->Flags & PORT_TYPE) == CLIENT_COMMUNICATION_PORT) {
        ClientPortClosedDatagram.PortMsg.u1.s1.TotalLength = sizeof( ClientPortClosedDatagram );
        ClientPortClosedDatagram.PortMsg.u1.s1.DataLength = sizeof( ClientPortClosedDatagram.CreateTime );
        ClientPortClosedDatagram.PortMsg.u2.s2.Type = LPC_PORT_CLOSED;
        ClientPortClosedDatagram.PortMsg.u2.s2.DataInfoOffset = 0;
        ClientPortClosedDatagram.CreateTime = PsGetCurrentProcess()->CreateTime;
        LpcRequestPort( Port, (PPORT_MESSAGE)&ClientPortClosedDatagram );
        }


    //
    // If connected, disconnect the port, and then scan the message queue
    // for this port and dereference any messages in the queue.
    //

    LpcpDestroyPortQueue( Port, TRUE );

    //
    // If the client has a port memory view, then unmap it
    //

    if (Port->ClientSectionBase != NULL) {
        ZwUnmapViewOfSection( NtCurrentProcess(),
                              Port->ClientSectionBase
                            );
        }

    //
    // If the server has a port memory view, then unmap it
    //

    if (Port->ServerSectionBase != NULL) {
        ZwUnmapViewOfSection( NtCurrentProcess(),
                              Port->ServerSectionBase
                            );
        }

    //
    // Dereference the pointer to the connection port if it is not
    // this port.
    //

    if (ConnectionPort = Port->ConnectionPort) {
        CurrentProcessId = PsGetCurrentThread()->Cid.UniqueProcess;
        ExAcquireFastMutex( &LpcpLock );
        Head = &ConnectionPort->LpcDataInfoChainHead;
        Next = Head->Flink;
        while (Next != Head) {
            Msg = CONTAINING_RECORD( Next, LPCP_MESSAGE, Entry );
            Next = Next->Flink;
            if (Msg->Request.ClientId.UniqueProcess == CurrentProcessId) {
                LpcpTrace(( "%s Freeing DataInfo Message %lx (%u.%u)  Port: %lx\n",
                            PsGetCurrentProcess()->ImageFileName,
                            Msg,
                            Msg->Request.MessageId,
                            Msg->Request.CallbackId,
                            ConnectionPort
                         ));
                RemoveEntryList( &Msg->Entry );
                LpcpFreeToPortZone( Msg, TRUE );
                }
            }
        ExReleaseFastMutex( &LpcpLock );

        if (ConnectionPort != Port) {
            ObDereferenceObject( ConnectionPort );
            }
        }

    //
    // Free any static client security context
    //

    LpcpFreePortClientSecurity( Port );
}
Example #13
0
VOID
LpcpDeletePort (
    IN PVOID Object
    )

/*++

Routine Description:

    This routine is the callback used for deleting a port object.

Arguments:

    Object - Supplies a pointer to the port object being deleted

Return Value:

    None.

--*/

{
    PLPCP_PORT_OBJECT Port = Object;
    PLPCP_PORT_OBJECT ConnectionPort;
    LPC_CLIENT_DIED_MSG ClientPortClosedDatagram;
    PLPCP_MESSAGE Msg;
    PLIST_ENTRY Head, Next;
    HANDLE CurrentProcessId;

    PAGED_CODE();

    //
    //  If the port is a server communication port then make sure that if
    //  there is a dangling client thread that we get rid of it.  This
    //  handles the case of someone calling NtAcceptConnectPort and not
    //  calling NtCompleteConnectPort
    //

    LpcpPortExtraDataDestroy( Port );

    if ((Port->Flags & PORT_TYPE) == SERVER_COMMUNICATION_PORT) {

        PETHREAD ClientThread;

        LpcpAcquireLpcpLock();

        if ((ClientThread = Port->ClientThread) != NULL) {

            Port->ClientThread = NULL;

            LpcpReleaseLpcpLock();

            ObDereferenceObject( ClientThread );

        } else {

            LpcpReleaseLpcpLock();
        }
    }

    //
    //  Send an LPC_PORT_CLOSED datagram to whoever is connected
    //  to this port so they know they are no longer connected.
    //

    if ((Port->Flags & PORT_TYPE) == CLIENT_COMMUNICATION_PORT) {

        ClientPortClosedDatagram.PortMsg.u1.s1.TotalLength = sizeof( ClientPortClosedDatagram );
        ClientPortClosedDatagram.PortMsg.u1.s1.DataLength = sizeof( ClientPortClosedDatagram.CreateTime );

        ClientPortClosedDatagram.PortMsg.u2.s2.Type = LPC_PORT_CLOSED;
        ClientPortClosedDatagram.PortMsg.u2.s2.DataInfoOffset = 0;

        ClientPortClosedDatagram.CreateTime = PsGetCurrentProcess()->CreateTime;

        LpcRequestPort( Port, (PPORT_MESSAGE)&ClientPortClosedDatagram );
    }

    //
    //  If connected, disconnect the port, and then scan the message queue
    //  for this port and dereference any messages in the queue.
    //

    LpcpDestroyPortQueue( Port, TRUE );

    //
    //  If the client has a port memory view, then unmap it
    //

    if (Port->ClientSectionBase != NULL) {

        MmUnmapViewOfSection( PsGetCurrentProcess(),
                              Port->ClientSectionBase );

    }

    //
    //  If the server has a port memory view, then unmap it
    //

    if (Port->ServerSectionBase != NULL) {

        MmUnmapViewOfSection( PsGetCurrentProcess(),
                              Port->ServerSectionBase  );

    }

    //
    //  Dereference the pointer to the connection port if it is not
    //  this port.
    //

    if (ConnectionPort = Port->ConnectionPort) {

        CurrentProcessId = PsGetCurrentThread()->Cid.UniqueProcess;

        LpcpAcquireLpcpLock();

        Head = &ConnectionPort->LpcDataInfoChainHead;
        Next = Head->Flink;

        while (Next != Head) {

            Msg = CONTAINING_RECORD( Next, LPCP_MESSAGE, Entry );
            Next = Next->Flink;

            if (Msg->Request.ClientId.UniqueProcess == CurrentProcessId) {

                LpcpTrace(( "%s Freeing DataInfo Message %lx (%u.%u)  Port: %lx\n",
                            PsGetCurrentProcess()->ImageFileName,
                            Msg,
                            Msg->Request.MessageId,
                            Msg->Request.CallbackId,
                            ConnectionPort ));

                RemoveEntryList( &Msg->Entry );

                LpcpFreeToPortZone( Msg, TRUE );
            }
        }

        LpcpReleaseLpcpLock();

        if (ConnectionPort != Port) {

            ObDereferenceObject( ConnectionPort );
        }
    }

    if (((Port->Flags & PORT_TYPE) == SERVER_CONNECTION_PORT) &&
        (ConnectionPort->ServerProcess != NULL)) {

        ObDereferenceObject( ConnectionPort->ServerProcess );

        ConnectionPort->ServerProcess = NULL;
    }

    //
    //  Free any static client security context
    //

    LpcpFreePortClientSecurity( Port );

    //
    //  And return to our caller
    //

    return;
}