From 09173d353c6a5742dce42f6834bd427fe71b8b1e Mon Sep 17 00:00:00 2001 From: Balazs Gerofi Date: Tue, 16 Mar 2021 09:41:29 +0900 Subject: [PATCH] mcctrl_wakeup_desc: refcount and fix timeouts Change-Id: I14b34f031ffb10bfac6cef07d81f53a8dece767b --- executer/kernel/mcctrl/arch/arm64/archdeps.c | 2 +- executer/kernel/mcctrl/control.c | 24 ++-- executer/kernel/mcctrl/futex.c | 15 +-- executer/kernel/mcctrl/ikc.c | 120 ++++++++++++------- executer/kernel/mcctrl/include/refcount.h | 105 ++++++++++++++++ executer/kernel/mcctrl/mcctrl.h | 6 + executer/kernel/mcctrl/procfs.c | 6 +- executer/kernel/mcctrl/syscall.c | 14 +-- 8 files changed, 200 insertions(+), 92 deletions(-) create mode 100644 executer/kernel/mcctrl/include/refcount.h diff --git a/executer/kernel/mcctrl/arch/arm64/archdeps.c b/executer/kernel/mcctrl/arch/arm64/archdeps.c index 8856527c..05672278 100644 --- a/executer/kernel/mcctrl/arch/arm64/archdeps.c +++ b/executer/kernel/mcctrl/arch/arm64/archdeps.c @@ -916,7 +916,7 @@ int __mcctrl_tof_utofu_release_handler(struct inode *inode, struct file *filp, isp.arg = f2pfd->fd; ret = mcctrl_ikc_send_wait(f2pfd->os, ppd->ikc_target_cpu, - &isp, -20, NULL, NULL, 0); + &isp, -1000, NULL, NULL, 0); if (ret != 0) { pr_err("%s: WARNING: IKC req for PID: %d, fd: %d failed\n", __func__, f2pfd->pid, f2pfd->fd); diff --git a/executer/kernel/mcctrl/control.c b/executer/kernel/mcctrl/control.c index 29254836..9a06568c 100644 --- a/executer/kernel/mcctrl/control.c +++ b/executer/kernel/mcctrl/control.c @@ -414,7 +414,7 @@ static void release_handler(ihk_os_t os, void *param) dprintk("%s: SCD_MSG_CLEANUP_PROCESS, info: %p, cpu: %d\n", __FUNCTION__, info, info->cpu); ret = mcctrl_ikc_send_wait(os, info->cpu, - &isp, -20, NULL, NULL, 0); + &isp, -5000, NULL, NULL, 0); if (ret != 0) { printk("%s: WARNING: failed to send IKC msg: %d\n", __func__, ret); @@ -513,8 +513,6 @@ static DECLARE_WAIT_QUEUE_HEAD(signalq); struct mcctrl_signal_desc { struct mcctrl_signal msig; - struct mcctrl_wakeup_desc wakeup; - void *addrs[1]; }; static long mcexec_send_signal(ihk_os_t os, struct signal_desc *sigparam) @@ -554,7 +552,7 @@ static long mcexec_send_signal(ihk_os_t os, struct signal_desc *sigparam) isp.pid = sig.pid; isp.arg = virt_to_phys(msigp); - rc = mcctrl_ikc_send_wait(os, sig.cpu, &isp, 0, &desc->wakeup, + rc = mcctrl_ikc_send_wait(os, sig.cpu, &isp, -1000, NULL, &do_free, 1, desc); if (rc < 0) { printk("mcexec_send_signal: mcctrl_ikc_send ret=%d\n", rc); @@ -2243,8 +2241,6 @@ long mcctrl_perf_num(ihk_os_t os, unsigned long arg) struct mcctrl_perf_ctrl_desc { struct perf_ctrl_desc desc; - struct mcctrl_wakeup_desc wakeup; - void *addrs[1]; }; #define wakeup_desc_of_perf_desc(_desc) \ (&container_of((_desc), struct mcctrl_perf_ctrl_desc, desc)->wakeup) @@ -2310,9 +2306,7 @@ long mcctrl_perf_set(ihk_os_t os, struct ihk_perf_event_attr *__user arg) isp.arg = virt_to_phys(perf_desc); for (j = 0; j < info->n_cpus; j++) { - ret = mcctrl_ikc_send_wait(os, j, &isp, - msecs_to_jiffies(10000), - wakeup_desc_of_perf_desc(perf_desc), + ret = mcctrl_ikc_send_wait(os, j, &isp, 10000, NULL, &need_free, 1, perf_desc); if (ret < 0) { pr_warn("%s: mcctrl_ikc_send_wait ret=%d\n", @@ -2382,9 +2376,7 @@ long mcctrl_perf_get(ihk_os_t os, unsigned long *__user arg) isp.arg = virt_to_phys(perf_desc); for (j = 0; j < info->n_cpus; j++) { - ret = mcctrl_ikc_send_wait(os, j, &isp, - msecs_to_jiffies(10000), - wakeup_desc_of_perf_desc(perf_desc), + ret = mcctrl_ikc_send_wait(os, j, &isp, 10000, NULL, &need_free, 1, perf_desc); if (ret < 0) { pr_warn("%s: mcctrl_ikc_send_wait ret=%d\n", @@ -2454,9 +2446,8 @@ long mcctrl_perf_enable(ihk_os_t os) return -EINVAL; } for (j = 0; j < info->n_cpus; j++) { - ret = mcctrl_ikc_send_wait(os, j, &isp, 0, - wakeup_desc_of_perf_desc(perf_desc), - &need_free, 1, perf_desc); + ret = mcctrl_ikc_send_wait(os, j, &isp, 0, NULL, + &need_free, 1, perf_desc); if (ret < 0) { pr_warn("%s: mcctrl_ikc_send_wait ret=%d\n", @@ -2522,8 +2513,7 @@ long mcctrl_perf_disable(ihk_os_t os) return -EINVAL; } for (j = 0; j < info->n_cpus; j++) { - ret = mcctrl_ikc_send_wait(os, j, &isp, 0, - wakeup_desc_of_perf_desc(perf_desc), + ret = mcctrl_ikc_send_wait(os, j, &isp, 0, NULL, &need_free, 1, perf_desc); if (ret < 0) { pr_warn("%s: mcctrl_ikc_send_wait ret=%d\n", diff --git a/executer/kernel/mcctrl/futex.c b/executer/kernel/mcctrl/futex.c index 9d887f9d..7a510d4e 100644 --- a/executer/kernel/mcctrl/futex.c +++ b/executer/kernel/mcctrl/futex.c @@ -182,8 +182,6 @@ static int uti_remote_page_fault(struct mcctrl_usrdata *usrdata, struct mcctrl_per_proc_data *ppd, int tid, int cpu) { int error; - struct mcctrl_wakeup_desc *desc; - int do_frees = 1; struct ikc_scd_packet packet; /* Request page fault */ @@ -192,20 +190,9 @@ static int uti_remote_page_fault(struct mcctrl_usrdata *usrdata, packet.fault_reason = reason; packet.fault_tid = tid; - /* we need to alloc desc ourselves because GFP_ATOMIC */ -retry_alloc: - desc = kmalloc(sizeof(*desc), GFP_ATOMIC); - if (!desc) { - pr_warn("WARNING: coudln't alloc remote page fault wait desc, retrying..\n"); - goto retry_alloc; - } - /* packet->target_cpu was set in rus_vm_fault if a thread was found */ error = mcctrl_ikc_send_wait(usrdata->os, cpu, &packet, - 0, desc, &do_frees, 0); - if (do_frees) { - kfree(desc); - } + 0, NULL, NULL, 0); if (error < 0) { pr_warn("%s: WARNING: failed to request uti remote page fault :%d\n", __func__, error); diff --git a/executer/kernel/mcctrl/ikc.c b/executer/kernel/mcctrl/ikc.c index 1bc31cb6..f2b73e44 100644 --- a/executer/kernel/mcctrl/ikc.c +++ b/executer/kernel/mcctrl/ikc.c @@ -58,23 +58,41 @@ void mcctrl_os_read_write_cpu_response(ihk_os_t os, struct ikc_scd_packet *pisp); void mcctrl_eventfd(ihk_os_t os, struct ikc_scd_packet *pisp); -/* Assumes usrdata->wakeup_descs_lock taken */ -static void mcctrl_wakeup_desc_cleanup(ihk_os_t os, - struct mcctrl_wakeup_desc *desc) +static void mcctrl_wakeup_desc_put(struct mcctrl_wakeup_desc *desc, + struct mcctrl_usrdata *usrdata, int free_addrs) { + unsigned long irqflags; int i; - list_del(&desc->chain); - - for (i = 0; i < desc->free_addrs_count; i++) { - kfree(desc->free_addrs[i]); + if (!refcount_dec_and_test(&desc->count)) { + return; } + + spin_lock_irqsave(&usrdata->wakeup_descs_lock, irqflags); + list_del(&desc->chain); + spin_unlock_irqrestore(&usrdata->wakeup_descs_lock, irqflags); + + if (free_addrs) { + for (i = 0; i < desc->free_addrs_count; i++) { + kfree(desc->free_addrs[i]); + } + } + + if (desc->free_at_put) + kfree(desc); } static void mcctrl_wakeup_cb(ihk_os_t os, struct ikc_scd_packet *packet) { struct mcctrl_wakeup_desc *desc = packet->reply; + struct mcctrl_usrdata *usrdata = ihk_host_os_get_usrdata(os); + /* destroy_ikc_channels must have cleaned up descs */ + if (!usrdata) { + pr_err("%s: error: mcctrl_usrdata not found\n", + __func__); + return; + } WRITE_ONCE(desc->err, packet->err); @@ -85,29 +103,25 @@ static void mcctrl_wakeup_cb(ihk_os_t os, struct ikc_scd_packet *packet) * wake up opportunistically between this set and the wake_up call. * * If the other side is no longer waiting, free the memory that was - * left for us. + * left for us. The caller has been notified not to free. */ if (cmpxchg(&desc->status, 0, 1)) { - struct mcctrl_usrdata *usrdata = ihk_host_os_get_usrdata(os); - unsigned long flags; - - /* destroy_ikc_channels must have cleaned up descs */ - if (!usrdata) { - pr_err("%s: error: mcctrl_usrdata not found\n", - __func__); - return; - } - - spin_lock_irqsave(&usrdata->wakeup_descs_lock, flags); - mcctrl_wakeup_desc_cleanup(os, desc); - spin_unlock_irqrestore(&usrdata->wakeup_descs_lock, flags); + mcctrl_wakeup_desc_put(desc, usrdata, 1); return; } + /* + * Notify waiter before dropping reference to make sure + * wait queue is still valid. + */ wake_up_interruptible(&desc->wq); + mcctrl_wakeup_desc_put(desc, usrdata, 0); } -/* do_frees: 1 when caller should free free_addrs[], 0 otherwise */ +/* + * do_frees: 1 when caller should free free_addrs[], 0 otherwise + * timeout: timeout in milliseconds + */ int mcctrl_ikc_send_wait(ihk_os_t os, int cpu, struct ikc_scd_packet *pisp, long int timeout, struct mcctrl_wakeup_desc *desc, int *do_frees, int free_addrs_count, ...) @@ -115,35 +129,60 @@ int mcctrl_ikc_send_wait(ihk_os_t os, int cpu, struct ikc_scd_packet *pisp, int ret, i; int alloc_desc = (desc == NULL); va_list ap; + unsigned long flags; + struct mcctrl_usrdata *usrdata = ihk_host_os_get_usrdata(os); + + if (!usrdata) { + pr_err("%s: error: mcctrl_usrdata not found\n", + __func__); + return -EINVAL; + } if (free_addrs_count) *do_frees = 1; + if (alloc_desc) desc = kmalloc(sizeof(struct mcctrl_wakeup_desc) + (free_addrs_count + 1) * sizeof(void *), - GFP_KERNEL); + GFP_ATOMIC); if (!desc) { pr_warn("%s: Could not allocate wakeup descriptor", __func__); return -ENOMEM; } + pisp->reply = desc; va_start(ap, free_addrs_count); for (i = 0; i < free_addrs_count; i++) { desc->free_addrs[i] = va_arg(ap, void*); } va_end(ap); - if (alloc_desc) - desc->free_addrs[free_addrs_count++] = desc; desc->free_addrs_count = free_addrs_count; + + /* Only free at put time if allocated internally */ + desc->free_at_put = 0; + if (alloc_desc) + desc->free_at_put = 1; + init_waitqueue_head(&desc->wq); + + /* One for the caller and one for the call-back */ + refcount_set(&desc->count, 2); + + /* XXX: make this a hash-table? */ + spin_lock_irqsave(&usrdata->wakeup_descs_lock, flags); + list_add(&desc->chain, &usrdata->wakeup_descs_list); + spin_unlock_irqrestore(&usrdata->wakeup_descs_lock, flags); + WRITE_ONCE(desc->err, 0); WRITE_ONCE(desc->status, 0); ret = mcctrl_ikc_send(os, cpu, pisp); if (ret < 0) { pr_warn("%s: mcctrl_ikc_send failed: %d\n", __func__, ret); - if (alloc_desc) - kfree(desc); + /* Failed to send msg, put twice */ + mcctrl_wakeup_desc_put(desc, usrdata, 0); + mcctrl_wakeup_desc_put(desc, usrdata, 0); + return ret; } @@ -180,28 +219,16 @@ int mcctrl_ikc_send_wait(ihk_os_t os, int cpu, struct ikc_scd_packet *pisp, * the callback it will need to free things for us */ if (!cmpxchg(&desc->status, 0, 1)) { - struct mcctrl_usrdata *usrdata = ihk_host_os_get_usrdata(os); - unsigned long flags; + mcctrl_wakeup_desc_put(desc, usrdata, 0); - if (!usrdata) { - pr_err("%s: error: mcctrl_usrdata not found\n", - __func__); - ret = ret < 0 ? ret : -EINVAL; - goto out; - } - - spin_lock_irqsave(&usrdata->wakeup_descs_lock, flags); - list_add(&desc->chain, &usrdata->wakeup_descs_list); - spin_unlock_irqrestore(&usrdata->wakeup_descs_lock, flags); if (do_frees) *do_frees = 0; return ret < 0 ? ret : -ETIME; } ret = READ_ONCE(desc->err); -out: - if (alloc_desc) - kfree(desc); + + mcctrl_wakeup_desc_put(desc, usrdata, 0); return ret; } @@ -605,10 +632,15 @@ void destroy_ikc_channels(ihk_os_t os) ihk_ikc_destroy_channel(usrdata->ikc2linux[i]); } } + spin_lock_irqsave(&usrdata->wakeup_descs_lock, flags); list_for_each_entry_safe(mwd_entry, mwd_next, - &usrdata->wakeup_descs_list, chain) { - mcctrl_wakeup_desc_cleanup(os, mwd_entry); + &usrdata->wakeup_descs_list, chain) { + list_del(&mwd_entry->chain); + + for (i = 0; i < mwd_entry->free_addrs_count; i++) { + kfree(mwd_entry->free_addrs[i]); + } } spin_unlock_irqrestore(&usrdata->wakeup_descs_lock, flags); diff --git a/executer/kernel/mcctrl/include/refcount.h b/executer/kernel/mcctrl/include/refcount.h new file mode 100644 index 00000000..8c6242df --- /dev/null +++ b/executer/kernel/mcctrl/include/refcount.h @@ -0,0 +1,105 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _LINUX_REFCOUNT_H +#define _LINUX_REFCOUNT_H + +#include +#include +#include + +struct mutex; + +/** + * struct refcount_t - variant of atomic_t specialized for reference counts + * @refs: atomic_t counter field + * + * The counter saturates at UINT_MAX and will not move once + * there. This avoids wrapping the counter and causing 'spurious' + * use-after-free bugs. + */ +typedef struct refcount_struct { + atomic_t refs; +} refcount_t; + +#define REFCOUNT_INIT(n) { .refs = ATOMIC_INIT(n), } + +/** + * refcount_set - set a refcount's value + * @r: the refcount + * @n: value to which the refcount will be set + */ +static inline void refcount_set(refcount_t *r, unsigned int n) +{ + atomic_set(&r->refs, n); +} + +/** + * refcount_read - get a refcount's value + * @r: the refcount + * + * Return: the refcount's value + */ +static inline unsigned int refcount_read(const refcount_t *r) +{ + return atomic_read(&r->refs); +} + +#ifdef CONFIG_REFCOUNT_FULL +extern __must_check bool refcount_add_not_zero(unsigned int i, refcount_t *r); +extern void refcount_add(unsigned int i, refcount_t *r); + +extern __must_check bool refcount_inc_not_zero(refcount_t *r); +extern void refcount_inc(refcount_t *r); + +extern __must_check bool refcount_sub_and_test(unsigned int i, refcount_t *r); + +extern __must_check bool refcount_dec_and_test(refcount_t *r); +extern void refcount_dec(refcount_t *r); +#else +# ifdef CONFIG_ARCH_HAS_REFCOUNT +# include +# else +static inline __must_check bool refcount_add_not_zero(unsigned int i, refcount_t *r) +{ + return atomic_add_unless(&r->refs, i, 0); +} + +static inline void refcount_add(unsigned int i, refcount_t *r) +{ + atomic_add(i, &r->refs); +} + +static inline __must_check bool refcount_inc_not_zero(refcount_t *r) +{ + return atomic_add_unless(&r->refs, 1, 0); +} + +static inline void refcount_inc(refcount_t *r) +{ + atomic_inc(&r->refs); +} + +static inline __must_check bool refcount_sub_and_test(unsigned int i, refcount_t *r) +{ + return atomic_sub_and_test(i, &r->refs); +} + +static inline __must_check bool refcount_dec_and_test(refcount_t *r) +{ + return atomic_dec_and_test(&r->refs); +} + +static inline void refcount_dec(refcount_t *r) +{ + atomic_dec(&r->refs); +} +# endif /* !CONFIG_ARCH_HAS_REFCOUNT */ +#endif /* CONFIG_REFCOUNT_FULL */ + +extern __must_check bool refcount_dec_if_one(refcount_t *r); +extern __must_check bool refcount_dec_not_one(refcount_t *r); +extern __must_check bool refcount_dec_and_mutex_lock(refcount_t *r, struct mutex *lock); +extern __must_check bool refcount_dec_and_lock(refcount_t *r, spinlock_t *lock); +extern __must_check bool refcount_dec_and_lock_irqsave(refcount_t *r, + spinlock_t *lock, + unsigned long *flags); +#endif /* _LINUX_REFCOUNT_H */ diff --git a/executer/kernel/mcctrl/mcctrl.h b/executer/kernel/mcctrl/mcctrl.h index eefca7ff..28f994f0 100644 --- a/executer/kernel/mcctrl/mcctrl.h +++ b/executer/kernel/mcctrl/mcctrl.h @@ -44,6 +44,10 @@ #include #include #include +#include +#if KERNEL_VERSION(4, 11, 0) > LINUX_VERSION_CODE +#include +#endif #include "sysfs.h" #define SCD_MSG_PREPARE_PROCESS 0x1 @@ -401,6 +405,8 @@ int mcctrl_ikc_is_valid_thread(ihk_os_t os, int cpu); struct mcctrl_wakeup_desc { int status; int err; + refcount_t count; + int free_at_put; wait_queue_head_t wq; struct list_head chain; int free_addrs_count; diff --git a/executer/kernel/mcctrl/procfs.c b/executer/kernel/mcctrl/procfs.c index dc931d57..696deb82 100644 --- a/executer/kernel/mcctrl/procfs.c +++ b/executer/kernel/mcctrl/procfs.c @@ -611,7 +611,7 @@ static ssize_t __mckernel_procfs_read_write( ret = mcctrl_ikc_send_wait(osnum_to_os(e->osnum), (pid > 0) ? ppd->ikc_target_cpu : 0, - &isp, HZ, NULL, &do_free, 1, r); + &isp, 5000, NULL, &do_free, 1, r); if (!do_free && ret >= 0) { ret = -EIO; @@ -879,7 +879,7 @@ static int mckernel_procfs_buff_release(struct inode *inode, struct file *file) rc = -EIO; ret = mcctrl_ikc_send_wait(info->os, 0, - &isp, 5 * HZ, NULL, &do_free, 1, r); + &isp, 5000, NULL, &do_free, 1, r); if (!do_free && ret >= 0) { ret = -EIO; @@ -977,7 +977,7 @@ static ssize_t mckernel_procfs_buff_read(struct file *file, char __user *ubuf, done = 1; ret = mcctrl_ikc_send_wait(os, (pid > 0) ? ppd->ikc_target_cpu : 0, - &isp, 5 * HZ, NULL, &do_free, 1, r); + &isp, 5000, NULL, &do_free, 1, r); if (!do_free && ret >= 0) { ret = -EIO; diff --git a/executer/kernel/mcctrl/syscall.c b/executer/kernel/mcctrl/syscall.c index 111cbc49..32e06ff8 100644 --- a/executer/kernel/mcctrl/syscall.c +++ b/executer/kernel/mcctrl/syscall.c @@ -495,8 +495,6 @@ int remote_page_fault(struct mcctrl_usrdata *usrdata, void *fault_addr, struct ikc_scd_packet *packet) { int error; - struct mcctrl_wakeup_desc *desc; - int do_frees = 1; dprintk("%s: tid: %d, fault_addr: %p, reason: %lu\n", __FUNCTION__, task_pid_vnr(current), fault_addr, (unsigned long)reason); @@ -506,19 +504,9 @@ int remote_page_fault(struct mcctrl_usrdata *usrdata, void *fault_addr, packet->fault_address = (unsigned long)fault_addr; packet->fault_reason = reason; - /* we need to alloc desc ourselves because GFP_ATOMIC */ -retry_alloc: - desc = kmalloc(sizeof(*desc), GFP_ATOMIC); - if (!desc) { - pr_warn("WARNING: coudln't alloc remote page fault wait desc, retrying..\n"); - goto retry_alloc; - } - /* packet->target_cpu was set in rus_vm_fault if a thread was found */ error = mcctrl_ikc_send_wait(usrdata->os, packet->target_cpu, packet, - 0, desc, &do_frees, 0); - if (do_frees) - kfree(desc); + 0, NULL, NULL, 0); if (error < 0) { pr_warn("%s: WARNING: failed to request remote page fault PID %d: %d\n", __func__, packet->pid, error);