static unsigned long amd_get_page_size(void *client_context) { unsigned long page_size; int result; struct amd_mem_context *mem_context = (struct amd_mem_context *)client_context; MSG_DBG("get_page_size: context: %p\n", client_context); MSG_DBG("get_page_size: pid 0x%p, address 0x%llx, size:0x%llx\n", mem_context->pid, mem_context->va, mem_context->size); result = rdma_interface->get_page_size( mem_context->va, mem_context->size, mem_context->pid, &page_size); if (result) { MSG_ERR("Could not get page size. %d", result); /* If we failed to get page size then do not know what to do. * Let's return some default value */ return 4096; } return page_size; }
static void amd_put_pages(struct sg_table *sg_head, void *client_context) { int ret = 0; struct amd_mem_context *mem_context = (struct amd_mem_context *)client_context; MSG_DBG("put_pages: sg_head %p client_context: 0x%p\n", sg_head, client_context); MSG_DBG("put_pages: pid 0x%p, address 0x%llx, size:0x%llx\n", mem_context->pid, mem_context->va, mem_context->size); MSG_DBG("put_pages: mem_context->p2p_info %p\n", mem_context->p2p_info); if (ACCESS_ONCE(mem_context->free_callback_called)) { MSG_DBG("put_pages: free callback was called\n"); return; } if (mem_context->p2p_info) { ret = rdma_interface->put_pages(&mem_context->p2p_info); mem_context->p2p_info = NULL; if (ret) MSG_ERR("put_pages failure: %d (callback status %d)", ret, mem_context->free_callback_called); } else MSG_ERR("put_pages: Pointer to p2p info is null\n"); }
static int amd_acquire(unsigned long addr, size_t size, void *peer_mem_private_data, char *peer_mem_name, void **client_context) { int ret; struct amd_mem_context *mem_context; struct pid *pid; /* Get pointer to structure describing current process */ pid = get_task_pid(current, PIDTYPE_PID); MSG_DBG("acquire: addr:0x%lx,size:0x%x, pid 0x%p\n", addr, (unsigned int)size, pid); /* Check if it is address handled by AMD GPU driver */ ret = rdma_interface->is_gpu_address(addr, pid); if (!ret) { MSG_DBG("acquire: Not GPU Address\n"); /* This is not GPU address */ return 0; } MSG_DBG("acquire: GPU address\n"); /* Initialize context used for operation with given address */ mem_context = kzalloc(sizeof(struct amd_mem_context), GFP_KERNEL); if (!mem_context) { MSG_ERR("failure to allocate memory for mem_context\n"); /* Error case handled as not GPU address */ return 0; } mem_context->free_callback_called = 0; mem_context->va = addr; mem_context->size = size; /* Save PI. It is guaranteed that such function will be * called in the correct process context as opposite to others. */ mem_context->pid = pid; MSG_DBG("acquire: Client context %p\n", mem_context); /* Return pointer to allocated context */ *client_context = mem_context; /* Increase counter to prevent module unloading */ __module_get(THIS_MODULE); /* Return 1 to inform that it is address which will be handled * by AMD GPU driver */ return 1; }
static int amd_get_pages(unsigned long addr, size_t size, int write, int force, struct sg_table *sg_head, void *client_context, void *core_context) { int ret; struct amd_mem_context *mem_context = (struct amd_mem_context *)client_context; MSG_DBG("get_pages: addr:0x%lx,size:0x%x, core_context:%p\n", addr, (unsigned int)size, core_context); if (!mem_context) { MSG_WARN("get_pages: Invalid client context"); return -EINVAL; } MSG_DBG("get_pages: pid :0x%p\n", mem_context->pid); if (addr != mem_context->va) { MSG_WARN("get_pages: Context address (0x%llx) is not the same", mem_context->va); return -EINVAL; } if (size != mem_context->size) { MSG_WARN("get_pages: Context size (0x%llx) is not the same", mem_context->size); return -EINVAL; } ret = rdma_interface->get_pages(addr, size, mem_context->pid, &mem_context->p2p_info, free_callback, mem_context); if (ret || !mem_context->p2p_info) { MSG_ERR("Could not rdma::get_pages failure: %d", ret); return ret; } mem_context->core_context = core_context; /* Note: At this stage it is OK not to fill sg_table */ return 0; }
static void amd_release(void *client_context) { struct amd_mem_context *mem_context = (struct amd_mem_context *)client_context; MSG_DBG("release: context: 0x%p\n", client_context); MSG_DBG("release: pid 0x%p, address 0x%llx, size:0x%llx\n", mem_context->pid, mem_context->va, mem_context->size); kfree(mem_context); /* Decrease counter to allow module unloading */ module_put(THIS_MODULE); }
static int amd_dma_unmap(struct sg_table *sg_head, void *client_context, struct device *dma_device) { struct amd_mem_context *mem_context = (struct amd_mem_context *)client_context; MSG_DBG("dma_unmap: Context 0x%p, sg_table 0x%p\n", client_context, sg_head); MSG_DBG("dma_unmap: pid 0x%p, address 0x%llx, size:0x%llx\n", mem_context->pid, mem_context->va, mem_context->size); /* Assume success */ return 0; }
static int amd_dma_map(struct sg_table *sg_head, void *client_context, struct device *dma_device, int dmasync, int *nmap) { /* * NOTE/TODO: * We could have potentially three cases for real memory * location: * - all memory in the local * - all memory in the system * - memory is spread (s/g) between local and system. * * In the case of all memory in the system we could use * iommu driver to build DMA addresses but not in the case * of local memory because currently iommu driver doesn't * deal with local/device memory addresses (it requires "struct * page"). * * Accordingly there return is assumption that iommu funcutionality * should be disabled so we could assume that sg_table already * contains DMA addresses. * */ struct amd_mem_context *mem_context = (struct amd_mem_context *)client_context; MSG_DBG("dma_map: Context 0x%p, sg_head 0x%p\n", client_context, sg_head); MSG_DBG("dma_map: pid 0x%p, address 0x%llx, size:0x%llx\n", mem_context->pid, mem_context->va, mem_context->size); if (!mem_context->p2p_info) { MSG_ERR("dma_map: No sg table were allocated\n"); return -EINVAL; } /* Copy information about previosly allocate sg_table */ *sg_head = *mem_context->p2p_info->pages; /* Return number of pages */ *nmap = mem_context->p2p_info->pages->nents; return 0; }
static void free_callback(void *client_priv) { struct amd_mem_context *mem_context = (struct amd_mem_context *)client_priv; MSG_DBG("free_callback: data 0x%p\n", mem_context); if (!mem_context) { MSG_WARN("free_callback: Invalid client context"); return; } MSG_DBG("mem_context->core_context 0x%p\n", mem_context->core_context); /* Call back IB stack asking to invalidate memory */ (*ib_invalidate_callback) (ib_reg_handle, mem_context->core_context); /* amdkfd will freed resources when we returned from this callback. * Set flag to inform that there is nothing to do on "put_pages", etc. */ ACCESS_ONCE(mem_context->free_callback_called) = 1; }
MSG_ERR_T msgi_copy_to_sbp(int buf_idx, void *data,int sizeb) { volatile msgi_sbp_meta_t *local_meta_ptr = NULL; volatile long local_buffer; unsigned char *p; int bcount; int index; int i; int level = 3; MSG_DBG(level, MESSAGE, "In msgmi_copy_to_sbp()"); local_meta_ptr = (msgi_sbp_meta_t *) (uintptr_t)msgi_sbp_meta[local_rank][buf_idx]; local_buffer = msgi_sbp[local_rank][buf_idx]; MSG_DBG(level, MESSAGE, "(before)producer = %d", local_meta_ptr->producer); if(sizeb > block_size) // now block size = 16KB { p = (unsigned char *)data; //block count bcount = msgi_count_blocks(sizeb); MSG_DBG(level, MESSAGE, "bcount = %d", bcount); //example: data_size = 120K, block_size = 16K, buffer_size = 64K // bcount = 8, for entry[0-6] is 16K and entry[7] is 120-(8-1)*16K = 8K //The for loop is handling the entry[0-6] for(i=0;i<(bcount-1);i++) { index = (local_meta_ptr->producer + 1) % MSGI_MAX_SBP_BLOCKS_PER_BUFFER; //when the producer is faster than consumer while(index == local_meta_ptr->consumer){} MSG_DBG(level, MESSAGE, "addr = %ld", (long)GET_ADDR(local_buffer, local_meta_ptr->producer, block_size)); memcpy(GET_ADDR(local_buffer,local_meta_ptr->producer,block_size), p, block_size); local_meta_ptr->producer = index; p += block_size; } //handle the remain data, e.g. entry[7] index = (local_meta_ptr->producer + 1) % MSGI_MAX_SBP_BLOCKS_PER_BUFFER; while(index == local_meta_ptr->consumer){} MSG_DBG(level, MESSAGE, "addr = %ld", (long)GET_ADDR(local_buffer, local_meta_ptr->producer, block_size)); memcpy(GET_ADDR(local_buffer,local_meta_ptr->producer,block_size), p,sizeb - (bcount - 1)* block_size); local_meta_ptr->producer = index; //mark the entry to BUSY local_meta_ptr->control_flag = BUSY; } else { index = (local_meta_ptr->producer + 1) % MSGI_MAX_SBP_BLOCKS_PER_BUFFER; //block is full while(index == local_meta_ptr->consumer){} MSG_DBG(level, MESSAGE, "addr = %ld", (long)GET_ADDR(local_buffer, local_meta_ptr->producer, block_size)); memcpy(GET_ADDR(local_buffer,local_meta_ptr->producer,block_size),data,sizeb); local_meta_ptr->producer = index; //mark the entry to BUSY local_meta_ptr->control_flag = BUSY; } MSG_DBG(level, MESSAGE, "(after)producer = %d", local_meta_ptr->producer); MSG_DBG(level, MESSAGE, "Leaving msgmi_copy_to_shm() ..."); return MSG_SUCCESS; }
MSG_ERR_T msgi_copy_from_sbp(int src_rank,int buf_idx, void *data,int sizeb) { volatile msgi_sbp_meta_t *remote_meta_ptr = NULL; volatile long remote_buffer; volatile unsigned char *p; volatile int index; int bcount; int i; int level = 3; MSG_DBG(level, MESSAGE, "In msgmi_copy_from_sbp()"); remote_meta_ptr = (msgi_sbp_meta_t *)(uintptr_t)msgi_sbp_meta[src_rank][buf_idx]; remote_buffer = msgi_sbp[src_rank][buf_idx]; MSG_DBG(level, MESSAGE, "(before)consumer = %d", remote_meta_ptr->consumer); if(sizeb > block_size) // now block size = 16KB { p = (unsigned char *)data; bcount = msgi_count_blocks(sizeb); MSG_DBG(level, MESSAGE, "bcount = %d", bcount); for(i=0;i<(bcount-1);i++) { // Check if the shared memory queue is not empty while(remote_meta_ptr->consumer == remote_meta_ptr->producer){} memcpy((void *)p, GET_ADDR(remote_buffer,remote_meta_ptr->consumer,block_size), block_size); index = (remote_meta_ptr->consumer + 1) % MSGI_MAX_SBP_BLOCKS_PER_BUFFER; remote_meta_ptr->consumer = index; p += block_size; } // Check if the shared memory queue is not empty while(remote_meta_ptr->consumer == remote_meta_ptr->producer){} memcpy((void *)p, GET_ADDR(remote_buffer,remote_meta_ptr->consumer,block_size), sizeb - (bcount - 1) * block_size); index = (remote_meta_ptr->consumer + 1) % MSGI_MAX_SBP_BLOCKS_PER_BUFFER; remote_meta_ptr->consumer = index; //mark the entry to FREE remote_meta_ptr->control_flag = FREE; } else { //Check if the shared memory queue is not empty while(remote_meta_ptr->consumer == remote_meta_ptr->producer){} memcpy(data,GET_ADDR(remote_buffer,remote_meta_ptr->consumer,block_size),sizeb); index = (remote_meta_ptr->consumer + 1) % MSGI_MAX_SBP_BLOCKS_PER_BUFFER; remote_meta_ptr->consumer = index; //mark the entry to FREE remote_meta_ptr->control_flag = FREE; } MSG_DBG(level, MESSAGE, "(after)consumer = %d", remote_meta_ptr->consumer); MSG_DBG(level, MESSAGE, "Leaving msgmi_copy_from_sbp() ..."); return MSG_SUCCESS; }