Пример #1
0
NTSTATUS
NTAPI
NpGetClientSecurityContext(IN ULONG NamedPipeEnd,
                           IN PNP_CCB Ccb,
                           IN PETHREAD Thread,
                           IN PSECURITY_CLIENT_CONTEXT *Context)
{
    PSECURITY_CLIENT_CONTEXT NewContext;
    NTSTATUS Status;
    PAGED_CODE();

    if (NamedPipeEnd == FILE_PIPE_SERVER_END || Ccb->ClientQos.ContextTrackingMode != SECURITY_DYNAMIC_TRACKING)
    {
        NewContext = NULL;
        Status = STATUS_SUCCESS;
    }
    else
    {
        NewContext = ExAllocatePoolWithQuotaTag(PagedPool | POOL_QUOTA_FAIL_INSTEAD_OF_RAISE,
                                                sizeof(*NewContext),
                                                NPFS_CLIENT_SEC_CTX_TAG);
        if (!NewContext) return STATUS_INSUFFICIENT_RESOURCES;

        Status = SeCreateClientSecurity(Thread, &Ccb->ClientQos, 0, NewContext);
        if (!NT_SUCCESS(Status))
        {
            ExFreePool(NewContext);
            NewContext = NULL;
        }
    }
    *Context = NewContext;
    return Status;
}
Пример #2
0
NTSTATUS
NTAPI
NpInitializeSecurity(IN PNP_CCB Ccb,
                     IN PSECURITY_QUALITY_OF_SERVICE SecurityQos,
                     IN PETHREAD Thread)
{
    PSECURITY_CLIENT_CONTEXT ClientContext;
    NTSTATUS Status;
    PAGED_CODE();

    if (SecurityQos)
    {
        Ccb->ClientQos = *SecurityQos;
    }
    else
    {
        Ccb->ClientQos.Length = sizeof(Ccb->ClientQos);
        Ccb->ClientQos.ImpersonationLevel = SecurityImpersonation;
        Ccb->ClientQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
        Ccb->ClientQos.EffectiveOnly = TRUE;
    }

    NpUninitializeSecurity(Ccb);

    if (Ccb->ClientQos.ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
    {
        Status = STATUS_SUCCESS;
        Ccb->ClientContext = NULL;
        return Status;
    }

    ClientContext = ExAllocatePoolWithQuotaTag(PagedPool | POOL_QUOTA_FAIL_INSTEAD_OF_RAISE,
                                               sizeof(*ClientContext),
                                               NPFS_CLIENT_SEC_CTX_TAG);
    Ccb->ClientContext = ClientContext;
    if (!ClientContext) return STATUS_INSUFFICIENT_RESOURCES;

    Status = SeCreateClientSecurity(Thread, &Ccb->ClientQos, 0, ClientContext);
    if (!NT_SUCCESS(Status))
    {
        ExFreePool(Ccb->ClientContext);
        Ccb->ClientContext = NULL;
    }

    return Status;
}
Пример #3
0
NTSTATUS
NtImpersonateThread(
    __in HANDLE ServerThreadHandle,
    __in HANDLE ClientThreadHandle,
    __in PSECURITY_QUALITY_OF_SERVICE SecurityQos
    )

/*++

Routine Description:

    This routine is used to cause the server thread to impersonate the client
    thread.  The impersonation is done according to the specified quality
    of service parameters.



Arguments:

    ServerThreadHandle - Is a handle to the server thread (the impersonator, or
        doing the impersonation).  This handle must be open for
        THREAD_IMPERSONATE access.

    ClientThreadHandle - Is a handle to the Client thread (the impersonatee, or
        one being impersonated).   This handle must be open for
        THREAD_DIRECT_IMPERSONATION access.


    SecurityQos - A pointer to security quality of service information
        indicating what form of impersonation is to be performed.



Return Value:

    STATUS_SUCCESS - Indicates the call completed successfully.


--*/

{


    KPROCESSOR_MODE PreviousMode;
    NTSTATUS Status;
    PETHREAD ClientThread, ServerThread;
    SECURITY_QUALITY_OF_SERVICE CapturedQos;
    SECURITY_CLIENT_CONTEXT ClientSecurityContext;

    //
    // Get previous processor mode and probe and capture arguments if necessary
    //

    PreviousMode = KeGetPreviousMode();

    try {

        if (PreviousMode != KernelMode) {
            ProbeForReadSmallStructure (SecurityQos,
                                        sizeof (SECURITY_QUALITY_OF_SERVICE),
                                        sizeof (ULONG));
        }
        CapturedQos = *SecurityQos;

    } except (ExSystemExceptionFilter ()) {
        return GetExceptionCode ();
    }



    //
    // Reference the client thread, checking for appropriate access.
    //

    Status = ObReferenceObjectByHandle (ClientThreadHandle,           // Handle
                                        THREAD_DIRECT_IMPERSONATION,  // DesiredAccess
                                        PsThreadType,                 // ObjectType
                                        PreviousMode,                 // AccessMode
                                        &ClientThread,                // Object
                                        NULL);                        // GrantedAccess

    if (!NT_SUCCESS (Status)) {
        return Status;
    }

    //
    // Reference the client thread, checking for appropriate access.
    //

    Status = ObReferenceObjectByHandle (ServerThreadHandle,           // Handle
                                        THREAD_IMPERSONATE,           // DesiredAccess
                                        PsThreadType,                 // ObjectType
                                        PreviousMode,                 // AccessMode
                                        &ServerThread,                // Object
                                        NULL);                        // GrantedAccess

    if (!NT_SUCCESS (Status)) {
        ObDereferenceObject (ClientThread);
        return Status;
    }


    //
    // Get the client's security context
    //

    Status = SeCreateClientSecurity (ClientThread,             // ClientThread
                                     &CapturedQos,             // SecurityQos
                                     FALSE,                    // ServerIsRemote
                                     &ClientSecurityContext);  // ClientContext

    if (!NT_SUCCESS (Status)) {
        ObDereferenceObject (ServerThread);
        ObDereferenceObject (ClientThread);
        return Status;
    }


    //
    // Impersonate the client
    //

    Status = SeImpersonateClientEx (&ClientSecurityContext, ServerThread);

    SeDeleteClientSecurity (&ClientSecurityContext);

    //
    // Done.
    //


    ObDereferenceObject (ServerThread);
    ObDereferenceObject (ClientThread);

    return Status ;
}
Пример #4
0
/*
 * @implemented
 */
NTSTATUS
NTAPI
NtSecureConnectPort(OUT PHANDLE PortHandle,
                    IN PUNICODE_STRING PortName,
                    IN PSECURITY_QUALITY_OF_SERVICE Qos,
                    IN OUT PPORT_VIEW ClientView OPTIONAL,
                    IN PSID ServerSid OPTIONAL,
                    IN OUT PREMOTE_PORT_VIEW ServerView OPTIONAL,
                    OUT PULONG MaxMessageLength OPTIONAL,
                    IN OUT PVOID ConnectionInformation OPTIONAL,
                    IN OUT PULONG ConnectionInformationLength OPTIONAL)
{
    ULONG ConnectionInfoLength = 0;
    PLPCP_PORT_OBJECT Port, ClientPort;
    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
    NTSTATUS Status = STATUS_SUCCESS;
    HANDLE Handle;
    PVOID SectionToMap;
    PLPCP_MESSAGE Message;
    PLPCP_CONNECTION_MESSAGE ConnectMessage;
    PETHREAD Thread = PsGetCurrentThread();
    ULONG PortMessageLength;
    LARGE_INTEGER SectionOffset;
    PTOKEN Token;
    PTOKEN_USER TokenUserInfo;
    PAGED_CODE();
    LPCTRACE(LPC_CONNECT_DEBUG,
             "Name: %wZ. Qos: %p. Views: %p/%p. Sid: %p\n",
             PortName,
             Qos,
             ClientView,
             ServerView,
             ServerSid);

    /* Validate client view */
    if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW)))
    {
        /* Fail */
        return STATUS_INVALID_PARAMETER;
    }

    /* Validate server view */
    if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW)))
    {
        /* Fail */
        return STATUS_INVALID_PARAMETER;
    }

    /* Check if caller sent connection information length */
    if (ConnectionInformationLength)
    {
        /* Retrieve the input length */
        ConnectionInfoLength = *ConnectionInformationLength;
    }

    /* Get the port */
    Status = ObReferenceObjectByName(PortName,
                                     0,
                                     NULL,
                                     PORT_CONNECT,
                                     LpcPortObjectType,
                                     PreviousMode,
                                     NULL,
                                     (PVOID *)&Port);
    if (!NT_SUCCESS(Status))
    {
        DPRINT1("Failed to reference port '%wZ': 0x%lx\n", PortName, Status);
        return Status;
    }

    /* This has to be a connection port */
    if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
    {
        /* It isn't, so fail */
        ObDereferenceObject(Port);
        return STATUS_INVALID_PORT_HANDLE;
    }

    /* Check if we have a SID */
    if (ServerSid)
    {
        /* Make sure that we have a server */
        if (Port->ServerProcess)
        {
            /* Get its token and query user information */
            Token = PsReferencePrimaryToken(Port->ServerProcess);
            //Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo);
            // FIXME: Need SeQueryInformationToken
            Status = STATUS_SUCCESS;
            TokenUserInfo = ExAllocatePoolWithTag(PagedPool, sizeof(TOKEN_USER), TAG_SE);
            TokenUserInfo->User.Sid = ServerSid;
            PsDereferencePrimaryToken(Token);

            /* Check for success */
            if (NT_SUCCESS(Status))
            {
                /* Compare the SIDs */
                if (!RtlEqualSid(ServerSid, TokenUserInfo->User.Sid))
                {
                    /* Fail */
                    Status = STATUS_SERVER_SID_MISMATCH;
                }

                /* Free token information */
                ExFreePoolWithTag(TokenUserInfo, TAG_SE);
            }
        }
        else
        {
            /* Invalid SID */
            Status = STATUS_SERVER_SID_MISMATCH;
        }

        /* Check if SID failed */
        if (!NT_SUCCESS(Status))
        {
            /* Quit */
            ObDereferenceObject(Port);
            return Status;
        }
    }

    /* Create the client port */
    Status = ObCreateObject(PreviousMode,
                            LpcPortObjectType,
                            NULL,
                            PreviousMode,
                            NULL,
                            sizeof(LPCP_PORT_OBJECT),
                            0,
                            0,
                            (PVOID *)&ClientPort);
    if (!NT_SUCCESS(Status))
    {
        /* Failed, dereference the server port and return */
        ObDereferenceObject(Port);
        return Status;
    }

    /* Setup the client port */
    RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT));
    ClientPort->Flags = LPCP_CLIENT_PORT;
    ClientPort->ConnectionPort = Port;
    ClientPort->MaxMessageLength = Port->MaxMessageLength;
    ClientPort->SecurityQos = *Qos;
    InitializeListHead(&ClientPort->LpcReplyChainHead);
    InitializeListHead(&ClientPort->LpcDataInfoChainHead);

    /* Check if we have dynamic security */
    if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
    {
        /* Remember that */
        ClientPort->Flags |= LPCP_SECURITY_DYNAMIC;
    }
    else
    {
        /* Create our own client security */
        Status = SeCreateClientSecurity(Thread,
                                        Qos,
                                        FALSE,
                                        &ClientPort->StaticSecurity);
        if (!NT_SUCCESS(Status))
        {
            /* Security failed, dereference and return */
            ObDereferenceObject(ClientPort);
            return Status;
        }
    }

    /* Initialize the port queue */
    Status = LpcpInitializePortQueue(ClientPort);
    if (!NT_SUCCESS(Status))
    {
        /* Failed */
        ObDereferenceObject(ClientPort);
        return Status;
    }

    /* Check if we have a client view */
    if (ClientView)
    {
        /* Get the section handle */
        Status = ObReferenceObjectByHandle(ClientView->SectionHandle,
                                           SECTION_MAP_READ |
                                           SECTION_MAP_WRITE,
                                           MmSectionObjectType,
                                           PreviousMode,
                                           (PVOID*)&SectionToMap,
                                           NULL);
        if (!NT_SUCCESS(Status))
        {
            /* Fail */
            ObDereferenceObject(Port);
            return Status;
        }

        /* Set the section offset */
        SectionOffset.QuadPart = ClientView->SectionOffset;

        /* Map it */
        Status = MmMapViewOfSection(SectionToMap,
                                    PsGetCurrentProcess(),
                                    &ClientPort->ClientSectionBase,
                                    0,
                                    0,
                                    &SectionOffset,
                                    &ClientView->ViewSize,
                                    ViewUnmap,
                                    0,
                                    PAGE_READWRITE);

        /* Update the offset */
        ClientView->SectionOffset = SectionOffset.LowPart;

        /* Check for failure */
        if (!NT_SUCCESS(Status))
        {
            /* Fail */
            ObDereferenceObject(SectionToMap);
            ObDereferenceObject(Port);
            return Status;
        }

        /* Update the base */
        ClientView->ViewBase = ClientPort->ClientSectionBase;

        /* Reference and remember the process */
        ClientPort->MappingProcess = PsGetCurrentProcess();
        ObReferenceObject(ClientPort->MappingProcess);
    }
    else
    {
        /* No section */
        SectionToMap = NULL;
    }

    /* Normalize connection information */
    if (ConnectionInfoLength > Port->MaxConnectionInfoLength)
    {
        /* Use the port's maximum allowed value */
        ConnectionInfoLength = Port->MaxConnectionInfoLength;
    }

    /* Allocate a message from the port zone */
    Message = LpcpAllocateFromPortZone();
    if (!Message)
    {
        /* Fail if we couldn't allocate a message */
        if (SectionToMap) ObDereferenceObject(SectionToMap);
        ObDereferenceObject(ClientPort);
        return STATUS_NO_MEMORY;
    }

    /* Set pointer to the connection message and fill in the CID */
    ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
    Message->Request.ClientId = Thread->Cid;

    /* Check if we have a client view */
    if (ClientView)
    {
        /* Set the view size */
        Message->Request.ClientViewSize = ClientView->ViewSize;

        /* Copy the client view and clear the server view */
        RtlCopyMemory(&ConnectMessage->ClientView,
                      ClientView,
                      sizeof(PORT_VIEW));
        RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW));
    }
    else
    {
        /* Set the size to 0 and clear the connect message */
        Message->Request.ClientViewSize = 0;
        RtlZeroMemory(ConnectMessage, sizeof(LPCP_CONNECTION_MESSAGE));
    }

    /* Set the section and client port. Port is NULL for now */
    ConnectMessage->ClientPort = NULL;
    ConnectMessage->SectionToMap = SectionToMap;

    /* Set the data for the connection request message */
    Message->Request.u1.s1.DataLength = (CSHORT)ConnectionInfoLength +
                                         sizeof(LPCP_CONNECTION_MESSAGE);
    Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
                                         Message->Request.u1.s1.DataLength;
    Message->Request.u2.s2.Type = LPC_CONNECTION_REQUEST;

    /* Check if we have connection information */
    if (ConnectionInformation)
    {
        /* Copy it in */
        RtlCopyMemory(ConnectMessage + 1,
                      ConnectionInformation,
                      ConnectionInfoLength);
    }

    /* Acquire the port lock */
    KeAcquireGuardedMutex(&LpcpLock);

    /* Check if someone already deleted the port name */
    if (Port->Flags & LPCP_NAME_DELETED)
    {
        /* Fail the request */
        Status = STATUS_OBJECT_NAME_NOT_FOUND;
    }
    else
    {
        /* Associate no thread yet */
        Message->RepliedToThread = NULL;

        /* Generate the Message ID and set it */
        Message->Request.MessageId =  LpcpNextMessageId++;
        if (!LpcpNextMessageId) LpcpNextMessageId = 1;
        Thread->LpcReplyMessageId = Message->Request.MessageId;

        /* Insert the message into the queue and thread chain */
        InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
        InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
        Thread->LpcReplyMessage = Message;

        /* Now we can finally reference the client port and link it*/
        ObReferenceObject(ClientPort);
        ConnectMessage->ClientPort = ClientPort;

        /* Enter a critical region */
        KeEnterCriticalRegion();
    }

    /* Add another reference to the port */
    ObReferenceObject(Port);

    /* Release the lock */
    KeReleaseGuardedMutex(&LpcpLock);

    /* Check for success */
    if (NT_SUCCESS(Status))
    {
        LPCTRACE(LPC_CONNECT_DEBUG,
                 "Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
                 Message,
                 ConnectMessage,
                 Port,
                 ClientPort,
                 Status);

        /* If this is a waitable port, set the event */
        if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
                                                         1,
                                                         FALSE);

        /* Release the queue semaphore and leave the critical region */
        LpcpCompleteWait(Port->MsgQueue.Semaphore);
        KeLeaveCriticalRegion();

        /* Now wait for a reply */
        LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
    }

    /* Check for failure */
    if (!NT_SUCCESS(Status)) goto Cleanup;

    /* Free the connection message */
    SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);

    /* Check if we got a message back */
    if (Message)
    {
        /* Check for new return length */
        if ((Message->Request.u1.s1.DataLength -
             sizeof(LPCP_CONNECTION_MESSAGE)) < ConnectionInfoLength)
        {
            /* Set new normalized connection length */
            ConnectionInfoLength = Message->Request.u1.s1.DataLength -
                                   sizeof(LPCP_CONNECTION_MESSAGE);
        }

        /* Check if we had connection information */
        if (ConnectionInformation)
        {
            /* Check if we had a length pointer */
            if (ConnectionInformationLength)
            {
                /* Return the length */
                *ConnectionInformationLength = ConnectionInfoLength;
            }

            /* Return the connection information */
            RtlCopyMemory(ConnectionInformation,
                          ConnectMessage + 1,
                          ConnectionInfoLength );
        }

        /* Make sure we had a connected port */
        if (ClientPort->ConnectedPort)
        {
            /* Get the message length before the port might get killed */
            PortMessageLength = Port->MaxMessageLength;

            /* Insert the client port */
            Status = ObInsertObject(ClientPort,
                                    NULL,
                                    PORT_ALL_ACCESS,
                                    0,
                                    (PVOID *)NULL,
                                    &Handle);
            if (NT_SUCCESS(Status))
            {
                /* Return the handle */
                *PortHandle = Handle;
                LPCTRACE(LPC_CONNECT_DEBUG,
                         "Handle: %p. Length: %lx\n",
                         Handle,
                         PortMessageLength);

                /* Check if maximum length was requested */
                if (MaxMessageLength) *MaxMessageLength = PortMessageLength;

                /* Check if we had a client view */
                if (ClientView)
                {
                    /* Copy it back */
                    RtlCopyMemory(ClientView,
                                  &ConnectMessage->ClientView,
                                  sizeof(PORT_VIEW));
                }

                /* Check if we had a server view */
                if (ServerView)
                {
                    /* Copy it back */
                    RtlCopyMemory(ServerView,
                                  &ConnectMessage->ServerView,
                                  sizeof(REMOTE_PORT_VIEW));
                }
            }
        }
        else
        {
            /* No connection port, we failed */
            if (SectionToMap) ObDereferenceObject(SectionToMap);

            /* Acquire the lock */
            KeAcquireGuardedMutex(&LpcpLock);

            /* Check if it's because the name got deleted */
            if (!(ClientPort->ConnectionPort) ||
                (Port->Flags & LPCP_NAME_DELETED))
            {
                /* Set the correct status */
                Status = STATUS_OBJECT_NAME_NOT_FOUND;
            }
            else
            {
                /* Otherwise, the caller refused us */
                Status = STATUS_PORT_CONNECTION_REFUSED;
            }

            /* Release the lock */
            KeReleaseGuardedMutex(&LpcpLock);

            /* Kill the port */
            ObDereferenceObject(ClientPort);
        }

        /* Free the message */
        LpcpFreeToPortZone(Message, 0);
    }
    else
    {
        /* No reply message, fail */
        if (SectionToMap) ObDereferenceObject(SectionToMap);
        ObDereferenceObject(ClientPort);
        Status = STATUS_PORT_CONNECTION_REFUSED;
    }

    /* Return status */
    ObDereferenceObject(Port);
    return Status;

