[PATCH] AArch64: Improve A64FX memcpy

Wilco Dijkstra Wilco.Dijkstra@arm.com
Wed Jun 30 15:38:26 GMT 2021


Hi Naohiro,

Since the A64FX memcpy is still quite large, I decided to do a quick pass to 
cleanup to simplify the code. I believe the code is better overall and a bit faster,
but please let me know what you think. I've left the structure as it was, but there
are likely more tweaks possible. Here it is:


Reduce the codesize of the A64FX memcpy by avoiding duplication of code,
and removing redundant instructions. The size for memcpy and memmove
goes down from 1796 bytes to 1080 bytes. Performance is mostly unchanged
or slightly better as the critical loops are identical but fewer instructions are
executed before entering the loop.

Passes GLIBC regress, OK for commit?

---

diff --git a/sysdeps/aarch64/multiarch/memcpy_a64fx.S b/sysdeps/aarch64/multiarch/memcpy_a64fx.S
index 65528405bb12373731e895c7030ccef23b88c17f..425148300913aadd8b144d17e7ee2b496f65008e 100644
--- a/sysdeps/aarch64/multiarch/memcpy_a64fx.S
+++ b/sysdeps/aarch64/multiarch/memcpy_a64fx.S
@@ -38,7 +38,6 @@
 #define dest_ptr	x7
 #define src_ptr		x8
 #define vector_length	x9
-#define cl_remainder	x10	// CACHE_LINE_SIZE remainder
 
 #if HAVE_AARCH64_SVE_ASM
 # if IS_IN (libc)
@@ -47,14 +46,6 @@
 
 	.arch armv8.2-a+sve
 
