From d5650c7e56431cda2f59032698dfb5d462215efc Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 14 Jul 2018 05:55:25 +0000 Subject: memalign: check alignment on all public functions And rework our internal_memalign interface to mimic the posix_memalign API; since that's what Ruby favors and avoids touching errno on errors. --- ext/mwrap/mwrap.c | 74 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/ext/mwrap/mwrap.c b/ext/mwrap/mwrap.c index 7ae892c..6109070 100644 --- a/ext/mwrap/mwrap.c +++ b/ext/mwrap/mwrap.c @@ -42,11 +42,12 @@ void *__malloc(size_t); void __free(void *); static void *(*real_malloc)(size_t) = __malloc; static void (*real_free)(void *) = __free; -# define RETURN_IF_NOT_READY() do {} while (0) /* nothing */ +static const int ready = 1; #else static int ready; static void *(*real_malloc)(size_t); static void (*real_free)(void *); +#endif /* !FreeBSD */ /* * we need to fake an OOM condition while dlsym is running, @@ -60,8 +61,6 @@ static void (*real_free)(void *); } \ } while (0) -#endif /* !FreeBSD */ - static size_t generation; static size_t page_size; static struct cds_lfht *totals; @@ -388,57 +387,72 @@ static void *ptr_align(void *ptr, size_t alignment) return (void *)(((uintptr_t)ptr + (alignment - 1)) & ~(alignment - 1)); } -static void *internal_memalign(size_t alignment, size_t size, uintptr_t caller) +static bool is_power_of_two(size_t n) { return (n & (n - 1)) == 0; } + +static int +internal_memalign(void **pp, size_t alignment, size_t size, uintptr_t caller) { struct src_loc *l; struct alloc_hdr *h; - void *p, *real; + void *real; size_t asize; + size_t d = alignment / sizeof(void*); + size_t r = alignment % sizeof(void*); - RETURN_IF_NOT_READY(); - if (alignment <= ASSUMED_MALLOC_ALIGNMENT) - return malloc(size); + if (!ready) return ENOMEM; + + if (r != 0 || d == 0 || !is_power_of_two(d)) + return EINVAL; + + if (alignment <= ASSUMED_MALLOC_ALIGNMENT) { + void *p = malloc(size); + if (!p) return ENOMEM; + *pp = p; + return 0; + } for (; alignment < sizeof(struct alloc_hdr); alignment *= 2) ; /* double alignment until >= sizeof(struct alloc_hdr) */ if (__builtin_add_overflow(size, alignment, &asize) || __builtin_add_overflow(asize, sizeof(struct alloc_hdr), &asize)) - return 0; + return ENOMEM; /* assert(asize == (alignment + size + sizeof(struct alloc_hdr))); */ rcu_read_lock(); l = update_stats_rcu(size, caller); - p = real = real_malloc(asize); + real = real_malloc(asize); if (real) { - p = hdr2ptr(real); + void *p = hdr2ptr(real); if (!ptr_is_aligned(p, alignment)) p = ptr_align(p, alignment); h = ptr2hdr(p); alloc_insert_rcu(l, h, size, real); + *pp = p; } rcu_read_unlock(); - return p; + return real ? 0 : ENOMEM; } -void *memalign(size_t alignment, size_t size) +static void * +memalign_result(int err, void *p) { - void *p = internal_memalign(alignment, size, RETURN_ADDRESS(0)); - if (caa_unlikely(!p)) errno = ENOMEM; + if (caa_unlikely(err)) { + errno = err; + return 0; + } return p; } -static bool is_power_of_two(size_t n) { return (n & (n - 1)) == 0; } +void *memalign(size_t alignment, size_t size) +{ + void *p; + int err = internal_memalign(&p, alignment, size, RETURN_ADDRESS(0)); + return memalign_result(err, p); +} int posix_memalign(void **p, size_t alignment, size_t size) { - size_t d = alignment / sizeof(void*); - size_t r = alignment % sizeof(void*); - - if (r != 0 || d == 0 || !is_power_of_two(d)) - return EINVAL; - - *p = internal_memalign(alignment, size, RETURN_ADDRESS(0)); - return *p ? 0 : ENOMEM; + return internal_memalign(p, alignment, size, RETURN_ADDRESS(0)); } void *aligned_alloc(size_t, size_t) __attribute__((alias("memalign"))); @@ -446,9 +460,9 @@ void cfree(void *) __attribute__((alias("free"))); void *valloc(size_t size) { - void *p = internal_memalign(page_size, size, RETURN_ADDRESS(0)); - if (caa_unlikely(!p)) errno = ENOMEM; - return p; + void *p; + int err = internal_memalign(&p, page_size, size, RETURN_ADDRESS(0)); + return memalign_result(err, p); } #if __GNUC__ < 7 @@ -465,15 +479,15 @@ void *pvalloc(size_t size) { size_t alignment = page_size; void *p; + int err; if (add_overflow_p(size, alignment)) { errno = ENOMEM; return 0; } size = size_align(size, alignment); - p = internal_memalign(alignment, size, RETURN_ADDRESS(0)); - if (caa_unlikely(!p)) errno = ENOMEM; - return p; + err = internal_memalign(&p, alignment, size, RETURN_ADDRESS(0)); + return memalign_result(err, p); } void *malloc(size_t size) -- cgit v1.2.3-24-ge0c7