Reference counting per-thread data

It is accompanied by the following fixes:
(1) Fix put ppd locations in mcexec_wait_syscall()
(2) Move put ptd to end of mcexec_terminate_thread_unsafe() and mcexec_ret_syscall()
(3) Add debug messages for ptd add/get/put
(4) Fix ptd-add/get/put matching in mcexec_wait_syscall()
    * Skip put when woken-up from wait_event_interruptible() by signal

Change-Id: Ib9be3f5e62a7a370197cb36c9fa7c4d79f44c314
This commit is contained in:
Masamichi Takagi
2018-09-03 19:36:28 +09:00
parent a121ffc785
commit b8bacdd2de
3 changed files with 269 additions and 112 deletions

View File

@ -63,6 +63,13 @@
#define dprintk(...)
#endif
//#define DEBUG_PTD
#ifdef DEBUG_PTD
#define pr_ptd(msg, tid, ptd) do { printk("%s: " msg ",tid=%d,refc=%d\n", __FUNCTION__, tid, atomic_read(&ptd->refcount)); } while(0)
#else
#define pr_ptd(msg, tid, ptd) do { } while(0)
#endif
static long pager_call_irq(ihk_os_t os, struct syscall_request *req);
static long pager_call(ihk_os_t os, struct syscall_request *req);
@ -80,75 +87,90 @@ static void print_dma_lastreq(void)
}
#endif
int mcctrl_add_per_thread_data(struct mcctrl_per_proc_data* ppd,
struct task_struct *task, void *data)
void mcctrl_put_per_thread_data_unsafe(struct mcctrl_per_thread_data *ptd)
{
struct mcctrl_per_thread_data *ptd_iter, *ptd = NULL;
struct mcctrl_per_thread_data *ptd_alloc = NULL;
int hash = (((uint64_t)task >> 4) & MCCTRL_PER_THREAD_DATA_HASH_MASK);
int ret = 0;
unsigned long flags;
ptd_alloc = kmalloc(sizeof(*ptd), GFP_ATOMIC);
if (!ptd_alloc) {
kprintf("%s: error allocate per thread data\n", __FUNCTION__);
ret = -ENOMEM;
goto out_noalloc;
}
/* Check if data for this thread exists and add if not */
write_lock_irqsave(&ppd->per_thread_data_hash_lock[hash], flags);
list_for_each_entry(ptd_iter, &ppd->per_thread_data_hash[hash], hash) {
if (ptd_iter->task == task) {
ptd = ptd_iter;
break;
if (!atomic_dec_and_test(&ptd->refcount)) {
int ret = atomic_read(&ptd->refcount);
if (ret < 0) {
printk("%s: ERROR: invalid refcount=%d\n", __FUNCTION__, ret);
}
return;
}
if (unlikely(ptd)) {
ret = -EBUSY;
kfree(ptd_alloc);
goto out;
}
ptd = ptd_alloc;
ptd->task = task;
ptd->data = data;
list_add_tail(&ptd->hash, &ppd->per_thread_data_hash[hash]);
out:
write_unlock_irqrestore(&ppd->per_thread_data_hash_lock[hash], flags);
out_noalloc:
return ret;
list_del(&ptd->hash);
kfree(ptd);
}
int mcctrl_delete_per_thread_data(struct mcctrl_per_proc_data* ppd,
struct task_struct *task)
void mcctrl_put_per_thread_data(struct mcctrl_per_thread_data* _ptd)
{
struct mcctrl_per_proc_data *ppd = _ptd->ppd;
struct mcctrl_per_thread_data *ptd_iter, *ptd = NULL;
int hash = (((uint64_t)task >> 4) & MCCTRL_PER_THREAD_DATA_HASH_MASK);
int ret = 0;
int hash = (((uint64_t)_ptd->task >> 4) & MCCTRL_PER_THREAD_DATA_HASH_MASK);
unsigned long flags;
/* Check if data for this thread exists and delete it */
write_lock_irqsave(&ppd->per_thread_data_hash_lock[hash], flags);
list_for_each_entry(ptd_iter, &ppd->per_thread_data_hash[hash], hash) {
if (ptd_iter->task == task) {
if (ptd_iter->task == _ptd->task) {
ptd = ptd_iter;
break;
}
}
if (!ptd) {
ret = -EINVAL;
printk("%s: ERROR: ptd not found\n", __FUNCTION__);
goto out;
}
list_del(&ptd->hash);
kfree(ptd);
mcctrl_put_per_thread_data_unsafe(ptd);
out:
write_unlock_irqrestore(&ppd->per_thread_data_hash_lock[hash], flags);
}
int mcctrl_add_per_thread_data(struct mcctrl_per_proc_data *ppd, void *data)
{
struct mcctrl_per_thread_data *ptd_iter, *ptd = NULL;
struct mcctrl_per_thread_data *ptd_alloc = NULL;
int hash = (((uint64_t)current >> 4) & MCCTRL_PER_THREAD_DATA_HASH_MASK);
int ret = 0;
unsigned long flags;
ptd_alloc = kmalloc(sizeof(struct mcctrl_per_thread_data), GFP_ATOMIC);
if (!ptd_alloc) {
kprintf("%s: error allocate per thread data\n", __FUNCTION__);
ret = -ENOMEM;
goto out_noalloc;
}
memset(ptd_alloc, 0, sizeof(struct mcctrl_per_thread_data));
/* Check if data for this thread exists and add if not */
write_lock_irqsave(&ppd->per_thread_data_hash_lock[hash], flags);
list_for_each_entry(ptd_iter, &ppd->per_thread_data_hash[hash], hash) {
if (ptd_iter->task == current) {
ptd = ptd_iter;
break;
}
}
if (unlikely(ptd)) {
kprintf("%s: WARNING: ptd of tid: %d exists\n", __FUNCTION__, task_pid_vnr(current));
ret = -EBUSY;
kfree(ptd_alloc);
goto out;
}
ptd = ptd_alloc;
ptd->ppd = ppd;
ptd->task = current;
ptd->tid = task_pid_vnr(current);
ptd->data = data;
atomic_set(&ptd->refcount, 1);
list_add_tail(&ptd->hash, &ppd->per_thread_data_hash[hash]);
out:
write_unlock_irqrestore(&ppd->per_thread_data_hash_lock[hash], flags);
out_noalloc:
return ret;
}
@ -159,7 +181,7 @@ struct mcctrl_per_thread_data *mcctrl_get_per_thread_data(struct mcctrl_per_proc
int hash = (((uint64_t)task >> 4) & MCCTRL_PER_THREAD_DATA_HASH_MASK);
unsigned long flags;
/* Check if data for this thread exists and return it */
/* Check if data for this thread exists */
read_lock_irqsave(&ppd->per_thread_data_hash_lock[hash], flags);
list_for_each_entry(ptd_iter, &ppd->per_thread_data_hash[hash], hash) {
@ -169,8 +191,18 @@ struct mcctrl_per_thread_data *mcctrl_get_per_thread_data(struct mcctrl_per_proc
}
}
if (ptd) {
if (atomic_read(&ptd->refcount) <= 0) {
printk("%s: ERROR: use-after-free detected (%d)", __FUNCTION__, atomic_read(&ptd->refcount));
ptd = NULL;
goto out;
}
atomic_inc(&ptd->refcount);
}
out:
read_unlock_irqrestore(&ppd->per_thread_data_hash_lock[hash], flags);
return ptd ? ptd->data : NULL;
return ptd;
}
#endif /* !POSTK_DEBUG_ARCH_DEP_56 */
@ -299,6 +331,7 @@ long syscall_backward(struct mcctrl_usrdata *usrdata, int num,
struct wait_queue_head_list_node *wqhln;
unsigned long irqflags;
struct mcctrl_per_proc_data *ppd;
struct mcctrl_per_thread_data *ptd;
unsigned long phys;
struct syscall_request _request[2];
struct syscall_request *request;
@ -327,7 +360,14 @@ long syscall_backward(struct mcctrl_usrdata *usrdata, int num,
return -EINVAL;
}
packet = (struct ikc_scd_packet *)mcctrl_get_per_thread_data(ppd, current);
ptd = mcctrl_get_per_thread_data(ppd, current);
if (!ptd) {
printk("%s: ERROR: mcctrl_get_per_thread_data failed\n", __FUNCTION__);
syscall_ret = -ENOENT;
goto no_ptd;
}
pr_ptd("get", task_pid_vnr(current), ptd);
packet = (struct ikc_scd_packet *)ptd->data;
if (!packet) {
syscall_ret = -ENOENT;
printk("%s: no packet registered for TID %d\n",
@ -466,6 +506,9 @@ out:
ihk_device_unmap_memory(ihk_os_to_dev(usrdata->os), phys, sizeof(*resp));
out_put_ppd:
mcctrl_put_per_thread_data(ptd);
pr_ptd("put", task_pid_vnr(current), ptd);
no_ptd:
dprintk("%s: tid: %d, syscall: %d, syscall_ret: %lx\n",
__FUNCTION__, task_pid_vnr(current), num, syscall_ret);
@ -483,6 +526,7 @@ int remote_page_fault(struct mcctrl_usrdata *usrdata, void *fault_addr, uint64_t
struct wait_queue_head_list_node *wqhln;
unsigned long irqflags;
struct mcctrl_per_proc_data *ppd;
struct mcctrl_per_thread_data *ptd;
unsigned long phys;
int retry;
@ -498,11 +542,18 @@ int remote_page_fault(struct mcctrl_usrdata *usrdata, void *fault_addr, uint64_t
return -EINVAL;
}
packet = (struct ikc_scd_packet *)mcctrl_get_per_thread_data(ppd, current);
if (!packet) {
ptd = mcctrl_get_per_thread_data(ppd, current);
if (!ptd) {
printk("%s: ERROR: mcctrl_get_per_thread_data failed\n", __FUNCTION__);
error = -ENOENT;
goto no_ptd;
}
pr_ptd("get", task_pid_vnr(current), ptd);
packet = (struct ikc_scd_packet *)ptd->data;
if (!packet) {
printk("%s: no packet registered for TID %d\n",
__FUNCTION__, task_pid_vnr(current));
error = -ENOENT;
goto out_put_ppd;
}
@ -669,6 +720,9 @@ out:
ihk_device_unmap_memory(ihk_os_to_dev(usrdata->os), phys, sizeof(*resp));
out_put_ppd:
mcctrl_put_per_thread_data(ptd);
pr_ptd("put", task_pid_vnr(current), ptd);
no_ptd:
dprintk("%s: tid: %d, fault_addr: %p, reason: %lu, error: %d\n",
__FUNCTION__, task_pid_vnr(current), fault_addr, (unsigned long)reason, error);
@ -711,6 +765,7 @@ static int rus_vm_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
size_t pix;
#endif
struct mcctrl_per_proc_data *ppd;
struct mcctrl_per_thread_data *ptd;
struct ikc_scd_packet *packet;
int ret = 0;
@ -740,7 +795,14 @@ static int rus_vm_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
goto no_ppd;
}
packet = (struct ikc_scd_packet *)mcctrl_get_per_thread_data(ppd, current);
ptd = mcctrl_get_per_thread_data(ppd, current);
if (!ptd) {
printk("%s: ERROR: mcctrl_get_per_thread_data failed\n", __FUNCTION__);
ret = VM_FAULT_SIGBUS;
goto no_ptd;
}
pr_ptd("get", task_pid_vnr(current), ptd);
packet = (struct ikc_scd_packet *)ptd->data;
if (!packet) {
ret = VM_FAULT_SIGBUS;
printk("%s: no packet registered for TID %d\n",
@ -929,6 +991,9 @@ static int rus_vm_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
ret = VM_FAULT_NOPAGE;
put_and_out:
mcctrl_put_per_thread_data(ptd);
pr_ptd("put", task_pid_vnr(current), ptd);
no_ptd:
mcctrl_put_per_proc_data(ppd);
no_ppd:
return ret;