Cleanup:
    /* We failed, free the message */
    SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);

    /* Check if the semaphore got signaled */
    if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
    {
        /* Wait on it */
        KeWaitForSingleObject(&Thread->LpcReplySemaphore,
                              WrExecutive,
                              KernelMode,
                              FALSE,
                              NULL);
    }

    /* Check if we had a message and free it */
    if (Message) LpcpFreeToPortZone(Message, 0);

    /* Dereference other objects */
    if (SectionToMap) ObDereferenceObject(SectionToMap);
    ObDereferenceObject(ClientPort);

    /* Return status */
    ObDereferenceObject(Port);
    return Status;
}
Пример #5
0
//
//	Check IOCTL_VFD_OPEN_IMAGE input parameters
//
NTSTATUS
VfdOpenCheck(
	PDEVICE_EXTENSION			DeviceExtension,
	PVFD_IMAGE_INFO				ImageInfo,
	ULONG						InputLength)
{
	//	Check media status

	if (DeviceExtension->FileHandle ||
		DeviceExtension->FileBuffer) {

		VFDTRACE(VFDWARN, ("[VFD] image already opened.\n"));

		return STATUS_DEVICE_BUSY;
	}

	//	Check input parameter length

	if (InputLength < sizeof(VFD_IMAGE_INFO) ||
		InputLength < sizeof(VFD_IMAGE_INFO) + ImageInfo->NameLength)
	{
		return STATUS_INVALID_PARAMETER;
	}

	//	Check input parameters

	if (ImageInfo->MediaType == VFD_MEDIA_NONE ||
		ImageInfo->MediaType >= VFD_MEDIA_MAX) {

		VFDTRACE(VFDWARN, ("[VFD] invalid MediaType - %u.\n",
			ImageInfo->MediaType));

		return STATUS_INVALID_PARAMETER;
	}

	if (ImageInfo->DiskType == VFD_DISKTYPE_FILE &&
		ImageInfo->NameLength == 0) {

		VFDTRACE(VFDWARN,
			("[VFD] File name required for VFD_DISKTYPE_FILE.\n"));

		return STATUS_INVALID_PARAMETER;
	}

	//	create a security context to match the calling process' context
	//	the driver thread uses this context to impersonate the client
	//	to open the specified image file

//	if (ImageInfo->DiskType == VFD_DISKTYPE_FILE)
	{
		SECURITY_QUALITY_OF_SERVICE sqos;

		if (DeviceExtension->SecurityContext != NULL) {
			SeDeleteClientSecurity(DeviceExtension->SecurityContext);
		}
		else {
			DeviceExtension->SecurityContext =
				(PSECURITY_CLIENT_CONTEXT)ExAllocatePoolWithTag(
				NonPagedPool, sizeof(SECURITY_CLIENT_CONTEXT), VFD_POOL_TAG);
		}

		RtlZeroMemory(&sqos, sizeof(SECURITY_QUALITY_OF_SERVICE));

		sqos.Length = sizeof(SECURITY_QUALITY_OF_SERVICE);
		sqos.ImpersonationLevel		= SecurityImpersonation;
		sqos.ContextTrackingMode	= SECURITY_STATIC_TRACKING;
		sqos.EffectiveOnly			= FALSE;

		SeCreateClientSecurity(
			PsGetCurrentThread(), &sqos, FALSE,
			DeviceExtension->SecurityContext);
	}

	return STATUS_SUCCESS;
}
Пример #6
0
NTSTATUS
FileDiskDeviceControl (
    IN PDEVICE_OBJECT   DeviceObject,
    IN PIRP             Irp
    )
{
    PDEVICE_EXTENSION   device_extension;
    PIO_STACK_LOCATION  io_stack;
    NTSTATUS            status;
// 得到设备扩展
    device_extension = (PDEVICE_EXTENSION) DeviceObject->DeviceExtension;

// 得到当前设备栈
    io_stack = IoGetCurrentIrpStackLocation(Irp);
// 判断如果是还没有加载物理媒质就返回失败。但是
// IOCTL_FILE_DISK_OPEN_FILE是自定义的功能号,专
// 用来加载物理媒质的
    if (!device_extension->media_in_device &&
        io_stack->Parameters.DeviceIoControl.IoControlCode !=
        IOCTL_FILE_DISK_OPEN_FILE)
    {
        Irp->IoStatus.Status = STATUS_NO_MEDIA_IN_DEVICE;
        Irp->IoStatus.Information = 0;

        IoCompleteRequest(Irp, IO_NO_INCREMENT);

        return STATUS_NO_MEDIA_IN_DEVICE;
    }
// 根据不同的功能号处理...
    switch (io_stack->Parameters.DeviceIoControl.IoControlCode)
    {
    case IOCTL_FILE_DISK_OPEN_FILE:
        {
            SECURITY_QUALITY_OF_SERVICE security_quality_of_service;

            if (device_extension->media_in_device)
            {
                KdPrint(("FileDisk: IOCTL_FILE_DISK_OPEN_FILE: Media already opened\n"));

                status = STATUS_INVALID_DEVICE_REQUEST;
                Irp->IoStatus.Information = 0;
                break;
            }

            if (io_stack->Parameters.DeviceIoControl.InputBufferLength <
                sizeof(OPEN_FILE_INFORMATION))
            {
                status = STATUS_INVALID_PARAMETER;
                Irp->IoStatus.Information = 0;
                break;
            }

            if (io_stack->Parameters.DeviceIoControl.InputBufferLength <
                sizeof(OPEN_FILE_INFORMATION) +
                ((POPEN_FILE_INFORMATION)Irp->AssociatedIrp.SystemBuffer)->FileNameLength -
                sizeof(UCHAR))
            {
                status = STATUS_INVALID_PARAMETER;
                Irp->IoStatus.Information = 0;
                break;
            }

            if (device_extension->security_client_context != NULL)
            {
                SeDeleteClientSecurity(device_extension->security_client_context);
            }
            else
            {
                device_extension->security_client_context =
                    ExAllocatePool(NonPagedPool, sizeof(SECURITY_CLIENT_CONTEXT));
            }

            RtlZeroMemory(&security_quality_of_service, sizeof(SECURITY_QUALITY_OF_SERVICE));

            security_quality_of_service.Length = sizeof(SECURITY_QUALITY_OF_SERVICE);
            security_quality_of_service.ImpersonationLevel = SecurityImpersonation;
            security_quality_of_service.ContextTrackingMode = SECURITY_STATIC_TRACKING;
            security_quality_of_service.EffectiveOnly = FALSE;

            SeCreateClientSecurity(
                PsGetCurrentThread(),
                &security_quality_of_service,
                FALSE,
                device_extension->security_client_context
                );

            IoMarkIrpPending(Irp);

            ExInterlockedInsertTailList(
                &device_extension->list_head,
                &Irp->Tail.Overlay.ListEntry,
                &device_extension->list_lock
                );

            KeSetEvent(
                &device_extension->request_event,
                (KPRIORITY) 0,
                FALSE
                );

            status = STATUS_PENDING;

            break;
        }

    case IOCTL_FILE_DISK_CLOSE_FILE:
        {
            IoMarkIrpPending(Irp);

            ExInterlockedInsertTailList(
                &device_extension->list_head,
                &Irp->Tail.Overlay.ListEntry,
                &device_extension->list_lock
                );

            KeSetEvent(
                &device_extension->request_event,
                (KPRIORITY) 0,
                FALSE
                );

            status = STATUS_PENDING;

            break;
        }

    case IOCTL_FILE_DISK_QUERY_FILE:
        {
            POPEN_FILE_INFORMATION open_file_information;

            if (io_stack->Parameters.DeviceIoControl.OutputBufferLength <
                sizeof(OPEN_FILE_INFORMATION) + device_extension->file_name.Length - sizeof(UCHAR))
            {
                status = STATUS_BUFFER_TOO_SMALL;
                Irp->IoStatus.Information = 0;
                break;
            }

            open_file_information = (POPEN_FILE_INFORMATION) Irp->AssociatedIrp.SystemBuffer;

            open_file_information->FileSize.QuadPart = device_extension->file_size.QuadPart;
            open_file_information->ReadOnly = device_extension->read_only;
            open_file_information->FileNameLength = device_extension->file_name.Length;

            RtlCopyMemory(
                open_file_information->FileName,
                device_extension->file_name.Buffer,
                device_extension->file_name.Length
                );

            status = STATUS_SUCCESS;
            Irp->IoStatus.Information = sizeof(OPEN_FILE_INFORMATION) +
                open_file_information->FileNameLength - sizeof(UCHAR);

            break;
        }

    case IOCTL_DISK_CHECK_VERIFY:
    case IOCTL_CDROM_CHECK_VERIFY:
    case IOCTL_STORAGE_CHECK_VERIFY:
    case IOCTL_STORAGE_CHECK_VERIFY2:
        {
            status = STATUS_SUCCESS;
            Irp->IoStatus.Information = 0;
            break;
        }

    case IOCTL_DISK_GET_DRIVE_GEOMETRY:
    case IOCTL_CDROM_GET_DRIVE_GEOMETRY:
        {
            PDISK_GEOMETRY  disk_geometry;
            ULONGLONG       length;

            if (io_stack->Parameters.DeviceIoControl.OutputBufferLength <
                sizeof(DISK_GEOMETRY))
            {
                status = STATUS_BUFFER_TOO_SMALL;
                Irp->IoStatus.Information = 0;
                break;
            }

            disk_geometry = (PDISK_GEOMETRY) Irp->AssociatedIrp.SystemBuffer;

            length = device_extension->file_size.QuadPart;

            disk_geometry->Cylinders.QuadPart = length / SECTOR_SIZE / 32 / 2;
            disk_geometry->MediaType = FixedMedia;
            disk_geometry->TracksPerCylinder = 2;
            disk_geometry->SectorsPerTrack = 32;
            disk_geometry->BytesPerSector = SECTOR_SIZE;

            status = STATUS_SUCCESS;
            Irp->IoStatus.Information = sizeof(DISK_GEOMETRY);

            break;
        }

    case IOCTL_DISK_GET_LENGTH_INFO:
        {
            PGET_LENGTH_INFORMATION get_length_information;

            if (io_stack->Parameters.DeviceIoControl.OutputBufferLength <
                sizeof(GET_LENGTH_INFORMATION))
            {
                status = STATUS_BUFFER_TOO_SMALL;
                Irp->IoStatus.Information = 0;
                break;
            }

            get_length_information = (PGET_LENGTH_INFORMATION) Irp->AssociatedIrp.SystemBuffer;

            get_length_information->Length.QuadPart = device_extension->file_size.QuadPart;

            status = STATUS_SUCCESS;
            Irp->IoStatus.Information = sizeof(GET_LENGTH_INFORMATION);

        break;
        }

    case IOCTL_DISK_GET_PARTITION_INFO:
        {
            PPARTITION_INFORMATION  partition_information;
            ULONGLONG               length;

            if (io_stack->Parameters.DeviceIoControl.OutputBufferLength <
                sizeof(PARTITION_INFORMATION))
            {
                status = STATUS_BUFFER_TOO_SMALL;
                Irp->IoStatus.Information = 0;
                break;
            }

            partition_information = (PPARTITION_INFORMATION) Irp->AssociatedIrp.SystemBuffer;

            length = device_extension->file_size.QuadPart;

            partition_information->StartingOffset.QuadPart = 0;
            partition_information->PartitionLength.QuadPart = length;
            partition_information->HiddenSectors = 1;
            partition_information->PartitionNumber = 0;
            partition_information->PartitionType = 0;
            partition_information->BootIndicator = FALSE;
            partition_information->RecognizedPartition = FALSE;
            partition_information->RewritePartition = FALSE;

            status = STATUS_SUCCESS;
            Irp->IoStatus.Information = sizeof(PARTITION_INFORMATION);

            break;
        }

    case IOCTL_DISK_GET_PARTITION_INFO_EX:
        {
            PPARTITION_INFORMATION_EX   partition_information_ex;
            ULONGLONG                   length;

            if (io_stack->Parameters.DeviceIoControl.OutputBufferLength <
                sizeof(PARTITION_INFORMATION_EX))
            {
                status = STATUS_BUFFER_TOO_SMALL;
                Irp->IoStatus.Information = 0;
                break;
            }

            partition_information_ex = (PPARTITION_INFORMATION_EX) Irp->AssociatedIrp.SystemBuffer;

            length = device_extension->file_size.QuadPart;

            partition_information_ex->PartitionStyle = PARTITION_STYLE_MBR;
            partition_information_ex->StartingOffset.QuadPart = 0;
            partition_information_ex->PartitionLength.QuadPart = length;
            partition_information_ex->PartitionNumber = 0;
            partition_information_ex->RewritePartition = FALSE;
            partition_information_ex->Mbr.PartitionType = 0;
            partition_information_ex->Mbr.BootIndicator = FALSE;
            partition_information_ex->Mbr.RecognizedPartition = FALSE;
            partition_information_ex->Mbr.HiddenSectors = 1;

            status = STATUS_SUCCESS;
            Irp->IoStatus.Information = sizeof(PARTITION_INFORMATION_EX);

            break;
        }

    case IOCTL_DISK_IS_WRITABLE:
        {
            if (!device_extension->read_only)
            {
                status = STATUS_SUCCESS;
            }
            else
            {
                status = STATUS_MEDIA_WRITE_PROTECTED;
            }
            Irp->IoStatus.Information = 0;
            break;
        }

    case IOCTL_DISK_MEDIA_REMOVAL:
    case IOCTL_STORAGE_MEDIA_REMOVAL:
        {
            status = STATUS_SUCCESS;
            Irp->IoStatus.Information = 0;
            break;
        }

    case IOCTL_CDROM_READ_TOC:
        {
            PCDROM_TOC cdrom_toc;

            if (io_stack->Parameters.DeviceIoControl.OutputBufferLength <
                sizeof(CDROM_TOC))
            {
                status = STATUS_BUFFER_TOO_SMALL;
                Irp->IoStatus.Information = 0;
                break;
            }

            cdrom_toc = (PCDROM_TOC) Irp->AssociatedIrp.SystemBuffer;

            RtlZeroMemory(cdrom_toc, sizeof(CDROM_TOC));

            cdrom_toc->FirstTrack = 1;
            cdrom_toc->LastTrack = 1;
            cdrom_toc->TrackData[0].Control = TOC_DATA_TRACK;

            status = STATUS_SUCCESS;
            Irp->IoStatus.Information = sizeof(CDROM_TOC);

            break;
        }

    case IOCTL_DISK_SET_PARTITION_INFO:
        {
            if (device_extension->read_only)
            {
                status = STATUS_MEDIA_WRITE_PROTECTED;
                Irp->IoStatus.Information = 0;
                break;
            }

            if (io_stack->Parameters.DeviceIoControl.InputBufferLength <
                sizeof(SET_PARTITION_INFORMATION))
            {
                status = STATUS_INVALID_PARAMETER;
                Irp->IoStatus.Information = 0;
                break;
            }

            status = STATUS_SUCCESS;
            Irp->IoStatus.Information = 0;

            break;
        }

    case IOCTL_DISK_VERIFY:
        {
            PVERIFY_INFORMATION verify_information;

            if (io_stack->Parameters.DeviceIoControl.InputBufferLength <
                sizeof(VERIFY_INFORMATION))
            {
                status = STATUS_INVALID_PARAMETER;
                Irp->IoStatus.Information = 0;
                break;
            }

            verify_information = (PVERIFY_INFORMATION) Irp->AssociatedIrp.SystemBuffer;

            status = STATUS_SUCCESS;
            Irp->IoStatus.Information = verify_information->Length;

            break;
        }

    default:
        {
            KdPrint((
                "FileDisk: Unknown IoControlCode %#x\n",
                io_stack->Parameters.DeviceIoControl.IoControlCode
                ));

            status = STATUS_INVALID_DEVICE_REQUEST;
            Irp->IoStatus.Information = 0;
        }
    }

    if (status != STATUS_PENDING)
    {
        Irp->IoStatus.Status = status;

        IoCompleteRequest(Irp, IO_NO_INCREMENT);
    }

    return status;
}