-	.macro dc_zva times
-	dc	zva, tmp1
-	add	tmp1, tmp1, CACHE_LINE_SIZE
-	.if \times-1
-	dc_zva "(\times-1)"
-	.endif
-	.endm
-
 	.macro ld1b_unroll8
 	ld1b	z0.b, p0/z, [src_ptr, #0, mul vl]
 	ld1b	z1.b, p0/z, [src_ptr, #1, mul vl]
@@ -106,69 +97,49 @@
 
 	.macro shortcut_for_small_size exit
 	// if rest <= vector_length * 2
-	whilelo	p0.b, xzr, n
+	whilelo p0.b, xzr, n
 	whilelo	p1.b, vector_length, n
-	b.last	1f
 	ld1b	z0.b, p0/z, [src, #0, mul vl]
 	ld1b	z1.b, p1/z, [src, #1, mul vl]
+	b.last	1f
 	st1b	z0.b, p0, [dest, #0, mul vl]
 	st1b	z1.b, p1, [dest, #1, mul vl]
 	ret
+
 1:	// if rest > vector_length * 8
 	cmp	n, vector_length, lsl 3 // vector_length * 8
 	b.hi	\exit
+
 	// if rest <= vector_length * 4
 	lsl	tmp1, vector_length, 1  // vector_length * 2
-	whilelo	p2.b, tmp1, n
-	incb	tmp1
-	whilelo	p3.b, tmp1, n
-	b.last	1f
-	ld1b	z0.b, p0/z, [src, #0, mul vl]
-	ld1b	z1.b, p1/z, [src, #1, mul vl]
+	sub	n, n, tmp1
+	whilelo p2.b, xzr, n
+	whilelo p3.b, vector_length, n
 	ld1b	z2.b, p2/z, [src, #2, mul vl]
 	ld1b	z3.b, p3/z, [src, #3, mul vl]
-	st1b	z0.b, p0, [dest, #0, mul vl]
-	st1b	z1.b, p1, [dest, #1, mul vl]
-	st1b	z2.b, p2, [dest, #2, mul vl]
-	st1b	z3.b, p3, [dest, #3, mul vl]
-	ret
-1:	// if rest <= vector_length * 8
-	lsl	tmp1, vector_length, 2  // vector_length * 4
-	whilelo	p4.b, tmp1, n
-	incb	tmp1
-	whilelo	p5.b, tmp1, n
 	b.last	1f
-	ld1b	z0.b, p0/z, [src, #0, mul vl]
-	ld1b	z1.b, p1/z, [src, #1, mul vl]
-	ld1b	z2.b, p2/z, [src, #2, mul vl]
-	ld1b	z3.b, p3/z, [src, #3, mul vl]
-	ld1b	z4.b, p4/z, [src, #4, mul vl]
-	ld1b	z5.b, p5/z, [src, #5, mul vl]
 	st1b	z0.b, p0, [dest, #0, mul vl]
-	st1b	z1.b, p1, [dest, #1, mul vl]
+	st1b	z1.b, p0, [dest, #1, mul vl]
 	st1b	z2.b, p2, [dest, #2, mul vl]
 	st1b	z3.b, p3, [dest, #3, mul vl]
-	st1b	z4.b, p4, [dest, #4, mul vl]
-	st1b	z5.b, p5, [dest, #5, mul vl]
 	ret
-1:	lsl	tmp1, vector_length, 2	// vector_length * 4
-	incb	tmp1			// vector_length * 5
-	incb	tmp1			// vector_length * 6
-	whilelo	p6.b, tmp1, n
-	incb	tmp1
-	whilelo	p7.b, tmp1, n
-	ld1b	z0.b, p0/z, [src, #0, mul vl]
-	ld1b	z1.b, p1/z, [src, #1, mul vl]
-	ld1b	z2.b, p2/z, [src, #2, mul vl]
-	ld1b	z3.b, p3/z, [src, #3, mul vl]
+
+1:	// if rest <= vector_length * 8
+	sub	n, n, tmp1
+	add	tmp2, tmp1, vector_length
+	whilelo p4.b, xzr, n
+	whilelo p5.b, vector_length, n
+	whilelo p6.b, tmp1, n
+	whilelo p7.b, tmp2, n
+
 	ld1b	z4.b, p4/z, [src, #4, mul vl]
 	ld1b	z5.b, p5/z, [src, #5, mul vl]
 	ld1b	z6.b, p6/z, [src, #6, mul vl]
 	ld1b	z7.b, p7/z, [src, #7, mul vl]
 	st1b	z0.b, p0, [dest, #0, mul vl]
-	st1b	z1.b, p1, [dest, #1, mul vl]
-	st1b	z2.b, p2, [dest, #2, mul vl]
-	st1b	z3.b, p3, [dest, #3, mul vl]
+	st1b	z1.b, p0, [dest, #1, mul vl]
+	st1b	z2.b, p0, [dest, #2, mul vl]
+	st1b	z3.b, p0, [dest, #3, mul vl]
 	st1b	z4.b, p4, [dest, #4, mul vl]
 	st1b	z5.b, p5, [dest, #5, mul vl]
 	st1b	z6.b, p6, [dest, #6, mul vl]
@@ -182,8 +153,8 @@ ENTRY (MEMCPY)
 	PTR_ARG (1)
 	SIZE_ARG (2)
 
-L(memcpy):
 	cntb	vector_length
+L(memmove_small):
 	// shortcut for less than vector_length * 8
 	// gives a free ptrue to p0.b for n >= vector_length
 	shortcut_for_small_size L(vl_agnostic)
@@ -201,135 +172,107 @@ L(vl_agnostic): // VL Agnostic
 
 L(unroll8): // unrolling and software pipeline
 	lsl	tmp1, vector_length, 3	// vector_length * 8
-	.p2align 3
-	cmp	 rest, tmp1
-	b.cc	L(last)
+	sub	rest, rest, tmp1
 	ld1b_unroll8
 	add	src_ptr, src_ptr, tmp1
-	sub	rest, rest, tmp1
-	cmp	rest, tmp1
+	subs	rest, rest, tmp1
 	b.cc	2f
-	.p2align 3
+	.p2align 4
 1:	stld1b_unroll8
 	add	dest_ptr, dest_ptr, tmp1
 	add	src_ptr, src_ptr, tmp1
-	sub	rest, rest, tmp1
-	cmp	rest, tmp1
-	b.ge	1b
+	subs	rest, rest, tmp1
+	b.hs	1b
 2:	st1b_unroll8
 	add	dest_ptr, dest_ptr, tmp1
+	add	rest, rest, tmp1
 
 	.p2align 3
 L(last):
-	whilelo	p0.b, xzr, rest
+	whilelo p0.b, xzr, rest
 	whilelo	p1.b, vector_length, rest
-	b.last	1f
-	ld1b	z0.b, p0/z, [src_ptr, #0, mul vl]
-	ld1b	z1.b, p1/z, [src_ptr, #1, mul vl]
-	st1b	z0.b, p0, [dest_ptr, #0, mul vl]
-	st1b	z1.b, p1, [dest_ptr, #1, mul vl]
-	ret
-1:	lsl	tmp1, vector_length, 1	// vector_length * 2
-	whilelo	p2.b, tmp1, rest
-	incb	tmp1
-	whilelo	p3.b, tmp1, rest
-	b.last	1f
-	ld1b	z0.b, p0/z, [src_ptr, #0, mul vl]
-	ld1b	z1.b, p1/z, [src_ptr, #1, mul vl]
-	ld1b	z2.b, p2/z, [src_ptr, #2, mul vl]
-	ld1b	z3.b, p3/z, [src_ptr, #3, mul vl]
-	st1b	z0.b, p0, [dest_ptr, #0, mul vl]
-	st1b	z1.b, p1, [dest_ptr, #1, mul vl]
-	st1b	z2.b, p2, [dest_ptr, #2, mul vl]
-	st1b	z3.b, p3, [dest_ptr, #3, mul vl]
-	ret
-1:	lsl	tmp1, vector_length, 2	// vector_length * 4
-	whilelo	p4.b, tmp1, rest
-	incb	tmp1
-	whilelo	p5.b, tmp1, rest
-	incb	tmp1
-	whilelo	p6.b, tmp1, rest
-	incb	tmp1
-	whilelo	p7.b, tmp1, rest
 	ld1b	z0.b, p0/z, [src_ptr, #0, mul vl]
 	ld1b	z1.b, p1/z, [src_ptr, #1, mul vl]
+	b.nlast	1f
+
+	lsl	tmp1, vector_length, 1  // vector_length * 2
+	sub	rest, rest, tmp1
+	whilelo p2.b, xzr, rest
+	whilelo p3.b, vector_length, rest
 	ld1b	z2.b, p2/z, [src_ptr, #2, mul vl]
 	ld1b	z3.b, p3/z, [src_ptr, #3, mul vl]
+        b.nlast  2f
+
+	sub	rest, rest, tmp1
+	add	tmp2, tmp1, vector_length // vector_length * 3
+	whilelo p4.b, xzr, rest
+	whilelo p5.b, vector_length, rest
+	whilelo p6.b, tmp1, rest
+	whilelo p7.b, tmp2, rest
+
 	ld1b	z4.b, p4/z, [src_ptr, #4, mul vl]
 	ld1b	z5.b, p5/z, [src_ptr, #5, mul vl]
 	ld1b	z6.b, p6/z, [src_ptr, #6, mul vl]
 	ld1b	z7.b, p7/z, [src_ptr, #7, mul vl]
-	st1b	z0.b, p0, [dest_ptr, #0, mul vl]
-	st1b	z1.b, p1, [dest_ptr, #1, mul vl]
-	st1b	z2.b, p2, [dest_ptr, #2, mul vl]
-	st1b	z3.b, p3, [dest_ptr, #3, mul vl]
 	st1b	z4.b, p4, [dest_ptr, #4, mul vl]
 	st1b	z5.b, p5, [dest_ptr, #5, mul vl]
 	st1b	z6.b, p6, [dest_ptr, #6, mul vl]
 	st1b	z7.b, p7, [dest_ptr, #7, mul vl]
+2:	st1b	z2.b, p2, [dest_ptr, #2, mul vl]
+	st1b	z3.b, p3, [dest_ptr, #3, mul vl]
+1:	st1b	z0.b, p0, [dest_ptr, #0, mul vl]
+	st1b	z1.b, p1, [dest_ptr, #1, mul vl]
 	ret
 
 L(L2):
 	// align dest address at CACHE_LINE_SIZE byte boundary
-	mov	tmp1, CACHE_LINE_SIZE
-	ands	tmp2, dest_ptr, CACHE_LINE_SIZE - 1
-	// if cl_remainder == 0
-	b.eq	L(L2_dc_zva)
-	sub	cl_remainder, tmp1, tmp2
-	// process remainder until the first CACHE_LINE_SIZE boundary
-	whilelo	p1.b, xzr, cl_remainder	// keep p0.b all true
-	whilelo	p2.b, vector_length, cl_remainder
-	b.last	1f
-	ld1b	z1.b, p1/z, [src_ptr, #0, mul vl]
-	ld1b	z2.b, p2/z, [src_ptr, #1, mul vl]
-	st1b	z1.b, p1, [dest_ptr, #0, mul vl]
-	st1b	z2.b, p2, [dest_ptr, #1, mul vl]
-	b	2f
-1:	lsl	tmp1, vector_length, 1	// vector_length * 2
-	whilelo	p3.b, tmp1, cl_remainder
-	incb	tmp1
-	whilelo	p4.b, tmp1, cl_remainder
-	ld1b	z1.b, p1/z, [src_ptr, #0, mul vl]
-	ld1b	z2.b, p2/z, [src_ptr, #1, mul vl]
-	ld1b	z3.b, p3/z, [src_ptr, #2, mul vl]
-	ld1b	z4.b, p4/z, [src_ptr, #3, mul vl]
-	st1b	z1.b, p1, [dest_ptr, #0, mul vl]
-	st1b	z2.b, p2, [dest_ptr, #1, mul vl]
-	st1b	z3.b, p3, [dest_ptr, #2, mul vl]
-	st1b	z4.b, p4, [dest_ptr, #3, mul vl]
-2:	add	dest_ptr, dest_ptr, cl_remainder
-	add	src_ptr, src_ptr, cl_remainder
-	sub	rest, rest, cl_remainder
+	and	tmp1, dest_ptr, CACHE_LINE_SIZE - 1
+	sub	tmp1, tmp1, CACHE_LINE_SIZE
+	ld1b	z1.b, p0/z, [src_ptr, #0, mul vl]
+	ld1b	z2.b, p0/z, [src_ptr, #1, mul vl]
+	ld1b	z3.b, p0/z, [src_ptr, #2, mul vl]
+	ld1b	z4.b, p0/z, [src_ptr, #3, mul vl]
+	st1b	z1.b, p0, [dest_ptr, #0, mul vl]
+	st1b	z2.b, p0, [dest_ptr, #1, mul vl]
+	st1b	z3.b, p0, [dest_ptr, #2, mul vl]
+	st1b	z4.b, p0, [dest_ptr, #3, mul vl]
+	sub	dest_ptr, dest_ptr, tmp1
+	sub	src_ptr, src_ptr, tmp1
+	add	rest, rest, tmp1
 
 L(L2_dc_zva):
-	// zero fill
-	and	tmp1, dest, 0xffffffffffffff
-	and	tmp2, src, 0xffffffffffffff
-	subs	tmp1, tmp1, tmp2	// diff
-	b.ge	1f
-	neg	tmp1, tmp1
-1:	mov	tmp3, ZF_DIST + CACHE_LINE_SIZE * 2
-	cmp	tmp1, tmp3
+	// check for overlap
+	sub	tmp1, src_ptr, dest_ptr
+	and	tmp1, tmp1, 0xffffffffffffff	// clear tag bits
+	mov	tmp2, ZF_DIST
+	cmp	tmp1, tmp2
 	b.lo	L(unroll8)
+
+	// zero fill loop
 	mov	tmp1, dest_ptr
-	dc_zva	(ZF_DIST / CACHE_LINE_SIZE) - 1
+	mov	tmp3, ZF_DIST / CACHE_LINE_SIZE
+1:	dc	zva, tmp1
+	add	tmp1, tmp1, CACHE_LINE_SIZE
+	subs	tmp3, tmp3, 1
+	b.ne	1b
+
+	mov	tmp3, ZF_DIST + CACHE_LINE_SIZE * 2
 	// unroll
-	ld1b_unroll8	// this line has to be after "b.lo L(unroll8)"
-	add	 src_ptr, src_ptr, CACHE_LINE_SIZE * 2
-	sub	 rest, rest, CACHE_LINE_SIZE * 2
-	mov	 tmp1, ZF_DIST
-	.p2align 3
-1:	stld1b_unroll4a
-	add	tmp2, dest_ptr, tmp1	// dest_ptr + ZF_DIST
-	dc	zva, tmp2
+	ld1b_unroll8
+	add	src_ptr, src_ptr, CACHE_LINE_SIZE * 2
+	sub	rest, rest, CACHE_LINE_SIZE * 2
+	.p2align 4
+2:	stld1b_unroll4a
+	add	tmp1, dest_ptr, tmp2	// dest_ptr + ZF_DIST
+	dc	zva, tmp1
 	stld1b_unroll4b
-	add	tmp2, tmp2, CACHE_LINE_SIZE
-	dc	zva, tmp2
+	add	tmp1, tmp1, CACHE_LINE_SIZE
+	dc	zva, tmp1
 	add	dest_ptr, dest_ptr, CACHE_LINE_SIZE * 2
 	add	src_ptr, src_ptr, CACHE_LINE_SIZE * 2
 	sub	rest, rest, CACHE_LINE_SIZE * 2
-	cmp	rest, tmp3	// ZF_DIST + CACHE_LINE_SIZE * 2
-	b.ge	1b
+	cmp	rest, tmp3
+	b.hs	2b
 	st1b_unroll8
 	add	dest_ptr, dest_ptr, CACHE_LINE_SIZE * 2
 	b	L(unroll8)
@@ -338,68 +281,50 @@ END (MEMCPY)
 libc_hidden_builtin_def (MEMCPY)
 
 
-ENTRY (MEMMOVE)
+ENTRY_ALIGN (MEMMOVE, 4)
 
 	PTR_ARG (0)
 	PTR_ARG (1)
 	SIZE_ARG (2)
 
-	// remove tag address
-	// dest has to be immutable because it is the return value
-	// src has to be immutable because it is used in L(bwd_last)
-	and	tmp2, dest, 0xffffffffffffff	// save dest_notag into tmp2
-	and	tmp3, src, 0xffffffffffffff	// save src_notag intp tmp3
-	cmp	n, 0
-	ccmp	tmp2, tmp3, 4, ne
-	b.ne	1f
-	ret
-1:	cntb	vector_length
-	// shortcut for less than vector_length * 8
-	// gives a free ptrue to p0.b for n >= vector_length
-	// tmp2 and tmp3 should not be used in this macro to keep
-	// notag addresses
-	shortcut_for_small_size L(dispatch)
-	// end of shortcut
-
-L(dispatch):
-	// tmp2 = dest_notag, tmp3 = src_notag
-	// diff = dest_notag - src_notag
-	sub	tmp1, tmp2, tmp3
-	// if diff <= 0 || diff >= n then memcpy
-	cmp	tmp1, 0
-	ccmp	tmp1, n, 2, gt
-	b.cs	L(vl_agnostic)
-
-L(bwd_start):
-	mov	rest, n
-	add	dest_ptr, dest, n	// dest_end
-	add	src_ptr, src, n		// src_end
+	cntb	vector_length
+        // diff = dest - src
+	sub	tmp1, dest, src
+	ands	tmp1, tmp1, 0xffffffffffffff    // clear tag bits
+	b.eq	L(full_overlap)
 
-L(bwd_unroll8): // unrolling and software pipeline
-	lsl	tmp1, vector_length, 3	// vector_length * 8
-	.p2align 3
-	cmp	rest, tmp1
-	b.cc	L(bwd_last)
-	sub	src_ptr, src_ptr, tmp1
+	cmp	n, vector_length, lsl 3 // vector_length * 8
+	b.ls	L(memmove_small)
+
+	ptrue	p0.b
+	// if diff < 0 || diff >= n then memcpy
+	cmp	tmp1, n
+	b.hs	L(vl_agnostic)
+
+	// unrolling and software pipeline
+	lsl	tmp1, vector_length, 3  // vector_length * 8
+	add	dest_ptr, dest, n       // dest_end
+	sub	rest, n, tmp1
+	add	src_ptr, src, rest	// src_end
 	ld1b_unroll8
-	sub	rest, rest, tmp1
-	cmp	rest, tmp1
+	subs	rest, rest, tmp1
 	b.cc	2f
-	.p2align 3
+	.p2align 4
 1:	sub	src_ptr, src_ptr, tmp1
 	sub	dest_ptr, dest_ptr, tmp1
 	stld1b_unroll8
-	sub	rest, rest, tmp1
-	cmp	rest, tmp1
-	b.ge	1b
+	subs	rest, rest, tmp1
+	b.hs	1b
 2:	sub	dest_ptr, dest_ptr, tmp1
 	st1b_unroll8
-
-L(bwd_last):
+	add	rest, rest, tmp1
 	mov	dest_ptr, dest
 	mov	src_ptr, src
 	b	L(last)
 
+L(full_overlap):
+	ret
+
 END (MEMMOVE)
 libc_hidden_builtin_def (MEMMOVE)
 # endif /* IS_IN (libc) */




More information about the Libc-alpha mailing list