From 19c062d3d4cd46ac9095f8ef8133c0e3c01a9d4f Mon Sep 17 00:00:00 2001
From: Peter Jung <admin@ptr1337.dev>
Date: Thu, 18 Dec 2025 16:42:00 +0100
Subject: [PATCH 07/11] crypto

Signed-off-by: Peter Jung <admin@ptr1337.dev>
---
 arch/x86/crypto/Makefile                      |    5 +-
 arch/x86/crypto/aes-gcm-aesni-x86_64.S        |   12 +-
 arch/x86/crypto/aes-gcm-vaes-avx2.S           | 1150 +++++++++++++++++
 ...m-avx10-x86_64.S => aes-gcm-vaes-avx512.S} |  722 +++++------
 arch/x86/crypto/aesni-intel_glue.c            |  264 ++--
 drivers/md/Kconfig                            |    1 +
 drivers/md/dm-verity-fec.c                    |   21 +-
 drivers/md/dm-verity-fec.h                    |    5 +-
 drivers/md/dm-verity-target.c                 |  203 ++-
 drivers/md/dm-verity.h                        |   52 +-
 include/linux/rhashtable.h                    |   70 +-
 11 files changed, 1921 insertions(+), 584 deletions(-)
 create mode 100644 arch/x86/crypto/aes-gcm-vaes-avx2.S
 rename arch/x86/crypto/{aes-gcm-avx10-x86_64.S => aes-gcm-vaes-avx512.S} (69%)

diff --git a/arch/x86/crypto/Makefile b/arch/x86/crypto/Makefile
index 2d30d5d36145..6409e3009524 100644
--- a/arch/x86/crypto/Makefile
+++ b/arch/x86/crypto/Makefile
@@ -46,8 +46,9 @@ obj-$(CONFIG_CRYPTO_AES_NI_INTEL) += aesni-intel.o
 aesni-intel-y := aesni-intel_asm.o aesni-intel_glue.o
 aesni-intel-$(CONFIG_64BIT) += aes-ctr-avx-x86_64.o \
 			       aes-gcm-aesni-x86_64.o \
-			       aes-xts-avx-x86_64.o \
-			       aes-gcm-avx10-x86_64.o
+			       aes-gcm-vaes-avx2.o \
+			       aes-gcm-vaes-avx512.o \
+			       aes-xts-avx-x86_64.o
 
 obj-$(CONFIG_CRYPTO_GHASH_CLMUL_NI_INTEL) += ghash-clmulni-intel.o
 ghash-clmulni-intel-y := ghash-clmulni-intel_asm.o ghash-clmulni-intel_glue.o
diff --git a/arch/x86/crypto/aes-gcm-aesni-x86_64.S b/arch/x86/crypto/aes-gcm-aesni-x86_64.S
index 45940e2883a0..7c8a8a32bd3c 100644
--- a/arch/x86/crypto/aes-gcm-aesni-x86_64.S
+++ b/arch/x86/crypto/aes-gcm-aesni-x86_64.S
@@ -61,15 +61,15 @@
 // for the *_aesni functions or AVX for the *_aesni_avx ones.  (But it seems
 // there are no CPUs that support AES-NI without also PCLMULQDQ and SSE4.1.)
 //
-// The design generally follows that of aes-gcm-avx10-x86_64.S, and that file is
+// The design generally follows that of aes-gcm-vaes-avx512.S, and that file is
 // more thoroughly commented.  This file has the following notable changes:
 //
 //    - The vector length is fixed at 128-bit, i.e. xmm registers.  This means
 //      there is only one AES block (and GHASH block) per register.
 //
-//    - Without AVX512 / AVX10, only 16 SIMD registers are available instead of
-//      32.  We work around this by being much more careful about using
-//      registers, relying heavily on loads to load values as they are needed.
+//    - Without AVX512, only 16 SIMD registers are available instead of 32.  We
+//      work around this by being much more careful about using registers,
+//      relying heavily on loads to load values as they are needed.
 //
 //    - Masking is not available either.  We work around this by implementing
 //      partial block loads and stores using overlapping scalar loads and stores
@@ -90,8 +90,8 @@
 //      multiplication instead of schoolbook multiplication.  This saves one
 //      pclmulqdq instruction per block, at the cost of one 64-bit load, one
 //      pshufd, and 0.25 pxors per block.  (This is without the three-argument
-//      XOR support that would be provided by AVX512 / AVX10, which would be
-//      more beneficial to schoolbook than Karatsuba.)
+//      XOR support that would be provided by AVX512, which would be more
+//      beneficial to schoolbook than Karatsuba.)
 //
 //      As a rough approximation, we can assume that Karatsuba multiplication is
 //      faster than schoolbook multiplication in this context if one pshufd and
diff --git a/arch/x86/crypto/aes-gcm-vaes-avx2.S b/arch/x86/crypto/aes-gcm-vaes-avx2.S
new file mode 100644
index 000000000000..5ccbd85383cd
--- /dev/null
+++ b/arch/x86/crypto/aes-gcm-vaes-avx2.S
@@ -0,0 +1,1150 @@
+/* SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause */
+//
+// AES-GCM implementation for x86_64 CPUs that support the following CPU
+// features: VAES && VPCLMULQDQ && AVX2
+//
+// Copyright 2025 Google LLC
+//
+// Author: Eric Biggers <ebiggers@google.com>
+//
+//------------------------------------------------------------------------------
+//
+// This file is dual-licensed, meaning that you can use it under your choice of
+// either of the following two licenses:
+//
+// Licensed under the Apache License 2.0 (the "License").  You may obtain a copy
+// of the License at
+//
+//	http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// or
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice,
+//    this list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright
+//    notice, this list of conditions and the following disclaimer in the
+//    documentation and/or other materials provided with the distribution.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// -----------------------------------------------------------------------------
+//
+// This is similar to aes-gcm-vaes-avx512.S, but it uses AVX2 instead of AVX512.
+// This means it can only use 16 vector registers instead of 32, the maximum
+// vector length is 32 bytes, and some instructions such as vpternlogd and
+// masked loads/stores are unavailable.  However, it is able to run on CPUs that
+// have VAES without AVX512, namely AMD Zen 3 (including "Milan" server CPUs),
+// various Intel client CPUs such as Alder Lake, and Intel Sierra Forest.
+//
+// This implementation also uses Karatsuba multiplication instead of schoolbook
+// multiplication for GHASH in its main loop.  This does not help much on Intel,
+// but it improves performance by ~5% on AMD Zen 3.  Other factors weighing
+// slightly in favor of Karatsuba multiplication in this implementation are the
+// lower maximum vector length (which means there are fewer key powers, so we
+// can cache the halves of each key power XOR'd together and still use less
+// memory than the AVX512 implementation), and the unavailability of the
+// vpternlogd instruction (which helped schoolbook a bit more than Karatsuba).
+
+#include <linux/linkage.h>
+
+.section .rodata
+.p2align 4
+
+	// The below three 16-byte values must be in the order that they are, as
+	// they are really two 32-byte tables and a 16-byte value that overlap:
+	//
+	// - The first 32-byte table begins at .Lselect_high_bytes_table.
+	//   For 0 <= len <= 16, the 16-byte value at
+	//   '.Lselect_high_bytes_table + len' selects the high 'len' bytes of
+	//   another 16-byte value when AND'ed with it.
+	//
+	// - The second 32-byte table begins at .Lrshift_and_bswap_table.
+	//   For 0 <= len <= 16, the 16-byte value at
+	//   '.Lrshift_and_bswap_table + len' is a vpshufb mask that does the
+	//   following operation: right-shift by '16 - len' bytes (shifting in
+	//   zeroes), then reflect all 16 bytes.
+	//
+	// - The 16-byte value at .Lbswap_mask is a vpshufb mask that reflects
+	//   all 16 bytes.
+.Lselect_high_bytes_table:
+	.octa	0
+.Lrshift_and_bswap_table:
+	.octa	0xffffffffffffffffffffffffffffffff
+.Lbswap_mask:
+	.octa	0x000102030405060708090a0b0c0d0e0f
+
+	// Sixteen 0x0f bytes.  By XOR'ing an entry of .Lrshift_and_bswap_table
+	// with this, we get a mask that left-shifts by '16 - len' bytes.
+.Lfifteens:
+	.octa	0x0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f
+
+	// This is the GHASH reducing polynomial without its constant term, i.e.
+	// x^128 + x^7 + x^2 + x, represented using the backwards mapping
+	// between bits and polynomial coefficients.
+	//
+	// Alternatively, it can be interpreted as the naturally-ordered
+	// representation of the polynomial x^127 + x^126 + x^121 + 1, i.e. the
+	// "reversed" GHASH reducing polynomial without its x^128 term.
+.Lgfpoly:
+	.octa	0xc2000000000000000000000000000001
+
+	// Same as above, but with the (1 << 64) bit set.
+.Lgfpoly_and_internal_carrybit:
+	.octa	0xc2000000000000010000000000000001
+
+	// Values needed to prepare the initial vector of counter blocks.
+.Lctr_pattern:
+	.octa	0
+	.octa	1
+
+	// The number of AES blocks per vector, as a 128-bit value.
+.Linc_2blocks:
+	.octa	2
+
+// Offsets in struct aes_gcm_key_vaes_avx2
+#define OFFSETOF_AESKEYLEN	480
+#define OFFSETOF_H_POWERS	512
+#define NUM_H_POWERS		8
+#define OFFSETOFEND_H_POWERS    (OFFSETOF_H_POWERS + (NUM_H_POWERS * 16))
+#define OFFSETOF_H_POWERS_XORED	OFFSETOFEND_H_POWERS
+
+.text
+
+// Do one step of GHASH-multiplying the 128-bit lanes of \a by the 128-bit lanes
+// of \b and storing the reduced products in \dst.  Uses schoolbook
+// multiplication.
+.macro	_ghash_mul_step	i, a, b, dst, gfpoly, t0, t1, t2
+.if \i == 0
+	vpclmulqdq	$0x00, \a, \b, \t0	  // LO = a_L * b_L
+	vpclmulqdq	$0x01, \a, \b, \t1	  // MI_0 = a_L * b_H
+.elseif \i == 1
+	vpclmulqdq	$0x10, \a, \b, \t2	  // MI_1 = a_H * b_L
+.elseif \i == 2
+	vpxor		\t2, \t1, \t1		  // MI = MI_0 + MI_1
+.elseif \i == 3
+	vpclmulqdq	$0x01, \t0, \gfpoly, \t2  // LO_L*(x^63 + x^62 + x^57)
+.elseif \i == 4
+	vpshufd		$0x4e, \t0, \t0		  // Swap halves of LO
+.elseif \i == 5
+	vpxor		\t0, \t1, \t1		  // Fold LO into MI (part 1)
+	vpxor		\t2, \t1, \t1		  // Fold LO into MI (part 2)
+.elseif \i == 6
+	vpclmulqdq	$0x11, \a, \b, \dst	  // HI = a_H * b_H
+.elseif \i == 7
+	vpclmulqdq	$0x01, \t1, \gfpoly, \t0  // MI_L*(x^63 + x^62 + x^57)
+.elseif \i == 8
+	vpshufd		$0x4e, \t1, \t1		  // Swap halves of MI
+.elseif \i == 9
+	vpxor		\t1, \dst, \dst		  // Fold MI into HI (part 1)
+	vpxor		\t0, \dst, \dst		  // Fold MI into HI (part 2)
+.endif
+.endm
+
+// GHASH-multiply the 128-bit lanes of \a by the 128-bit lanes of \b and store
+// the reduced products in \dst.  See _ghash_mul_step for full explanation.
+.macro	_ghash_mul	a, b, dst, gfpoly, t0, t1, t2
+.irp i, 0,1,2,3,4,5,6,7,8,9
+	_ghash_mul_step	\i, \a, \b, \dst, \gfpoly, \t0, \t1, \t2
+.endr
+.endm
+
+// GHASH-multiply the 128-bit lanes of \a by the 128-bit lanes of \b and add the
+// *unreduced* products to \lo, \mi, and \hi.
+.macro	_ghash_mul_noreduce	a, b, lo, mi, hi, t0
+	vpclmulqdq	$0x00, \a, \b, \t0	// a_L * b_L
+	vpxor		\t0, \lo, \lo
+	vpclmulqdq	$0x01, \a, \b, \t0	// a_L * b_H
+	vpxor		\t0, \mi, \mi
+	vpclmulqdq	$0x10, \a, \b, \t0	// a_H * b_L
+	vpxor		\t0, \mi, \mi
+	vpclmulqdq	$0x11, \a, \b, \t0	// a_H * b_H
+	vpxor		\t0, \hi, \hi
+.endm
+
+// Reduce the unreduced products from \lo, \mi, and \hi and store the 128-bit
+// reduced products in \hi.  See _ghash_mul_step for explanation of reduction.
+.macro	_ghash_reduce	lo, mi, hi, gfpoly, t0
+	vpclmulqdq	$0x01, \lo, \gfpoly, \t0
+	vpshufd		$0x4e, \lo, \lo
+	vpxor		\lo, \mi, \mi
+	vpxor		\t0, \mi, \mi
+	vpclmulqdq	$0x01, \mi, \gfpoly, \t0
+	vpshufd		$0x4e, \mi, \mi
+	vpxor		\mi, \hi, \hi
+	vpxor		\t0, \hi, \hi
+.endm
+
+// This is a specialized version of _ghash_mul that computes \a * \a, i.e. it
+// squares \a.  It skips computing MI = (a_L * a_H) + (a_H * a_L) = 0.
+.macro	_ghash_square	a, dst, gfpoly, t0, t1
+	vpclmulqdq	$0x00, \a, \a, \t0	  // LO = a_L * a_L
+	vpclmulqdq	$0x11, \a, \a, \dst	  // HI = a_H * a_H
+	vpclmulqdq	$0x01, \t0, \gfpoly, \t1  // LO_L*(x^63 + x^62 + x^57)
+	vpshufd		$0x4e, \t0, \t0		  // Swap halves of LO
+	vpxor		\t0, \t1, \t1		  // Fold LO into MI
+	vpclmulqdq	$0x01, \t1, \gfpoly, \t0  // MI_L*(x^63 + x^62 + x^57)
+	vpshufd		$0x4e, \t1, \t1		  // Swap halves of MI
+	vpxor		\t1, \dst, \dst		  // Fold MI into HI (part 1)
+	vpxor		\t0, \dst, \dst		  // Fold MI into HI (part 2)
+.endm
+
+// void aes_gcm_precompute_vaes_avx2(struct aes_gcm_key_vaes_avx2 *key);
+//
+// Given the expanded AES key |key->base.aes_key|, derive the GHASH subkey and
+// initialize |key->h_powers| and |key->h_powers_xored|.
+//
+// We use h_powers[0..7] to store H^8 through H^1, and h_powers_xored[0..7] to
+// store the 64-bit halves of the key powers XOR'd together (for Karatsuba
+// multiplication) in the order 8,6,7,5,4,2,3,1.
+SYM_FUNC_START(aes_gcm_precompute_vaes_avx2)
+
+	// Function arguments
+	.set	KEY,		%rdi
+
+	// Additional local variables
+	.set	POWERS_PTR,	%rsi
+	.set	RNDKEYLAST_PTR,	%rdx
+	.set	TMP0,		%ymm0
+	.set	TMP0_XMM,	%xmm0
+	.set	TMP1,		%ymm1
+	.set	TMP1_XMM,	%xmm1
+	.set	TMP2,		%ymm2
+	.set	TMP2_XMM,	%xmm2
+	.set	H_CUR,		%ymm3
+	.set	H_CUR_XMM,	%xmm3
+	.set	H_CUR2,		%ymm4
+	.set	H_CUR2_XMM,	%xmm4
+	.set	H_INC,		%ymm5
+	.set	H_INC_XMM,	%xmm5
+	.set	GFPOLY,		%ymm6
+	.set	GFPOLY_XMM,	%xmm6
+
+	// Encrypt an all-zeroes block to get the raw hash subkey.
+	movl		OFFSETOF_AESKEYLEN(KEY), %eax
+	lea		6*16(KEY,%rax,4), RNDKEYLAST_PTR
+	vmovdqu		(KEY), H_CUR_XMM  // Zero-th round key XOR all-zeroes block
+	lea		16(KEY), %rax
+1:
+	vaesenc		(%rax), H_CUR_XMM, H_CUR_XMM
+	add		$16, %rax
+	cmp		%rax, RNDKEYLAST_PTR
+	jne		1b
+	vaesenclast	(RNDKEYLAST_PTR), H_CUR_XMM, H_CUR_XMM
+
+	// Reflect the bytes of the raw hash subkey.
+	vpshufb		.Lbswap_mask(%rip), H_CUR_XMM, H_CUR_XMM
+
+	// Finish preprocessing the byte-reflected hash subkey by multiplying it
+	// by x^-1 ("standard" interpretation of polynomial coefficients) or
+	// equivalently x^1 (natural interpretation).  This gets the key into a
+	// format that avoids having to bit-reflect the data blocks later.
+	vpshufd		$0xd3, H_CUR_XMM, TMP0_XMM
+	vpsrad		$31, TMP0_XMM, TMP0_XMM
+	vpaddq		H_CUR_XMM, H_CUR_XMM, H_CUR_XMM
+	vpand		.Lgfpoly_and_internal_carrybit(%rip), TMP0_XMM, TMP0_XMM
+	vpxor		TMP0_XMM, H_CUR_XMM, H_CUR_XMM
+
+	// Load the gfpoly constant.
+	vbroadcasti128	.Lgfpoly(%rip), GFPOLY
+
+	// Square H^1 to get H^2.
+	_ghash_square	H_CUR_XMM, H_INC_XMM, GFPOLY_XMM, TMP0_XMM, TMP1_XMM
+
+	// Create H_CUR = [H^2, H^1] and H_INC = [H^2, H^2].
+	vinserti128	$1, H_CUR_XMM, H_INC, H_CUR
+	vinserti128	$1, H_INC_XMM, H_INC, H_INC
+
+	// Compute H_CUR2 = [H^4, H^3].
+	_ghash_mul	H_INC, H_CUR, H_CUR2, GFPOLY, TMP0, TMP1, TMP2
+
+	// Store [H^2, H^1] and [H^4, H^3].
+	vmovdqu		H_CUR, OFFSETOF_H_POWERS+3*32(KEY)
+	vmovdqu		H_CUR2, OFFSETOF_H_POWERS+2*32(KEY)
+
+	// For Karatsuba multiplication: compute and store the two 64-bit halves
+	// of each key power XOR'd together.  Order is 4,2,3,1.
+	vpunpcklqdq	H_CUR, H_CUR2, TMP0
+	vpunpckhqdq	H_CUR, H_CUR2, TMP1
+	vpxor		TMP1, TMP0, TMP0
+	vmovdqu		TMP0, OFFSETOF_H_POWERS_XORED+32(KEY)
+
+	// Compute and store H_CUR = [H^6, H^5] and H_CUR2 = [H^8, H^7].
+	_ghash_mul	H_INC, H_CUR2, H_CUR, GFPOLY, TMP0, TMP1, TMP2
+	_ghash_mul	H_INC, H_CUR, H_CUR2, GFPOLY, TMP0, TMP1, TMP2
+	vmovdqu		H_CUR, OFFSETOF_H_POWERS+1*32(KEY)
+	vmovdqu		H_CUR2, OFFSETOF_H_POWERS+0*32(KEY)
+
+	// Again, compute and store the two 64-bit halves of each key power
+	// XOR'd together.  Order is 8,6,7,5.
+	vpunpcklqdq	H_CUR, H_CUR2, TMP0
+	vpunpckhqdq	H_CUR, H_CUR2, TMP1
+	vpxor		TMP1, TMP0, TMP0
+	vmovdqu		TMP0, OFFSETOF_H_POWERS_XORED(KEY)
+
+	vzeroupper
+	RET
+SYM_FUNC_END(aes_gcm_precompute_vaes_avx2)
+
+// Do one step of the GHASH update of four vectors of data blocks.
+//   \i: the step to do, 0 through 9
+//   \ghashdata_ptr: pointer to the data blocks (ciphertext or AAD)
+//   KEY: pointer to struct aes_gcm_key_vaes_avx2
+//   BSWAP_MASK: mask for reflecting the bytes of blocks
+//   H_POW[2-1]_XORED: cached values from KEY->h_powers_xored
+//   TMP[0-2]: temporary registers.  TMP[1-2] must be preserved across steps.
+//   LO, MI: working state for this macro that must be preserved across steps
+//   GHASH_ACC: the GHASH accumulator (input/output)
+.macro	_ghash_step_4x	i, ghashdata_ptr
+	.set		HI, GHASH_ACC # alias
+	.set		HI_XMM, GHASH_ACC_XMM
+.if \i == 0
+	// First vector
+	vmovdqu		0*32(\ghashdata_ptr), TMP1
+	vpshufb		BSWAP_MASK, TMP1, TMP1
+	vmovdqu		OFFSETOF_H_POWERS+0*32(KEY), TMP2
+	vpxor		GHASH_ACC, TMP1, TMP1
+	vpclmulqdq	$0x00, TMP2, TMP1, LO
+	vpclmulqdq	$0x11, TMP2, TMP1, HI
+	vpunpckhqdq	TMP1, TMP1, TMP0
+	vpxor		TMP1, TMP0, TMP0
+	vpclmulqdq	$0x00, H_POW2_XORED, TMP0, MI
+.elseif \i == 1
+.elseif \i == 2
+	// Second vector
+	vmovdqu		1*32(\ghashdata_ptr), TMP1
+	vpshufb		BSWAP_MASK, TMP1, TMP1
+	vmovdqu		OFFSETOF_H_POWERS+1*32(KEY), TMP2
+	vpclmulqdq	$0x00, TMP2, TMP1, TMP0
+	vpxor		TMP0, LO, LO
+	vpclmulqdq	$0x11, TMP2, TMP1, TMP0
+	vpxor		TMP0, HI, HI
+	vpunpckhqdq	TMP1, TMP1, TMP0
+	vpxor		TMP1, TMP0, TMP0
+	vpclmulqdq	$0x10, H_POW2_XORED, TMP0, TMP0
+	vpxor		TMP0, MI, MI
+.elseif \i == 3
+	// Third vector
+	vmovdqu		2*32(\ghashdata_ptr), TMP1
+	vpshufb		BSWAP_MASK, TMP1, TMP1
+	vmovdqu		OFFSETOF_H_POWERS+2*32(KEY), TMP2
+.elseif \i == 4
+	vpclmulqdq	$0x00, TMP2, TMP1, TMP0
+	vpxor		TMP0, LO, LO
+	vpclmulqdq	$0x11, TMP2, TMP1, TMP0
+	vpxor		TMP0, HI, HI
+.elseif \i == 5
+	vpunpckhqdq	TMP1, TMP1, TMP0
+	vpxor		TMP1, TMP0, TMP0
+	vpclmulqdq	$0x00, H_POW1_XORED, TMP0, TMP0
+	vpxor		TMP0, MI, MI
+
+	// Fourth vector
+	vmovdqu		3*32(\ghashdata_ptr), TMP1
+	vpshufb		BSWAP_MASK, TMP1, TMP1
+.elseif \i == 6
+	vmovdqu		OFFSETOF_H_POWERS+3*32(KEY), TMP2
+	vpclmulqdq	$0x00, TMP2, TMP1, TMP0
+	vpxor		TMP0, LO, LO
+	vpclmulqdq	$0x11, TMP2, TMP1, TMP0
+	vpxor		TMP0, HI, HI
+	vpunpckhqdq	TMP1, TMP1, TMP0
+	vpxor		TMP1, TMP0, TMP0
+	vpclmulqdq	$0x10, H_POW1_XORED, TMP0, TMP0
+	vpxor		TMP0, MI, MI
+.elseif \i == 7
+	// Finalize 'mi' following Karatsuba multiplication.
+	vpxor		LO, MI, MI
+	vpxor		HI, MI, MI
+
+	// Fold lo into mi.
+	vbroadcasti128	.Lgfpoly(%rip), TMP2
+	vpclmulqdq	$0x01, LO, TMP2, TMP0
+	vpshufd		$0x4e, LO, LO
+	vpxor		LO, MI, MI
+	vpxor		TMP0, MI, MI
+.elseif \i == 8
+	// Fold mi into hi.
+	vpclmulqdq	$0x01, MI, TMP2, TMP0
+	vpshufd		$0x4e, MI, MI
+	vpxor		MI, HI, HI
+	vpxor		TMP0, HI, HI
+.elseif \i == 9
+	vextracti128	$1, HI, TMP0_XMM
+	vpxor		TMP0_XMM, HI_XMM, GHASH_ACC_XMM
+.endif
+.endm
+
+// Update GHASH with four vectors of data blocks.  See _ghash_step_4x for full
+// explanation.
+.macro	_ghash_4x	ghashdata_ptr
+.irp i, 0,1,2,3,4,5,6,7,8,9
+	_ghash_step_4x	\i, \ghashdata_ptr
+.endr
+.endm
+
+// Load 1 <= %ecx <= 16 bytes from the pointer \src into the xmm register \dst
+// and zeroize any remaining bytes.  Clobbers %rax, %rcx, and \tmp{64,32}.
+.macro	_load_partial_block	src, dst, tmp64, tmp32
+	sub		$8, %ecx		// LEN - 8
+	jle		.Lle8\@
+
+	// Load 9 <= LEN <= 16 bytes.
+	vmovq		(\src), \dst		// Load first 8 bytes
+	mov		(\src, %rcx), %rax	// Load last 8 bytes
+	neg		%ecx
+	shl		$3, %ecx
+	shr		%cl, %rax		// Discard overlapping bytes
+	vpinsrq		$1, %rax, \dst, \dst
+	jmp		.Ldone\@
+
+.Lle8\@:
+	add		$4, %ecx		// LEN - 4
+	jl		.Llt4\@
+
+	// Load 4 <= LEN <= 8 bytes.
+	mov		(\src), %eax		// Load first 4 bytes
+	mov		(\src, %rcx), \tmp32	// Load last 4 bytes
+	jmp		.Lcombine\@
+
+.Llt4\@:
+	// Load 1 <= LEN <= 3 bytes.
+	add		$2, %ecx		// LEN - 2
+	movzbl		(\src), %eax		// Load first byte
+	jl		.Lmovq\@
+	movzwl		(\src, %rcx), \tmp32	// Load last 2 bytes
+.Lcombine\@:
+	shl		$3, %ecx
+	shl		%cl, \tmp64
+	or		\tmp64, %rax		// Combine the two parts
+.Lmovq\@:
+	vmovq		%rax, \dst
+.Ldone\@:
+.endm
+
+// Store 1 <= %ecx <= 16 bytes from the xmm register \src to the pointer \dst.
+// Clobbers %rax, %rcx, and \tmp{64,32}.
+.macro	_store_partial_block	src, dst, tmp64, tmp32
+	sub		$8, %ecx		// LEN - 8
+	jl		.Llt8\@
+
+	// Store 8 <= LEN <= 16 bytes.
+	vpextrq		$1, \src, %rax
+	mov		%ecx, \tmp32
+	shl		$3, %ecx
+	ror		%cl, %rax
+	mov		%rax, (\dst, \tmp64)	// Store last LEN - 8 bytes
+	vmovq		\src, (\dst)		// Store first 8 bytes
+	jmp		.Ldone\@
+
+.Llt8\@:
+	add		$4, %ecx		// LEN - 4
+	jl		.Llt4\@
+
+	// Store 4 <= LEN <= 7 bytes.
+	vpextrd		$1, \src, %eax
+	mov		%ecx, \tmp32
+	shl		$3, %ecx
+	ror		%cl, %eax
+	mov		%eax, (\dst, \tmp64)	// Store last LEN - 4 bytes
+	vmovd		\src, (\dst)		// Store first 4 bytes
+	jmp		.Ldone\@
+
+.Llt4\@:
+	// Store 1 <= LEN <= 3 bytes.
+	vpextrb		$0, \src, 0(\dst)
+	cmp		$-2, %ecx		// LEN - 4 == -2, i.e. LEN == 2?
+	jl		.Ldone\@
+	vpextrb		$1, \src, 1(\dst)
+	je		.Ldone\@
+	vpextrb		$2, \src, 2(\dst)
+.Ldone\@:
+.endm
+
+// void aes_gcm_aad_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
+//				     u8 ghash_acc[16],
+//				     const u8 *aad, int aadlen);
+//
+// This function processes the AAD (Additional Authenticated Data) in GCM.
+// Using the key |key|, it updates the GHASH accumulator |ghash_acc| with the
+// data given by |aad| and |aadlen|.  On the first call, |ghash_acc| must be all
+// zeroes.  |aadlen| must be a multiple of 16, except on the last call where it
+// can be any length.  The caller must do any buffering needed to ensure this.
+//
+// This handles large amounts of AAD efficiently, while also keeping overhead
+// low for small amounts which is the common case.  TLS and IPsec use less than
+// one block of AAD, but (uncommonly) other use cases may use much more.
+SYM_FUNC_START(aes_gcm_aad_update_vaes_avx2)
+
+	// Function arguments
+	.set	KEY,		%rdi
+	.set	GHASH_ACC_PTR,	%rsi
+	.set	AAD,		%rdx
+	.set	AADLEN,		%ecx	// Must be %ecx for _load_partial_block
+	.set	AADLEN64,	%rcx	// Zero-extend AADLEN before using!
+
+	// Additional local variables.
+	// %rax and %r8 are used as temporary registers.
+	.set	TMP0,		%ymm0
+	.set	TMP0_XMM,	%xmm0
+	.set	TMP1,		%ymm1
+	.set	TMP1_XMM,	%xmm1
+	.set	TMP2,		%ymm2
+	.set	TMP2_XMM,	%xmm2
+	.set	LO,		%ymm3
+	.set	LO_XMM,		%xmm3
+	.set	MI,		%ymm4
+	.set	MI_XMM,		%xmm4
+	.set	GHASH_ACC,	%ymm5
+	.set	GHASH_ACC_XMM,	%xmm5
+	.set	BSWAP_MASK,	%ymm6
+	.set	BSWAP_MASK_XMM,	%xmm6
+	.set	GFPOLY,		%ymm7
+	.set	GFPOLY_XMM,	%xmm7
+	.set	H_POW2_XORED,	%ymm8
+	.set	H_POW1_XORED,	%ymm9
+
+	// Load the bswap_mask and gfpoly constants.  Since AADLEN is usually
+	// small, usually only 128-bit vectors will be used.  So as an
+	// optimization, don't broadcast these constants to both 128-bit lanes
+	// quite yet.
+	vmovdqu		.Lbswap_mask(%rip), BSWAP_MASK_XMM
+	vmovdqu		.Lgfpoly(%rip), GFPOLY_XMM
+
+	// Load the GHASH accumulator.
+	vmovdqu		(GHASH_ACC_PTR), GHASH_ACC_XMM
+
+	// Check for the common case of AADLEN <= 16, as well as AADLEN == 0.
+	test		AADLEN, AADLEN
+	jz		.Laad_done
+	cmp		$16, AADLEN
+	jle		.Laad_lastblock
+
+	// AADLEN > 16, so we'll operate on full vectors.  Broadcast bswap_mask
+	// and gfpoly to both 128-bit lanes.
+	vinserti128	$1, BSWAP_MASK_XMM, BSWAP_MASK, BSWAP_MASK
+	vinserti128	$1, GFPOLY_XMM, GFPOLY, GFPOLY
+
+	// If AADLEN >= 128, update GHASH with 128 bytes of AAD at a time.
+	add		$-128, AADLEN	// 128 is 4 bytes, -128 is 1 byte
+	jl		.Laad_loop_4x_done
+	vmovdqu		OFFSETOF_H_POWERS_XORED(KEY), H_POW2_XORED
+	vmovdqu		OFFSETOF_H_POWERS_XORED+32(KEY), H_POW1_XORED
+.Laad_loop_4x:
+	_ghash_4x	AAD
+	sub		$-128, AAD
+	add		$-128, AADLEN
+	jge		.Laad_loop_4x
+.Laad_loop_4x_done:
+
+	// If AADLEN >= 32, update GHASH with 32 bytes of AAD at a time.
+	add		$96, AADLEN
+	jl		.Laad_loop_1x_done
+.Laad_loop_1x:
+	vmovdqu		(AAD), TMP0
+	vpshufb		BSWAP_MASK, TMP0, TMP0
+	vpxor		TMP0, GHASH_ACC, GHASH_ACC
+	vmovdqu		OFFSETOFEND_H_POWERS-32(KEY), TMP0
+	_ghash_mul	TMP0, GHASH_ACC, GHASH_ACC, GFPOLY, TMP1, TMP2, LO
+	vextracti128	$1, GHASH_ACC, TMP0_XMM
+	vpxor		TMP0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM
+	add		$32, AAD
+	sub		$32, AADLEN
+	jge		.Laad_loop_1x
+.Laad_loop_1x_done:
+	add		$32, AADLEN
+	// Now 0 <= AADLEN < 32.
+
+	jz		.Laad_done
+	cmp		$16, AADLEN
+	jle		.Laad_lastblock
+
+.Laad_last2blocks:
+	// Update GHASH with the remaining 17 <= AADLEN <= 31 bytes of AAD.
+	mov		AADLEN, AADLEN	// Zero-extend AADLEN to AADLEN64.
+	vmovdqu		(AAD), TMP0_XMM
+	vmovdqu		-16(AAD, AADLEN64), TMP1_XMM
+	vpshufb		BSWAP_MASK_XMM, TMP0_XMM, TMP0_XMM
+	vpxor		TMP0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM
+	lea		.Lrshift_and_bswap_table(%rip), %rax
+	vpshufb		-16(%rax, AADLEN64), TMP1_XMM, TMP1_XMM
+	vinserti128	$1, TMP1_XMM, GHASH_ACC, GHASH_ACC
+	vmovdqu		OFFSETOFEND_H_POWERS-32(KEY), TMP0
+	_ghash_mul	TMP0, GHASH_ACC, GHASH_ACC, GFPOLY, TMP1, TMP2, LO
+	vextracti128	$1, GHASH_ACC, TMP0_XMM
+	vpxor		TMP0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM
+	jmp		.Laad_done
+
+.Laad_lastblock:
+	// Update GHASH with the remaining 1 <= AADLEN <= 16 bytes of AAD.
+	_load_partial_block	AAD, TMP0_XMM, %r8, %r8d
+	vpshufb		BSWAP_MASK_XMM, TMP0_XMM, TMP0_XMM
+	vpxor		TMP0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM
+	vmovdqu		OFFSETOFEND_H_POWERS-16(KEY), TMP0_XMM
+	_ghash_mul	TMP0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM, GFPOLY_XMM, \
+			TMP1_XMM, TMP2_XMM, LO_XMM
+
+.Laad_done:
+	// Store the updated GHASH accumulator back to memory.
+	vmovdqu		GHASH_ACC_XMM, (GHASH_ACC_PTR)
+
+	vzeroupper
+	RET
+SYM_FUNC_END(aes_gcm_aad_update_vaes_avx2)
+
+// Do one non-last round of AES encryption on the blocks in the given AESDATA
+// vectors using the round key that has been broadcast to all 128-bit lanes of
+// \round_key.
+.macro	_vaesenc	round_key, vecs:vararg
+.irp i, \vecs
+	vaesenc		\round_key, AESDATA\i, AESDATA\i
+.endr
+.endm
+
+// Generate counter blocks in the given AESDATA vectors, then do the zero-th AES
+// round on them.  Clobbers TMP0.
+.macro	_ctr_begin	vecs:vararg
+	vbroadcasti128	.Linc_2blocks(%rip), TMP0
+.irp i, \vecs
+	vpshufb		BSWAP_MASK, LE_CTR, AESDATA\i
+	vpaddd		TMP0, LE_CTR, LE_CTR
+.endr
+.irp i, \vecs
+	vpxor		RNDKEY0, AESDATA\i, AESDATA\i
+.endr
+.endm
+
+// Generate and encrypt counter blocks in the given AESDATA vectors, excluding
+// the last AES round.  Clobbers TMP0.
+.macro	_aesenc_loop	vecs:vararg
+	_ctr_begin	\vecs
+	lea		16(KEY), %rax
+.Laesenc_loop\@:
+	vbroadcasti128	(%rax), TMP0
+	_vaesenc	TMP0, \vecs
+	add		$16, %rax
+	cmp		%rax, RNDKEYLAST_PTR
+	jne		.Laesenc_loop\@
+.endm
+
+// Finalize the keystream blocks in the given AESDATA vectors by doing the last
+// AES round, then XOR those keystream blocks with the corresponding data.
+// Reduce latency by doing the XOR before the vaesenclast, utilizing the
+// property vaesenclast(key, a) ^ b == vaesenclast(key ^ b, a).  Clobbers TMP0.
+.macro	_aesenclast_and_xor	vecs:vararg
+.irp i, \vecs
+	vpxor		\i*32(SRC), RNDKEYLAST, TMP0
+	vaesenclast	TMP0, AESDATA\i, AESDATA\i
+.endr
+.irp i, \vecs
+	vmovdqu		AESDATA\i, \i*32(DST)
+.endr
+.endm
+
+// void aes_gcm_{enc,dec}_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
+//					   const u32 le_ctr[4], u8 ghash_acc[16],
+//					   const u8 *src, u8 *dst, int datalen);
+//
+// This macro generates a GCM encryption or decryption update function with the
+// above prototype (with \enc selecting which one).  The function computes the
+// next portion of the CTR keystream, XOR's it with |datalen| bytes from |src|,
+// and writes the resulting encrypted or decrypted data to |dst|.  It also
+// updates the GHASH accumulator |ghash_acc| using the next |datalen| ciphertext
+// bytes.
+//
+// |datalen| must be a multiple of 16, except on the last call where it can be
+// any length.  The caller must do any buffering needed to ensure this.  Both
+// in-place and out-of-place en/decryption are supported.
+//
+// |le_ctr| must give the current counter in little-endian format.  This
+// function loads the counter from |le_ctr| and increments the loaded counter as
+// needed, but it does *not* store the updated counter back to |le_ctr|.  The
+// caller must update |le_ctr| if any more data segments follow.  Internally,
+// only the low 32-bit word of the counter is incremented, following the GCM
+// standard.
+.macro	_aes_gcm_update	enc
+
+	// Function arguments
+	.set	KEY,		%rdi
+	.set	LE_CTR_PTR,	%rsi
+	.set	LE_CTR_PTR32,	%esi
+	.set	GHASH_ACC_PTR,	%rdx
+	.set	SRC,		%rcx	// Assumed to be %rcx.
+					// See .Ltail_xor_and_ghash_partial_vec
+	.set	DST,		%r8
+	.set	DATALEN,	%r9d
+	.set	DATALEN64,	%r9	// Zero-extend DATALEN before using!
+
+	// Additional local variables
+
+	// %rax is used as a temporary register.  LE_CTR_PTR is also available
+	// as a temporary register after the counter is loaded.
+
+	// AES key length in bytes
+	.set	AESKEYLEN,	%r10d
+	.set	AESKEYLEN64,	%r10
+
+	// Pointer to the last AES round key for the chosen AES variant
+	.set	RNDKEYLAST_PTR,	%r11
+
+	// BSWAP_MASK is the shuffle mask for byte-reflecting 128-bit values
+	// using vpshufb, copied to all 128-bit lanes.
+	.set	BSWAP_MASK,	%ymm0
+	.set	BSWAP_MASK_XMM,	%xmm0
+
+	// GHASH_ACC is the accumulator variable for GHASH.  When fully reduced,
+	// only the lowest 128-bit lane can be nonzero.  When not fully reduced,
+	// more than one lane may be used, and they need to be XOR'd together.
+	.set	GHASH_ACC,	%ymm1
+	.set	GHASH_ACC_XMM,	%xmm1
+
+	// TMP[0-2] are temporary registers.
+	.set	TMP0,		%ymm2
+	.set	TMP0_XMM,	%xmm2
+	.set	TMP1,		%ymm3
+	.set	TMP1_XMM,	%xmm3
+	.set	TMP2,		%ymm4
+	.set	TMP2_XMM,	%xmm4
+
+	// LO and MI are used to accumulate unreduced GHASH products.
+	.set	LO,		%ymm5
+	.set	LO_XMM,		%xmm5
+	.set	MI,		%ymm6
+	.set	MI_XMM,		%xmm6
+
+	// H_POW[2-1]_XORED contain cached values from KEY->h_powers_xored.  The
+	// descending numbering reflects the order of the key powers.
+	.set	H_POW2_XORED,	%ymm7
+	.set	H_POW2_XORED_XMM, %xmm7
+	.set	H_POW1_XORED,	%ymm8
+	.set	H_POW1_XORED_XMM, %xmm8
+
+	// RNDKEY0 caches the zero-th round key, and RNDKEYLAST the last one.
+	.set	RNDKEY0,	%ymm9
+	.set	RNDKEYLAST,	%ymm10
+
+	// LE_CTR contains the next set of little-endian counter blocks.
+	.set	LE_CTR,		%ymm11
+
+	// AESDATA[0-3] hold the counter blocks that are being encrypted by AES.
+	.set	AESDATA0,	%ymm12
+	.set	AESDATA0_XMM,	%xmm12
+	.set	AESDATA1,	%ymm13
+	.set	AESDATA1_XMM,	%xmm13
+	.set	AESDATA2,	%ymm14
+	.set	AESDATA2_XMM,	%xmm14
+	.set	AESDATA3,	%ymm15
+	.set	AESDATA3_XMM,	%xmm15
+
+.if \enc
+	.set	GHASHDATA_PTR,	DST
+.else
+	.set	GHASHDATA_PTR,	SRC
+.endif
+
+	vbroadcasti128	.Lbswap_mask(%rip), BSWAP_MASK
+
+	// Load the GHASH accumulator and the starting counter.
+	vmovdqu		(GHASH_ACC_PTR), GHASH_ACC_XMM
+	vbroadcasti128	(LE_CTR_PTR), LE_CTR
+
+	// Load the AES key length in bytes.
+	movl		OFFSETOF_AESKEYLEN(KEY), AESKEYLEN
+
+	// Make RNDKEYLAST_PTR point to the last AES round key.  This is the
+	// round key with index 10, 12, or 14 for AES-128, AES-192, or AES-256
+	// respectively.  Then load the zero-th and last round keys.
+	lea		6*16(KEY,AESKEYLEN64,4), RNDKEYLAST_PTR
+	vbroadcasti128	(KEY), RNDKEY0
+	vbroadcasti128	(RNDKEYLAST_PTR), RNDKEYLAST
+
+	// Finish initializing LE_CTR by adding 1 to the second block.
+	vpaddd		.Lctr_pattern(%rip), LE_CTR, LE_CTR
+
+	// If there are at least 128 bytes of data, then continue into the loop
+	// that processes 128 bytes of data at a time.  Otherwise skip it.
+	add		$-128, DATALEN	// 128 is 4 bytes, -128 is 1 byte
+	jl		.Lcrypt_loop_4x_done\@
+
+	vmovdqu		OFFSETOF_H_POWERS_XORED(KEY), H_POW2_XORED
+	vmovdqu		OFFSETOF_H_POWERS_XORED+32(KEY), H_POW1_XORED
+
+	// Main loop: en/decrypt and hash 4 vectors (128 bytes) at a time.
+
+.if \enc
+	// Encrypt the first 4 vectors of plaintext blocks.
+	_aesenc_loop	0,1,2,3
+	_aesenclast_and_xor	0,1,2,3
+	sub		$-128, SRC	// 128 is 4 bytes, -128 is 1 byte
+	add		$-128, DATALEN
+	jl		.Lghash_last_ciphertext_4x\@
+.endif
+
+.align 16
+.Lcrypt_loop_4x\@:
+
+	// Start the AES encryption of the counter blocks.
+	_ctr_begin	0,1,2,3
+	cmp		$24, AESKEYLEN
+	jl		128f	// AES-128?
+	je		192f	// AES-192?
+	// AES-256
+	vbroadcasti128	-13*16(RNDKEYLAST_PTR), TMP0
+	_vaesenc	TMP0, 0,1,2,3
+	vbroadcasti128	-12*16(RNDKEYLAST_PTR), TMP0
+	_vaesenc	TMP0, 0,1,2,3
+192:
+	vbroadcasti128	-11*16(RNDKEYLAST_PTR), TMP0
+	_vaesenc	TMP0, 0,1,2,3
+	vbroadcasti128	-10*16(RNDKEYLAST_PTR), TMP0
+	_vaesenc	TMP0, 0,1,2,3
+128:
+
+	// Finish the AES encryption of the counter blocks in AESDATA[0-3],
+	// interleaved with the GHASH update of the ciphertext blocks.
+.irp i, 9,8,7,6,5,4,3,2,1
+	_ghash_step_4x  (9 - \i), GHASHDATA_PTR
+	vbroadcasti128	-\i*16(RNDKEYLAST_PTR), TMP0
+	_vaesenc	TMP0, 0,1,2,3
+.endr
+	_ghash_step_4x	9, GHASHDATA_PTR
+.if \enc
+	sub		$-128, DST	// 128 is 4 bytes, -128 is 1 byte
+.endif
+	_aesenclast_and_xor	0,1,2,3
+	sub		$-128, SRC
+.if !\enc
+	sub		$-128, DST
+.endif
+	add		$-128, DATALEN
+	jge		.Lcrypt_loop_4x\@
+
+.if \enc
+.Lghash_last_ciphertext_4x\@:
+	// Update GHASH with the last set of ciphertext blocks.
+	_ghash_4x	DST
+	sub		$-128, DST
+.endif
+
+.Lcrypt_loop_4x_done\@:
+
+	// Undo the extra subtraction by 128 and check whether data remains.
+	sub		$-128, DATALEN	// 128 is 4 bytes, -128 is 1 byte
+	jz		.Ldone\@
+
+	// The data length isn't a multiple of 128 bytes.  Process the remaining
+	// data of length 1 <= DATALEN < 128.
+	//
+	// Since there are enough key powers available for all remaining data,
+	// there is no need to do a GHASH reduction after each iteration.
+	// Instead, multiply each remaining block by its own key power, and only
+	// do a GHASH reduction at the very end.
+
+	// Make POWERS_PTR point to the key powers [H^N, H^(N-1), ...] where N
+	// is the number of blocks that remain.
+	.set		POWERS_PTR, LE_CTR_PTR	// LE_CTR_PTR is free to be reused.
+	.set		POWERS_PTR32, LE_CTR_PTR32
+	mov		DATALEN, %eax
+	neg		%rax
+	and		$~15, %rax  // -round_up(DATALEN, 16)
+	lea		OFFSETOFEND_H_POWERS(KEY,%rax), POWERS_PTR
+
+	// Start collecting the unreduced GHASH intermediate value LO, MI, HI.
+	.set		HI, H_POW2_XORED	// H_POW2_XORED is free to be reused.
+	.set		HI_XMM, H_POW2_XORED_XMM
+	vpxor		LO_XMM, LO_XMM, LO_XMM
+	vpxor		MI_XMM, MI_XMM, MI_XMM
+	vpxor		HI_XMM, HI_XMM, HI_XMM
+
+	// 1 <= DATALEN < 128.  Generate 2 or 4 more vectors of keystream blocks
+	// excluding the last AES round, depending on the remaining DATALEN.
+	cmp		$64, DATALEN
+	jg		.Ltail_gen_4_keystream_vecs\@
+	_aesenc_loop	0,1
+	cmp		$32, DATALEN
+	jge		.Ltail_xor_and_ghash_full_vec_loop\@
+	jmp		.Ltail_xor_and_ghash_partial_vec\@
+.Ltail_gen_4_keystream_vecs\@:
+	_aesenc_loop	0,1,2,3
+
+	// XOR the remaining data and accumulate the unreduced GHASH products
+	// for DATALEN >= 32, starting with one full 32-byte vector at a time.
+.Ltail_xor_and_ghash_full_vec_loop\@:
+.if \enc
+	_aesenclast_and_xor	0
+	vpshufb		BSWAP_MASK, AESDATA0, AESDATA0
+.else
+	vmovdqu		(SRC), TMP1
+	vpxor		TMP1, RNDKEYLAST, TMP0
+	vaesenclast	TMP0, AESDATA0, AESDATA0
+	vmovdqu		AESDATA0, (DST)
+	vpshufb		BSWAP_MASK, TMP1, AESDATA0
+.endif
+	// The ciphertext blocks (i.e. GHASH input data) are now in AESDATA0.
+	vpxor		GHASH_ACC, AESDATA0, AESDATA0
+	vmovdqu		(POWERS_PTR), TMP2
+	_ghash_mul_noreduce	TMP2, AESDATA0, LO, MI, HI, TMP0
+	vmovdqa		AESDATA1, AESDATA0
+	vmovdqa		AESDATA2, AESDATA1
+	vmovdqa		AESDATA3, AESDATA2
+	vpxor		GHASH_ACC_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM
+	add		$32, SRC
+	add		$32, DST
+	add		$32, POWERS_PTR
+	sub		$32, DATALEN
+	cmp		$32, DATALEN
+	jge		.Ltail_xor_and_ghash_full_vec_loop\@
+	test		DATALEN, DATALEN
+	jz		.Ltail_ghash_reduce\@
+
+.Ltail_xor_and_ghash_partial_vec\@:
+	// XOR the remaining data and accumulate the unreduced GHASH products,
+	// for 1 <= DATALEN < 32.
+	vaesenclast	RNDKEYLAST, AESDATA0, AESDATA0
+	cmp		$16, DATALEN
+	jle		.Ltail_xor_and_ghash_1to16bytes\@
+
+	// Handle 17 <= DATALEN < 32.
+
+	// Load a vpshufb mask that will right-shift by '32 - DATALEN' bytes
+	// (shifting in zeroes), then reflect all 16 bytes.
+	lea		.Lrshift_and_bswap_table(%rip), %rax
+	vmovdqu		-16(%rax, DATALEN64), TMP2_XMM
+
+	// Move the second keystream block to its own register and left-align it
+	vextracti128	$1, AESDATA0, AESDATA1_XMM
+	vpxor		.Lfifteens(%rip), TMP2_XMM, TMP0_XMM
+	vpshufb		TMP0_XMM, AESDATA1_XMM, AESDATA1_XMM
+
+	// Using overlapping loads and stores, XOR the source data with the
+	// keystream and write the destination data.  Then prepare the GHASH
+	// input data: the full ciphertext block and the zero-padded partial
+	// ciphertext block, both byte-reflected, in AESDATA0.
+.if \enc
+	vpxor		-16(SRC, DATALEN64), AESDATA1_XMM, AESDATA1_XMM
+	vpxor		(SRC), AESDATA0_XMM, AESDATA0_XMM
+	vmovdqu		AESDATA1_XMM, -16(DST, DATALEN64)
+	vmovdqu		AESDATA0_XMM, (DST)
+	vpshufb		TMP2_XMM, AESDATA1_XMM, AESDATA1_XMM
+	vpshufb		BSWAP_MASK_XMM, AESDATA0_XMM, AESDATA0_XMM
+.else
+	vmovdqu		-16(SRC, DATALEN64), TMP1_XMM
+	vmovdqu		(SRC), TMP0_XMM
+	vpxor		TMP1_XMM, AESDATA1_XMM, AESDATA1_XMM
+	vpxor		TMP0_XMM, AESDATA0_XMM, AESDATA0_XMM
+	vmovdqu		AESDATA1_XMM, -16(DST, DATALEN64)
+	vmovdqu		AESDATA0_XMM, (DST)
+	vpshufb		TMP2_XMM, TMP1_XMM, AESDATA1_XMM
+	vpshufb		BSWAP_MASK_XMM, TMP0_XMM, AESDATA0_XMM
+.endif
+	vpxor		GHASH_ACC_XMM, AESDATA0_XMM, AESDATA0_XMM
+	vinserti128	$1, AESDATA1_XMM, AESDATA0, AESDATA0
+	vmovdqu		(POWERS_PTR), TMP2
+	jmp		.Ltail_ghash_last_vec\@
+
+.Ltail_xor_and_ghash_1to16bytes\@:
+	// Handle 1 <= DATALEN <= 16.  Carefully load and store the
+	// possibly-partial block, which we mustn't access out of bounds.
+	vmovdqu		(POWERS_PTR), TMP2_XMM
+	mov		SRC, KEY	// Free up %rcx, assuming SRC == %rcx
+	mov		DATALEN, %ecx
+	_load_partial_block	KEY, TMP0_XMM, POWERS_PTR, POWERS_PTR32
+	vpxor		TMP0_XMM, AESDATA0_XMM, AESDATA0_XMM
+	mov		DATALEN, %ecx
+	_store_partial_block	AESDATA0_XMM, DST, POWERS_PTR, POWERS_PTR32
+.if \enc
+	lea		.Lselect_high_bytes_table(%rip), %rax
+	vpshufb		BSWAP_MASK_XMM, AESDATA0_XMM, AESDATA0_XMM
+	vpand		(%rax, DATALEN64), AESDATA0_XMM, AESDATA0_XMM
+.else
+	vpshufb		BSWAP_MASK_XMM, TMP0_XMM, AESDATA0_XMM
+.endif
+	vpxor		GHASH_ACC_XMM, AESDATA0_XMM, AESDATA0_XMM
+
+.Ltail_ghash_last_vec\@:
+	// Accumulate the unreduced GHASH products for the last 1-2 blocks.  The
+	// GHASH input data is in AESDATA0.  If only one block remains, then the
+	// second block in AESDATA0 is zero and does not affect the result.
+	_ghash_mul_noreduce	TMP2, AESDATA0, LO, MI, HI, TMP0
+
+.Ltail_ghash_reduce\@:
+	// Finally, do the GHASH reduction.
+	vbroadcasti128	.Lgfpoly(%rip), TMP0
+	_ghash_reduce	LO, MI, HI, TMP0, TMP1
+	vextracti128	$1, HI, GHASH_ACC_XMM
+	vpxor		HI_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM
+
+.Ldone\@:
+	// Store the updated GHASH accumulator back to memory.
+	vmovdqu		GHASH_ACC_XMM, (GHASH_ACC_PTR)
+
+	vzeroupper
+	RET
+.endm
+
+// void aes_gcm_enc_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
+//				    const u32 le_ctr[4], u8 ghash_acc[16],
+//				    u64 total_aadlen, u64 total_datalen);
+// bool aes_gcm_dec_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
+//				    const u32 le_ctr[4], const u8 ghash_acc[16],
+//				    u64 total_aadlen, u64 total_datalen,
+//				    const u8 tag[16], int taglen);
+//
+// This macro generates one of the above two functions (with \enc selecting
+// which one).  Both functions finish computing the GCM authentication tag by
+// updating GHASH with the lengths block and encrypting the GHASH accumulator.
+// |total_aadlen| and |total_datalen| must be the total length of the additional
+// authenticated data and the en/decrypted data in bytes, respectively.
+//
+// The encryption function then stores the full-length (16-byte) computed
+// authentication tag to |ghash_acc|.  The decryption function instead loads the
+// expected authentication tag (the one that was transmitted) from the 16-byte
+// buffer |tag|, compares the first 4 <= |taglen| <= 16 bytes of it to the
+// computed tag in constant time, and returns true if and only if they match.
+.macro	_aes_gcm_final	enc
+
+	// Function arguments
+	.set	KEY,		%rdi
+	.set	LE_CTR_PTR,	%rsi
+	.set	GHASH_ACC_PTR,	%rdx
+	.set	TOTAL_AADLEN,	%rcx
+	.set	TOTAL_DATALEN,	%r8
+	.set	TAG,		%r9
+	.set	TAGLEN,		%r10d	// Originally at 8(%rsp)
+	.set	TAGLEN64,	%r10
+
+	// Additional local variables.
+	// %rax and %xmm0-%xmm3 are used as temporary registers.
+	.set	AESKEYLEN,	%r11d
+	.set	AESKEYLEN64,	%r11
+	.set	GFPOLY,		%xmm4
+	.set	BSWAP_MASK,	%xmm5
+	.set	LE_CTR,		%xmm6
+	.set	GHASH_ACC,	%xmm7
+	.set	H_POW1,		%xmm8
+
+	// Load some constants.
+	vmovdqa		.Lgfpoly(%rip), GFPOLY
+	vmovdqa		.Lbswap_mask(%rip), BSWAP_MASK
+
+	// Load the AES key length in bytes.
+	movl		OFFSETOF_AESKEYLEN(KEY), AESKEYLEN
+
+	// Set up a counter block with 1 in the low 32-bit word.  This is the
+	// counter that produces the ciphertext needed to encrypt the auth tag.
+	// GFPOLY has 1 in the low word, so grab the 1 from there using a blend.
+	vpblendd	$0xe, (LE_CTR_PTR), GFPOLY, LE_CTR
+
+	// Build the lengths block and XOR it with the GHASH accumulator.
+	// Although the lengths block is defined as the AAD length followed by
+	// the en/decrypted data length, both in big-endian byte order, a byte
+	// reflection of the full block is needed because of the way we compute
+	// GHASH (see _ghash_mul_step).  By using little-endian values in the
+	// opposite order, we avoid having to reflect any bytes here.
+	vmovq		TOTAL_DATALEN, %xmm0
+	vpinsrq		$1, TOTAL_AADLEN, %xmm0, %xmm0
+	vpsllq		$3, %xmm0, %xmm0	// Bytes to bits
+	vpxor		(GHASH_ACC_PTR), %xmm0, GHASH_ACC
+
+	// Load the first hash key power (H^1), which is stored last.
+	vmovdqu		OFFSETOFEND_H_POWERS-16(KEY), H_POW1
+
+	// Load TAGLEN if decrypting.
+.if !\enc
+	movl		8(%rsp), TAGLEN
+.endif
+
+	// Make %rax point to the last AES round key for the chosen AES variant.
+	lea		6*16(KEY,AESKEYLEN64,4), %rax
+
+	// Start the AES encryption of the counter block by swapping the counter
+	// block to big-endian and XOR-ing it with the zero-th AES round key.
+	vpshufb		BSWAP_MASK, LE_CTR, %xmm0
+	vpxor		(KEY), %xmm0, %xmm0
+
+	// Complete the AES encryption and multiply GHASH_ACC by H^1.
+	// Interleave the AES and GHASH instructions to improve performance.
+	cmp		$24, AESKEYLEN
+	jl		128f	// AES-128?
+	je		192f	// AES-192?
+	// AES-256
+	vaesenc		-13*16(%rax), %xmm0, %xmm0
+	vaesenc		-12*16(%rax), %xmm0, %xmm0
+192:
+	vaesenc		-11*16(%rax), %xmm0, %xmm0
+	vaesenc		-10*16(%rax), %xmm0, %xmm0
+128:
+.irp i, 0,1,2,3,4,5,6,7,8
+	_ghash_mul_step	\i, H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \
+			%xmm1, %xmm2, %xmm3
+	vaesenc		(\i-9)*16(%rax), %xmm0, %xmm0
+.endr
+	_ghash_mul_step	9, H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \
+			%xmm1, %xmm2, %xmm3
+
+	// Undo the byte reflection of the GHASH accumulator.
+	vpshufb		BSWAP_MASK, GHASH_ACC, GHASH_ACC
+
+	// Do the last AES round and XOR the resulting keystream block with the
+	// GHASH accumulator to produce the full computed authentication tag.
+	//
+	// Reduce latency by taking advantage of the property vaesenclast(key,
+	// a) ^ b == vaesenclast(key ^ b, a).  I.e., XOR GHASH_ACC into the last
+	// round key, instead of XOR'ing the final AES output with GHASH_ACC.
+	//
+	// enc_final then returns the computed auth tag, while dec_final
+	// compares it with the transmitted one and returns a bool.  To compare
+	// the tags, dec_final XORs them together and uses vptest to check
+	// whether the result is all-zeroes.  This should be constant-time.
+	// dec_final applies the vaesenclast optimization to this additional
+	// value XOR'd too.
+.if \enc
+	vpxor		(%rax), GHASH_ACC, %xmm1
+	vaesenclast	%xmm1, %xmm0, GHASH_ACC
+	vmovdqu		GHASH_ACC, (GHASH_ACC_PTR)
+.else
+	vpxor		(TAG), GHASH_ACC, GHASH_ACC
+	vpxor		(%rax), GHASH_ACC, GHASH_ACC
+	vaesenclast	GHASH_ACC, %xmm0, %xmm0
+	lea		.Lselect_high_bytes_table(%rip), %rax
+	vmovdqu		(%rax, TAGLEN64), %xmm1
+	vpshufb		BSWAP_MASK, %xmm1, %xmm1 // select low bytes, not high
+	vptest		%xmm1, %xmm0
+	sete		%al
+.endif
+	// No need for vzeroupper here, since only used xmm registers were used.
+	RET
+.endm
+
+SYM_FUNC_START(aes_gcm_enc_update_vaes_avx2)
+	_aes_gcm_update	1
+SYM_FUNC_END(aes_gcm_enc_update_vaes_avx2)
+SYM_FUNC_START(aes_gcm_dec_update_vaes_avx2)
+	_aes_gcm_update	0
+SYM_FUNC_END(aes_gcm_dec_update_vaes_avx2)
+
+SYM_FUNC_START(aes_gcm_enc_final_vaes_avx2)
+	_aes_gcm_final	1
+SYM_FUNC_END(aes_gcm_enc_final_vaes_avx2)
+SYM_FUNC_START(aes_gcm_dec_final_vaes_avx2)
+	_aes_gcm_final	0
+SYM_FUNC_END(aes_gcm_dec_final_vaes_avx2)
diff --git a/arch/x86/crypto/aes-gcm-avx10-x86_64.S b/arch/x86/crypto/aes-gcm-vaes-avx512.S
similarity index 69%
rename from arch/x86/crypto/aes-gcm-avx10-x86_64.S
rename to arch/x86/crypto/aes-gcm-vaes-avx512.S
index 02ee11083d4f..06b71314d65c 100644
--- a/arch/x86/crypto/aes-gcm-avx10-x86_64.S
+++ b/arch/x86/crypto/aes-gcm-vaes-avx512.S
@@ -1,6 +1,7 @@
 /* SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause */
 //
-// VAES and VPCLMULQDQ optimized AES-GCM for x86_64
+// AES-GCM implementation for x86_64 CPUs that support the following CPU
+// features: VAES && VPCLMULQDQ && AVX512BW && AVX512VL && BMI2
 //
 // Copyright 2024 Google LLC
 //
@@ -45,41 +46,6 @@
 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 // POSSIBILITY OF SUCH DAMAGE.
-//
-//------------------------------------------------------------------------------
-//
-// This file implements AES-GCM (Galois/Counter Mode) for x86_64 CPUs that
-// support VAES (vector AES), VPCLMULQDQ (vector carryless multiplication), and
-// either AVX512 or AVX10.  Some of the functions, notably the encryption and
-// decryption update functions which are the most performance-critical, are
-// provided in two variants generated from a macro: one using 256-bit vectors
-// (suffix: vaes_avx10_256) and one using 512-bit vectors (vaes_avx10_512).  The
-// other, "shared" functions (vaes_avx10) use at most 256-bit vectors.
-//
-// The functions that use 512-bit vectors are intended for CPUs that support
-// 512-bit vectors *and* where using them doesn't cause significant
-// downclocking.  They require the following CPU features:
-//
-//	VAES && VPCLMULQDQ && BMI2 && ((AVX512BW && AVX512VL) || AVX10/512)
-//
-// The other functions require the following CPU features:
-//
-//	VAES && VPCLMULQDQ && BMI2 && ((AVX512BW && AVX512VL) || AVX10/256)
-//
-// All functions use the "System V" ABI.  The Windows ABI is not supported.
-//
-// Note that we use "avx10" in the names of the functions as a shorthand to
-// really mean "AVX10 or a certain set of AVX512 features".  Due to Intel's
-// introduction of AVX512 and then its replacement by AVX10, there doesn't seem
-// to be a simple way to name things that makes sense on all CPUs.
-//
-// Note that the macros that support both 256-bit and 512-bit vectors could
-// fairly easily be changed to support 128-bit too.  However, this would *not*
-// be sufficient to allow the code to run on CPUs without AVX512 or AVX10,
-// because the code heavily uses several features of these extensions other than
-// the vector length: the increase in the number of SIMD registers from 16 to
-// 32, masking support, and new instructions such as vpternlogd (which can do a
-// three-argument XOR).  These features are very useful for AES-GCM.
 
 #include <linux/linkage.h>
 
@@ -104,16 +70,14 @@
 .Lgfpoly_and_internal_carrybit:
 	.octa	0xc2000000000000010000000000000001
 
-	// The below constants are used for incrementing the counter blocks.
-	// ctr_pattern points to the four 128-bit values [0, 1, 2, 3].
-	// inc_2blocks and inc_4blocks point to the single 128-bit values 2 and
-	// 4.  Note that the same '2' is reused in ctr_pattern and inc_2blocks.
+	// Values needed to prepare the initial vector of counter blocks.
 .Lctr_pattern:
 	.octa	0
 	.octa	1
-.Linc_2blocks:
 	.octa	2
 	.octa	3
+
+	// The number of AES blocks per vector, as a 128-bit value.
 .Linc_4blocks:
 	.octa	4
 
@@ -130,29 +94,13 @@
 // Offset to end of hash key powers array in the key struct.
 //
 // This is immediately followed by three zeroized padding blocks, which are
-// included so that partial vectors can be handled more easily.  E.g. if VL=64
-// and two blocks remain, we load the 4 values [H^2, H^1, 0, 0].  The most
-// padding blocks needed is 3, which occurs if [H^1, 0, 0, 0] is loaded.
+// included so that partial vectors can be handled more easily.  E.g. if two
+// blocks remain, we load the 4 values [H^2, H^1, 0, 0].  The most padding
+// blocks needed is 3, which occurs if [H^1, 0, 0, 0] is loaded.
 #define OFFSETOFEND_H_POWERS	(OFFSETOF_H_POWERS + (NUM_H_POWERS * 16))
 
 .text
 
-// Set the vector length in bytes.  This sets the VL variable and defines
-// register aliases V0-V31 that map to the ymm or zmm registers.
-.macro	_set_veclen	vl
-	.set	VL,	\vl
-.irp i, 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15, \
-	16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31
-.if VL == 32
-	.set	V\i,	%ymm\i
-.elseif VL == 64
-	.set	V\i,	%zmm\i
-.else
-	.error "Unsupported vector length"
-.endif
-.endr
-.endm
-
 // The _ghash_mul_step macro does one step of GHASH multiplication of the
 // 128-bit lanes of \a by the corresponding 128-bit lanes of \b and storing the
 // reduced products in \dst.  \t0, \t1, and \t2 are temporary registers of the
@@ -312,39 +260,44 @@
 	vpternlogd	$0x96, \t0, \mi, \hi
 .endm
 
-// void aes_gcm_precompute_##suffix(struct aes_gcm_key_avx10 *key);
-//
-// Given the expanded AES key |key->aes_key|, this function derives the GHASH
-// subkey and initializes |key->ghash_key_powers| with powers of it.
-//
-// The number of key powers initialized is NUM_H_POWERS, and they are stored in
-// the order H^NUM_H_POWERS to H^1.  The zeroized padding blocks after the key
-// powers themselves are also initialized.
+// This is a specialized version of _ghash_mul that computes \a * \a, i.e. it
+// squares \a.  It skips computing MI = (a_L * a_H) + (a_H * a_L) = 0.
+.macro	_ghash_square	a, dst, gfpoly, t0, t1
+	vpclmulqdq	$0x00, \a, \a, \t0	  // LO = a_L * a_L
+	vpclmulqdq	$0x11, \a, \a, \dst	  // HI = a_H * a_H
+	vpclmulqdq	$0x01, \t0, \gfpoly, \t1  // LO_L*(x^63 + x^62 + x^57)
+	vpshufd		$0x4e, \t0, \t0		  // Swap halves of LO
+	vpxord		\t0, \t1, \t1		  // Fold LO into MI
+	vpclmulqdq	$0x01, \t1, \gfpoly, \t0  // MI_L*(x^63 + x^62 + x^57)
+	vpshufd		$0x4e, \t1, \t1		  // Swap halves of MI
+	vpternlogd	$0x96, \t0, \t1, \dst	  // Fold MI into HI
+.endm
+
+// void aes_gcm_precompute_vaes_avx512(struct aes_gcm_key_vaes_avx512 *key);
 //
-// This macro supports both VL=32 and VL=64.  _set_veclen must have been invoked
-// with the desired length.  In the VL=32 case, the function computes twice as
-// many key powers than are actually used by the VL=32 GCM update functions.
-// This is done to keep the key format the same regardless of vector length.
-.macro	_aes_gcm_precompute
+// Given the expanded AES key |key->base.aes_key|, derive the GHASH subkey and
+// initialize |key->h_powers| and |key->padding|.
+SYM_FUNC_START(aes_gcm_precompute_vaes_avx512)
 
 	// Function arguments
 	.set	KEY,		%rdi
 
-	// Additional local variables.  V0-V2 and %rax are used as temporaries.
+	// Additional local variables.
+	// %zmm[0-2] and %rax are used as temporaries.
 	.set	POWERS_PTR,	%rsi
 	.set	RNDKEYLAST_PTR,	%rdx
-	.set	H_CUR,		V3
+	.set	H_CUR,		%zmm3
 	.set	H_CUR_YMM,	%ymm3
 	.set	H_CUR_XMM,	%xmm3
-	.set	H_INC,		V4
+	.set	H_INC,		%zmm4
 	.set	H_INC_YMM,	%ymm4
 	.set	H_INC_XMM,	%xmm4
-	.set	GFPOLY,		V5
+	.set	GFPOLY,		%zmm5
 	.set	GFPOLY_YMM,	%ymm5
 	.set	GFPOLY_XMM,	%xmm5
 
 	// Get pointer to lowest set of key powers (located at end of array).
-	lea		OFFSETOFEND_H_POWERS-VL(KEY), POWERS_PTR
+	lea		OFFSETOFEND_H_POWERS-64(KEY), POWERS_PTR
 
 	// Encrypt an all-zeroes block to get the raw hash subkey.
 	movl		OFFSETOF_AESKEYLEN(KEY), %eax
@@ -363,8 +316,8 @@
 
 	// Zeroize the padding blocks.
 	vpxor		%xmm0, %xmm0, %xmm0
-	vmovdqu		%ymm0, VL(POWERS_PTR)
-	vmovdqu		%xmm0, VL+2*16(POWERS_PTR)
+	vmovdqu		%ymm0, 64(POWERS_PTR)
+	vmovdqu		%xmm0, 64+2*16(POWERS_PTR)
 
 	// Finish preprocessing the first key power, H^1.  Since this GHASH
 	// implementation operates directly on values with the backwards bit
@@ -397,54 +350,44 @@
 	// special needs to be done to make this happen, though: H^1 * H^1 would
 	// end up with two factors of x^-1, but the multiplication consumes one.
 	// So the product H^2 ends up with the desired one factor of x^-1.
-	_ghash_mul	H_CUR_XMM, H_CUR_XMM, H_INC_XMM, GFPOLY_XMM, \
-			%xmm0, %xmm1, %xmm2
+	_ghash_square	H_CUR_XMM, H_INC_XMM, GFPOLY_XMM, %xmm0, %xmm1
 
 	// Create H_CUR_YMM = [H^2, H^1] and H_INC_YMM = [H^2, H^2].
 	vinserti128	$1, H_CUR_XMM, H_INC_YMM, H_CUR_YMM
 	vinserti128	$1, H_INC_XMM, H_INC_YMM, H_INC_YMM
 
-.if VL == 64
 	// Create H_CUR = [H^4, H^3, H^2, H^1] and H_INC = [H^4, H^4, H^4, H^4].
 	_ghash_mul	H_INC_YMM, H_CUR_YMM, H_INC_YMM, GFPOLY_YMM, \
 			%ymm0, %ymm1, %ymm2
 	vinserti64x4	$1, H_CUR_YMM, H_INC, H_CUR
 	vshufi64x2	$0, H_INC, H_INC, H_INC
-.endif
 
 	// Store the lowest set of key powers.
 	vmovdqu8	H_CUR, (POWERS_PTR)
 
-	// Compute and store the remaining key powers.  With VL=32, repeatedly
-	// multiply [H^(i+1), H^i] by [H^2, H^2] to get [H^(i+3), H^(i+2)].
-	// With VL=64, repeatedly multiply [H^(i+3), H^(i+2), H^(i+1), H^i] by
+	// Compute and store the remaining key powers.
+	// Repeatedly multiply [H^(i+3), H^(i+2), H^(i+1), H^i] by
 	// [H^4, H^4, H^4, H^4] to get [H^(i+7), H^(i+6), H^(i+5), H^(i+4)].
-	mov		$(NUM_H_POWERS*16/VL) - 1, %eax
-.Lprecompute_next\@:
-	sub		$VL, POWERS_PTR
-	_ghash_mul	H_INC, H_CUR, H_CUR, GFPOLY, V0, V1, V2
+	mov		$3, %eax
+.Lprecompute_next:
+	sub		$64, POWERS_PTR
+	_ghash_mul	H_INC, H_CUR, H_CUR, GFPOLY, %zmm0, %zmm1, %zmm2
 	vmovdqu8	H_CUR, (POWERS_PTR)
 	dec		%eax
-	jnz		.Lprecompute_next\@
+	jnz		.Lprecompute_next
 
 	vzeroupper	// This is needed after using ymm or zmm registers.
 	RET
-.endm
+SYM_FUNC_END(aes_gcm_precompute_vaes_avx512)
 
 // XOR together the 128-bit lanes of \src (whose low lane is \src_xmm) and store
 // the result in \dst_xmm.  This implicitly zeroizes the other lanes of dst.
 .macro	_horizontal_xor	src, src_xmm, dst_xmm, t0_xmm, t1_xmm, t2_xmm
 	vextracti32x4	$1, \src, \t0_xmm
-.if VL == 32
-	vpxord		\t0_xmm, \src_xmm, \dst_xmm
-.elseif VL == 64
 	vextracti32x4	$2, \src, \t1_xmm
 	vextracti32x4	$3, \src, \t2_xmm
 	vpxord		\t0_xmm, \src_xmm, \dst_xmm
 	vpternlogd	$0x96, \t1_xmm, \t2_xmm, \dst_xmm
-.else
-	.error "Unsupported vector length"
-.endif
 .endm
 
 // Do one step of the GHASH update of the data blocks given in the vector
@@ -458,25 +401,21 @@
 //
 // The GHASH update does: GHASH_ACC = H_POW4*(GHASHDATA0 + GHASH_ACC) +
 // H_POW3*GHASHDATA1 + H_POW2*GHASHDATA2 + H_POW1*GHASHDATA3, where the
-// operations are vectorized operations on vectors of 16-byte blocks.  E.g.,
-// with VL=32 there are 2 blocks per vector and the vectorized terms correspond
-// to the following non-vectorized terms:
-//
-//	H_POW4*(GHASHDATA0 + GHASH_ACC) => H^8*(blk0 + GHASH_ACC_XMM) and H^7*(blk1 + 0)
-//	H_POW3*GHASHDATA1 => H^6*blk2 and H^5*blk3
-//	H_POW2*GHASHDATA2 => H^4*blk4 and H^3*blk5
-//	H_POW1*GHASHDATA3 => H^2*blk6 and H^1*blk7
+// operations are vectorized operations on 512-bit vectors of 128-bit blocks.
+// The vectorized terms correspond to the following non-vectorized terms:
 //
-// With VL=64, we use 4 blocks/vector, H^16 through H^1, and blk0 through blk15.
+//       H_POW4*(GHASHDATA0 + GHASH_ACC) => H^16*(blk0 + GHASH_ACC_XMM),
+//              H^15*(blk1 + 0), H^14*(blk2 + 0), and H^13*(blk3 + 0)
+//       H_POW3*GHASHDATA1 => H^12*blk4, H^11*blk5, H^10*blk6, and H^9*blk7
+//       H_POW2*GHASHDATA2 => H^8*blk8,  H^7*blk9,  H^6*blk10, and H^5*blk11
+//       H_POW1*GHASHDATA3 => H^4*blk12, H^3*blk13, H^2*blk14, and H^1*blk15
 //
 // More concretely, this code does:
 //   - Do vectorized "schoolbook" multiplications to compute the intermediate
 //     256-bit product of each block and its corresponding hash key power.
-//     There are 4*VL/16 of these intermediate products.
-//   - Sum (XOR) the intermediate 256-bit products across vectors.  This leaves
-//     VL/16 256-bit intermediate values.
+//   - Sum (XOR) the intermediate 256-bit products across vectors.
 //   - Do a vectorized reduction of these 256-bit intermediate values to
-//     128-bits each.  This leaves VL/16 128-bit intermediate values.
+//     128-bits each.
 //   - Sum (XOR) these values and store the 128-bit result in GHASH_ACC_XMM.
 //
 // See _ghash_mul_step for the full explanation of the operations performed for
@@ -532,85 +471,224 @@
 .endif
 .endm
 
-// Do one non-last round of AES encryption on the counter blocks in V0-V3 using
-// the round key that has been broadcast to all 128-bit lanes of \round_key.
+// Update GHASH with four vectors of data blocks.  See _ghash_step_4x for full
+// explanation.
+.macro	_ghash_4x
+.irp i, 0,1,2,3,4,5,6,7,8,9
+	_ghash_step_4x	\i
+.endr
+.endm
+
+// void aes_gcm_aad_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
+//				       u8 ghash_acc[16],
+//				       const u8 *aad, int aadlen);
+//
+// This function processes the AAD (Additional Authenticated Data) in GCM.
+// Using the key |key|, it updates the GHASH accumulator |ghash_acc| with the
+// data given by |aad| and |aadlen|.  On the first call, |ghash_acc| must be all
+// zeroes.  |aadlen| must be a multiple of 16, except on the last call where it
+// can be any length.  The caller must do any buffering needed to ensure this.
+//
+// This handles large amounts of AAD efficiently, while also keeping overhead
+// low for small amounts which is the common case.  TLS and IPsec use less than
+// one block of AAD, but (uncommonly) other use cases may use much more.
+SYM_FUNC_START(aes_gcm_aad_update_vaes_avx512)
+
+	// Function arguments
+	.set	KEY,		%rdi
+	.set	GHASH_ACC_PTR,	%rsi
+	.set	AAD,		%rdx
+	.set	AADLEN,		%ecx
+	.set	AADLEN64,	%rcx	// Zero-extend AADLEN before using!
+
+	// Additional local variables.
+	// %rax and %k1 are used as temporary registers.
+	.set	GHASHDATA0,	%zmm0
+	.set	GHASHDATA0_XMM,	%xmm0
+	.set	GHASHDATA1,	%zmm1
+	.set	GHASHDATA1_XMM,	%xmm1
+	.set	GHASHDATA2,	%zmm2
+	.set	GHASHDATA2_XMM,	%xmm2
+	.set	GHASHDATA3,	%zmm3
+	.set	BSWAP_MASK,	%zmm4
+	.set	BSWAP_MASK_XMM,	%xmm4
+	.set	GHASH_ACC,	%zmm5
+	.set	GHASH_ACC_XMM,	%xmm5
+	.set	H_POW4,		%zmm6
+	.set	H_POW3,		%zmm7
+	.set	H_POW2,		%zmm8
+	.set	H_POW1,		%zmm9
+	.set	H_POW1_XMM,	%xmm9
+	.set	GFPOLY,		%zmm10
+	.set	GFPOLY_XMM,	%xmm10
+	.set	GHASHTMP0,	%zmm11
+	.set	GHASHTMP1,	%zmm12
+	.set	GHASHTMP2,	%zmm13
+
+	// Load the GHASH accumulator.
+	vmovdqu		(GHASH_ACC_PTR), GHASH_ACC_XMM
+
+	// Check for the common case of AADLEN <= 16, as well as AADLEN == 0.
+	cmp		$16, AADLEN
+	jg		.Laad_more_than_16bytes
+	test		AADLEN, AADLEN
+	jz		.Laad_done
+
+	// Fast path: update GHASH with 1 <= AADLEN <= 16 bytes of AAD.
+	vmovdqu		.Lbswap_mask(%rip), BSWAP_MASK_XMM
+	vmovdqu		.Lgfpoly(%rip), GFPOLY_XMM
+	mov		$-1, %eax
+	bzhi		AADLEN, %eax, %eax
+	kmovd		%eax, %k1
+	vmovdqu8	(AAD), GHASHDATA0_XMM{%k1}{z}
+	vmovdqu		OFFSETOFEND_H_POWERS-16(KEY), H_POW1_XMM
+	vpshufb		BSWAP_MASK_XMM, GHASHDATA0_XMM, GHASHDATA0_XMM
+	vpxor		GHASHDATA0_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM
+	_ghash_mul	H_POW1_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM, GFPOLY_XMM, \
+			GHASHDATA0_XMM, GHASHDATA1_XMM, GHASHDATA2_XMM
+	jmp		.Laad_done
+
+.Laad_more_than_16bytes:
+	vbroadcasti32x4	.Lbswap_mask(%rip), BSWAP_MASK
+	vbroadcasti32x4	.Lgfpoly(%rip), GFPOLY
+
+	// If AADLEN >= 256, update GHASH with 256 bytes of AAD at a time.
+	sub		$256, AADLEN
+	jl		.Laad_loop_4x_done
+	vmovdqu8	OFFSETOFEND_H_POWERS-4*64(KEY), H_POW4
+	vmovdqu8	OFFSETOFEND_H_POWERS-3*64(KEY), H_POW3
+	vmovdqu8	OFFSETOFEND_H_POWERS-2*64(KEY), H_POW2
+	vmovdqu8	OFFSETOFEND_H_POWERS-1*64(KEY), H_POW1
+.Laad_loop_4x:
+	vmovdqu8	0*64(AAD), GHASHDATA0
+	vmovdqu8	1*64(AAD), GHASHDATA1
+	vmovdqu8	2*64(AAD), GHASHDATA2
+	vmovdqu8	3*64(AAD), GHASHDATA3
+	_ghash_4x
+	add		$256, AAD
+	sub		$256, AADLEN
+	jge		.Laad_loop_4x
+.Laad_loop_4x_done:
+
+	// If AADLEN >= 64, update GHASH with 64 bytes of AAD at a time.
+	add		$192, AADLEN
+	jl		.Laad_loop_1x_done
+	vmovdqu8	OFFSETOFEND_H_POWERS-1*64(KEY), H_POW1
+.Laad_loop_1x:
+	vmovdqu8	(AAD), GHASHDATA0
+	vpshufb		BSWAP_MASK, GHASHDATA0, GHASHDATA0
+	vpxord		GHASHDATA0, GHASH_ACC, GHASH_ACC
+	_ghash_mul	H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \
+			GHASHDATA0, GHASHDATA1, GHASHDATA2
+	_horizontal_xor GHASH_ACC, GHASH_ACC_XMM, GHASH_ACC_XMM, \
+			GHASHDATA0_XMM, GHASHDATA1_XMM, GHASHDATA2_XMM
+	add		$64, AAD
+	sub		$64, AADLEN
+	jge		.Laad_loop_1x
+.Laad_loop_1x_done:
+
+	// Update GHASH with the remaining 0 <= AADLEN < 64 bytes of AAD.
+	add		$64, AADLEN
+	jz		.Laad_done
+	mov		$-1, %rax
+	bzhi		AADLEN64, %rax, %rax
+	kmovq		%rax, %k1
+	vmovdqu8	(AAD), GHASHDATA0{%k1}{z}
+	neg		AADLEN64
+	and		$~15, AADLEN64  // -round_up(AADLEN, 16)
+	vmovdqu8	OFFSETOFEND_H_POWERS(KEY,AADLEN64), H_POW1
+	vpshufb		BSWAP_MASK, GHASHDATA0, GHASHDATA0
+	vpxord		GHASHDATA0, GHASH_ACC, GHASH_ACC
+	_ghash_mul	H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \
+			GHASHDATA0, GHASHDATA1, GHASHDATA2
+	_horizontal_xor GHASH_ACC, GHASH_ACC_XMM, GHASH_ACC_XMM, \
+			GHASHDATA0_XMM, GHASHDATA1_XMM, GHASHDATA2_XMM
+
+.Laad_done:
+	// Store the updated GHASH accumulator back to memory.
+	vmovdqu		GHASH_ACC_XMM, (GHASH_ACC_PTR)
+
+	vzeroupper	// This is needed after using ymm or zmm registers.
+	RET
+SYM_FUNC_END(aes_gcm_aad_update_vaes_avx512)
+
+// Do one non-last round of AES encryption on the blocks in %zmm[0-3] using the
+// round key that has been broadcast to all 128-bit lanes of \round_key.
 .macro	_vaesenc_4x	round_key
-	vaesenc		\round_key, V0, V0
-	vaesenc		\round_key, V1, V1
-	vaesenc		\round_key, V2, V2
-	vaesenc		\round_key, V3, V3
+	vaesenc		\round_key, %zmm0, %zmm0
+	vaesenc		\round_key, %zmm1, %zmm1
+	vaesenc		\round_key, %zmm2, %zmm2
+	vaesenc		\round_key, %zmm3, %zmm3
 .endm
 
 // Start the AES encryption of four vectors of counter blocks.
 .macro	_ctr_begin_4x
 
 	// Increment LE_CTR four times to generate four vectors of little-endian
-	// counter blocks, swap each to big-endian, and store them in V0-V3.
-	vpshufb		BSWAP_MASK, LE_CTR, V0
+	// counter blocks, swap each to big-endian, and store them in %zmm[0-3].
+	vpshufb		BSWAP_MASK, LE_CTR, %zmm0
 	vpaddd		LE_CTR_INC, LE_CTR, LE_CTR
-	vpshufb		BSWAP_MASK, LE_CTR, V1
+	vpshufb		BSWAP_MASK, LE_CTR, %zmm1
 	vpaddd		LE_CTR_INC, LE_CTR, LE_CTR
-	vpshufb		BSWAP_MASK, LE_CTR, V2
+	vpshufb		BSWAP_MASK, LE_CTR, %zmm2
 	vpaddd		LE_CTR_INC, LE_CTR, LE_CTR
-	vpshufb		BSWAP_MASK, LE_CTR, V3
+	vpshufb		BSWAP_MASK, LE_CTR, %zmm3
 	vpaddd		LE_CTR_INC, LE_CTR, LE_CTR
 
 	// AES "round zero": XOR in the zero-th round key.
-	vpxord		RNDKEY0, V0, V0
-	vpxord		RNDKEY0, V1, V1
-	vpxord		RNDKEY0, V2, V2
-	vpxord		RNDKEY0, V3, V3
+	vpxord		RNDKEY0, %zmm0, %zmm0
+	vpxord		RNDKEY0, %zmm1, %zmm1
+	vpxord		RNDKEY0, %zmm2, %zmm2
+	vpxord		RNDKEY0, %zmm3, %zmm3
 .endm
 
-// Do the last AES round for four vectors of counter blocks V0-V3, XOR source
-// data with the resulting keystream, and write the result to DST and
+// Do the last AES round for four vectors of counter blocks %zmm[0-3], XOR
+// source data with the resulting keystream, and write the result to DST and
 // GHASHDATA[0-3].  (Implementation differs slightly, but has the same effect.)
 .macro	_aesenclast_and_xor_4x
 	// XOR the source data with the last round key, saving the result in
 	// GHASHDATA[0-3].  This reduces latency by taking advantage of the
 	// property vaesenclast(key, a) ^ b == vaesenclast(key ^ b, a).
-	vpxord		0*VL(SRC), RNDKEYLAST, GHASHDATA0
-	vpxord		1*VL(SRC), RNDKEYLAST, GHASHDATA1
-	vpxord		2*VL(SRC), RNDKEYLAST, GHASHDATA2
-	vpxord		3*VL(SRC), RNDKEYLAST, GHASHDATA3
+	vpxord		0*64(SRC), RNDKEYLAST, GHASHDATA0
+	vpxord		1*64(SRC), RNDKEYLAST, GHASHDATA1
+	vpxord		2*64(SRC), RNDKEYLAST, GHASHDATA2
+	vpxord		3*64(SRC), RNDKEYLAST, GHASHDATA3
 
 	// Do the last AES round.  This handles the XOR with the source data
 	// too, as per the optimization described above.
-	vaesenclast	GHASHDATA0, V0, GHASHDATA0
-	vaesenclast	GHASHDATA1, V1, GHASHDATA1
-	vaesenclast	GHASHDATA2, V2, GHASHDATA2
-	vaesenclast	GHASHDATA3, V3, GHASHDATA3
+	vaesenclast	GHASHDATA0, %zmm0, GHASHDATA0
+	vaesenclast	GHASHDATA1, %zmm1, GHASHDATA1
+	vaesenclast	GHASHDATA2, %zmm2, GHASHDATA2
+	vaesenclast	GHASHDATA3, %zmm3, GHASHDATA3
 
 	// Store the en/decrypted data to DST.
-	vmovdqu8	GHASHDATA0, 0*VL(DST)
-	vmovdqu8	GHASHDATA1, 1*VL(DST)
-	vmovdqu8	GHASHDATA2, 2*VL(DST)
-	vmovdqu8	GHASHDATA3, 3*VL(DST)
+	vmovdqu8	GHASHDATA0, 0*64(DST)
+	vmovdqu8	GHASHDATA1, 1*64(DST)
+	vmovdqu8	GHASHDATA2, 2*64(DST)
+	vmovdqu8	GHASHDATA3, 3*64(DST)
 .endm
 
-// void aes_gcm_{enc,dec}_update_##suffix(const struct aes_gcm_key_avx10 *key,
-//					  const u32 le_ctr[4], u8 ghash_acc[16],
-//					  const u8 *src, u8 *dst, int datalen);
+// void aes_gcm_{enc,dec}_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
+//					     const u32 le_ctr[4], u8 ghash_acc[16],
+//					     const u8 *src, u8 *dst, int datalen);
 //
 // This macro generates a GCM encryption or decryption update function with the
-// above prototype (with \enc selecting which one).  This macro supports both
-// VL=32 and VL=64.  _set_veclen must have been invoked with the desired length.
-//
-// This function computes the next portion of the CTR keystream, XOR's it with
-// |datalen| bytes from |src|, and writes the resulting encrypted or decrypted
-// data to |dst|.  It also updates the GHASH accumulator |ghash_acc| using the
-// next |datalen| ciphertext bytes.
+// above prototype (with \enc selecting which one).  The function computes the
+// next portion of the CTR keystream, XOR's it with |datalen| bytes from |src|,
+// and writes the resulting encrypted or decrypted data to |dst|.  It also
+// updates the GHASH accumulator |ghash_acc| using the next |datalen| ciphertext
+// bytes.
 //
 // |datalen| must be a multiple of 16, except on the last call where it can be
 // any length.  The caller must do any buffering needed to ensure this.  Both
 // in-place and out-of-place en/decryption are supported.
 //
-// |le_ctr| must give the current counter in little-endian format.  For a new
-// message, the low word of the counter must be 2.  This function loads the
-// counter from |le_ctr| and increments the loaded counter as needed, but it
-// does *not* store the updated counter back to |le_ctr|.  The caller must
-// update |le_ctr| if any more data segments follow.  Internally, only the low
-// 32-bit word of the counter is incremented, following the GCM standard.
+// |le_ctr| must give the current counter in little-endian format.  This
+// function loads the counter from |le_ctr| and increments the loaded counter as
+// needed, but it does *not* store the updated counter back to |le_ctr|.  The
+// caller must update |le_ctr| if any more data segments follow.  Internally,
+// only the low 32-bit word of the counter is incremented, following the GCM
+// standard.
 .macro	_aes_gcm_update	enc
 
 	// Function arguments
@@ -634,69 +712,69 @@
 	// Pointer to the last AES round key for the chosen AES variant
 	.set	RNDKEYLAST_PTR,	%r11
 
-	// In the main loop, V0-V3 are used as AES input and output.  Elsewhere
-	// they are used as temporary registers.
+	// In the main loop, %zmm[0-3] are used as AES input and output.
+	// Elsewhere they are used as temporary registers.
 
 	// GHASHDATA[0-3] hold the ciphertext blocks and GHASH input data.
-	.set	GHASHDATA0,	V4
+	.set	GHASHDATA0,	%zmm4
 	.set	GHASHDATA0_XMM,	%xmm4
-	.set	GHASHDATA1,	V5
+	.set	GHASHDATA1,	%zmm5
 	.set	GHASHDATA1_XMM,	%xmm5
-	.set	GHASHDATA2,	V6
+	.set	GHASHDATA2,	%zmm6
 	.set	GHASHDATA2_XMM,	%xmm6
-	.set	GHASHDATA3,	V7
+	.set	GHASHDATA3,	%zmm7
 
 	// BSWAP_MASK is the shuffle mask for byte-reflecting 128-bit values
 	// using vpshufb, copied to all 128-bit lanes.
-	.set	BSWAP_MASK,	V8
+	.set	BSWAP_MASK,	%zmm8
 
 	// RNDKEY temporarily holds the next AES round key.
-	.set	RNDKEY,		V9
+	.set	RNDKEY,		%zmm9
 
 	// GHASH_ACC is the accumulator variable for GHASH.  When fully reduced,
 	// only the lowest 128-bit lane can be nonzero.  When not fully reduced,
 	// more than one lane may be used, and they need to be XOR'd together.
-	.set	GHASH_ACC,	V10
+	.set	GHASH_ACC,	%zmm10
 	.set	GHASH_ACC_XMM,	%xmm10
 
 	// LE_CTR_INC is the vector of 32-bit words that need to be added to a
 	// vector of little-endian counter blocks to advance it forwards.
-	.set	LE_CTR_INC,	V11
+	.set	LE_CTR_INC,	%zmm11
 
 	// LE_CTR contains the next set of little-endian counter blocks.
-	.set	LE_CTR,		V12
+	.set	LE_CTR,		%zmm12
 
 	// RNDKEY0, RNDKEYLAST, and RNDKEY_M[9-1] contain cached AES round keys,
 	// copied to all 128-bit lanes.  RNDKEY0 is the zero-th round key,
 	// RNDKEYLAST the last, and RNDKEY_M\i the one \i-th from the last.
-	.set	RNDKEY0,	V13
-	.set	RNDKEYLAST,	V14
-	.set	RNDKEY_M9,	V15
-	.set	RNDKEY_M8,	V16
-	.set	RNDKEY_M7,	V17
-	.set	RNDKEY_M6,	V18
-	.set	RNDKEY_M5,	V19
-	.set	RNDKEY_M4,	V20
-	.set	RNDKEY_M3,	V21
-	.set	RNDKEY_M2,	V22
-	.set	RNDKEY_M1,	V23
+	.set	RNDKEY0,	%zmm13
+	.set	RNDKEYLAST,	%zmm14
+	.set	RNDKEY_M9,	%zmm15
+	.set	RNDKEY_M8,	%zmm16
+	.set	RNDKEY_M7,	%zmm17
+	.set	RNDKEY_M6,	%zmm18
+	.set	RNDKEY_M5,	%zmm19
+	.set	RNDKEY_M4,	%zmm20
+	.set	RNDKEY_M3,	%zmm21
+	.set	RNDKEY_M2,	%zmm22
+	.set	RNDKEY_M1,	%zmm23
 
 	// GHASHTMP[0-2] are temporary variables used by _ghash_step_4x.  These
 	// cannot coincide with anything used for AES encryption, since for
 	// performance reasons GHASH and AES encryption are interleaved.
-	.set	GHASHTMP0,	V24
-	.set	GHASHTMP1,	V25
-	.set	GHASHTMP2,	V26
+	.set	GHASHTMP0,	%zmm24
+	.set	GHASHTMP1,	%zmm25
+	.set	GHASHTMP2,	%zmm26
 
-	// H_POW[4-1] contain the powers of the hash key H^(4*VL/16)...H^1.  The
+	// H_POW[4-1] contain the powers of the hash key H^16...H^1.  The
 	// descending numbering reflects the order of the key powers.
-	.set	H_POW4,		V27
-	.set	H_POW3,		V28
-	.set	H_POW2,		V29
-	.set	H_POW1,		V30
+	.set	H_POW4,		%zmm27
+	.set	H_POW3,		%zmm28
+	.set	H_POW2,		%zmm29
+	.set	H_POW1,		%zmm30
 
 	// GFPOLY contains the .Lgfpoly constant, copied to all 128-bit lanes.
-	.set	GFPOLY,		V31
+	.set	GFPOLY,		%zmm31
 
 	// Load some constants.
 	vbroadcasti32x4	.Lbswap_mask(%rip), BSWAP_MASK
@@ -719,29 +797,23 @@
 	// Finish initializing LE_CTR by adding [0, 1, ...] to its low words.
 	vpaddd		.Lctr_pattern(%rip), LE_CTR, LE_CTR
 
-	// Initialize LE_CTR_INC to contain VL/16 in all 128-bit lanes.
-.if VL == 32
-	vbroadcasti32x4	.Linc_2blocks(%rip), LE_CTR_INC
-.elseif VL == 64
+	// Load 4 into all 128-bit lanes of LE_CTR_INC.
 	vbroadcasti32x4	.Linc_4blocks(%rip), LE_CTR_INC
-.else
-	.error "Unsupported vector length"
-.endif
 
-	// If there are at least 4*VL bytes of data, then continue into the loop
-	// that processes 4*VL bytes of data at a time.  Otherwise skip it.
+	// If there are at least 256 bytes of data, then continue into the loop
+	// that processes 256 bytes of data at a time.  Otherwise skip it.
 	//
-	// Pre-subtracting 4*VL from DATALEN saves an instruction from the main
+	// Pre-subtracting 256 from DATALEN saves an instruction from the main
 	// loop and also ensures that at least one write always occurs to
 	// DATALEN, zero-extending it and allowing DATALEN64 to be used later.
-	add		$-4*VL, DATALEN  // shorter than 'sub 4*VL' when VL=32
+	sub		$256, DATALEN
 	jl		.Lcrypt_loop_4x_done\@
 
 	// Load powers of the hash key.
-	vmovdqu8	OFFSETOFEND_H_POWERS-4*VL(KEY), H_POW4
-	vmovdqu8	OFFSETOFEND_H_POWERS-3*VL(KEY), H_POW3
-	vmovdqu8	OFFSETOFEND_H_POWERS-2*VL(KEY), H_POW2
-	vmovdqu8	OFFSETOFEND_H_POWERS-1*VL(KEY), H_POW1
+	vmovdqu8	OFFSETOFEND_H_POWERS-4*64(KEY), H_POW4
+	vmovdqu8	OFFSETOFEND_H_POWERS-3*64(KEY), H_POW3
+	vmovdqu8	OFFSETOFEND_H_POWERS-2*64(KEY), H_POW2
+	vmovdqu8	OFFSETOFEND_H_POWERS-1*64(KEY), H_POW1
 
 	// Main loop: en/decrypt and hash 4 vectors at a time.
 	//
@@ -770,9 +842,9 @@
 	cmp		%rax, RNDKEYLAST_PTR
 	jne		1b
 	_aesenclast_and_xor_4x
-	sub		$-4*VL, SRC  // shorter than 'add 4*VL' when VL=32
-	sub		$-4*VL, DST
-	add		$-4*VL, DATALEN
+	add		$256, SRC
+	add		$256, DST
+	sub		$256, DATALEN
 	jl		.Lghash_last_ciphertext_4x\@
 .endif
 
@@ -786,10 +858,10 @@
 	// If decrypting, load more ciphertext blocks into GHASHDATA[0-3].  If
 	// encrypting, GHASHDATA[0-3] already contain the previous ciphertext.
 .if !\enc
-	vmovdqu8	0*VL(SRC), GHASHDATA0
-	vmovdqu8	1*VL(SRC), GHASHDATA1
-	vmovdqu8	2*VL(SRC), GHASHDATA2
-	vmovdqu8	3*VL(SRC), GHASHDATA3
+	vmovdqu8	0*64(SRC), GHASHDATA0
+	vmovdqu8	1*64(SRC), GHASHDATA1
+	vmovdqu8	2*64(SRC), GHASHDATA2
+	vmovdqu8	3*64(SRC), GHASHDATA3
 .endif
 
 	// Start the AES encryption of the counter blocks.
@@ -809,44 +881,44 @@
 	_vaesenc_4x	RNDKEY
 128:
 
-	// Finish the AES encryption of the counter blocks in V0-V3, interleaved
-	// with the GHASH update of the ciphertext blocks in GHASHDATA[0-3].
+	// Finish the AES encryption of the counter blocks in %zmm[0-3],
+	// interleaved with the GHASH update of the ciphertext blocks in
+	// GHASHDATA[0-3].
 .irp i, 9,8,7,6,5,4,3,2,1
 	_ghash_step_4x  (9 - \i)
 	_vaesenc_4x	RNDKEY_M\i
 .endr
 	_ghash_step_4x	9
 	_aesenclast_and_xor_4x
-	sub		$-4*VL, SRC  // shorter than 'add 4*VL' when VL=32
-	sub		$-4*VL, DST
-	add		$-4*VL, DATALEN
+	add		$256, SRC
+	add		$256, DST
+	sub		$256, DATALEN
 	jge		.Lcrypt_loop_4x\@
 
 .if \enc
 .Lghash_last_ciphertext_4x\@:
 	// Update GHASH with the last set of ciphertext blocks.
-.irp i, 0,1,2,3,4,5,6,7,8,9
-	_ghash_step_4x	\i
-.endr
+	_ghash_4x
 .endif
 
 .Lcrypt_loop_4x_done\@:
 
-	// Undo the extra subtraction by 4*VL and check whether data remains.
-	sub		$-4*VL, DATALEN  // shorter than 'add 4*VL' when VL=32
+	// Undo the extra subtraction by 256 and check whether data remains.
+	add		$256, DATALEN
 	jz		.Ldone\@
 
-	// The data length isn't a multiple of 4*VL.  Process the remaining data
-	// of length 1 <= DATALEN < 4*VL, up to one vector (VL bytes) at a time.
-	// Going one vector at a time may seem inefficient compared to having
-	// separate code paths for each possible number of vectors remaining.
-	// However, using a loop keeps the code size down, and it performs
-	// surprising well; modern CPUs will start executing the next iteration
-	// before the previous one finishes and also predict the number of loop
-	// iterations.  For a similar reason, we roll up the AES rounds.
+	// The data length isn't a multiple of 256 bytes.  Process the remaining
+	// data of length 1 <= DATALEN < 256, up to one 64-byte vector at a
+	// time.  Going one vector at a time may seem inefficient compared to
+	// having separate code paths for each possible number of vectors
+	// remaining.  However, using a loop keeps the code size down, and it
+	// performs surprising well; modern CPUs will start executing the next
+	// iteration before the previous one finishes and also predict the
+	// number of loop iterations.  For a similar reason, we roll up the AES
+	// rounds.
 	//
-	// On the last iteration, the remaining length may be less than VL.
-	// Handle this using masking.
+	// On the last iteration, the remaining length may be less than 64
+	// bytes.  Handle this using masking.
 	//
 	// Since there are enough key powers available for all remaining data,
 	// there is no need to do a GHASH reduction after each iteration.
@@ -875,65 +947,60 @@
 .Lcrypt_loop_1x\@:
 
 	// Select the appropriate mask for this iteration: all 1's if
-	// DATALEN >= VL, otherwise DATALEN 1's.  Do this branchlessly using the
+	// DATALEN >= 64, otherwise DATALEN 1's.  Do this branchlessly using the
 	// bzhi instruction from BMI2.  (This relies on DATALEN <= 255.)
-.if VL < 64
-	mov		$-1, %eax
-	bzhi		DATALEN, %eax, %eax
-	kmovd		%eax, %k1
-.else
 	mov		$-1, %rax
 	bzhi		DATALEN64, %rax, %rax
 	kmovq		%rax, %k1
-.endif
 
 	// Encrypt a vector of counter blocks.  This does not need to be masked.
-	vpshufb		BSWAP_MASK, LE_CTR, V0
+	vpshufb		BSWAP_MASK, LE_CTR, %zmm0
 	vpaddd		LE_CTR_INC, LE_CTR, LE_CTR
-	vpxord		RNDKEY0, V0, V0
+	vpxord		RNDKEY0, %zmm0, %zmm0
 	lea		16(KEY), %rax
 1:
 	vbroadcasti32x4	(%rax), RNDKEY
-	vaesenc		RNDKEY, V0, V0
+	vaesenc		RNDKEY, %zmm0, %zmm0
 	add		$16, %rax
 	cmp		%rax, RNDKEYLAST_PTR
 	jne		1b
-	vaesenclast	RNDKEYLAST, V0, V0
+	vaesenclast	RNDKEYLAST, %zmm0, %zmm0
 
 	// XOR the data with the appropriate number of keystream bytes.
-	vmovdqu8	(SRC), V1{%k1}{z}
-	vpxord		V1, V0, V0
-	vmovdqu8	V0, (DST){%k1}
+	vmovdqu8	(SRC), %zmm1{%k1}{z}
+	vpxord		%zmm1, %zmm0, %zmm0
+	vmovdqu8	%zmm0, (DST){%k1}
 
 	// Update GHASH with the ciphertext block(s), without reducing.
 	//
-	// In the case of DATALEN < VL, the ciphertext is zero-padded to VL.
-	// (If decrypting, it's done by the above masked load.  If encrypting,
-	// it's done by the below masked register-to-register move.)  Note that
-	// if DATALEN <= VL - 16, there will be additional padding beyond the
-	// padding of the last block specified by GHASH itself; i.e., there may
-	// be whole block(s) that get processed by the GHASH multiplication and
-	// reduction instructions but should not actually be included in the
+	// In the case of DATALEN < 64, the ciphertext is zero-padded to 64
+	// bytes.  (If decrypting, it's done by the above masked load.  If
+	// encrypting, it's done by the below masked register-to-register move.)
+	// Note that if DATALEN <= 48, there will be additional padding beyond
+	// the padding of the last block specified by GHASH itself; i.e., there
+	// may be whole block(s) that get processed by the GHASH multiplication
+	// and reduction instructions but should not actually be included in the
 	// GHASH.  However, any such blocks are all-zeroes, and the values that
 	// they're multiplied with are also all-zeroes.  Therefore they just add
 	// 0 * 0 = 0 to the final GHASH result, which makes no difference.
 	vmovdqu8	(POWERS_PTR), H_POW1
 .if \enc
-	vmovdqu8	V0, V1{%k1}{z}
+	vmovdqu8	%zmm0, %zmm1{%k1}{z}
 .endif
-	vpshufb		BSWAP_MASK, V1, V0
-	vpxord		GHASH_ACC, V0, V0
-	_ghash_mul_noreduce	H_POW1, V0, LO, MI, HI, GHASHDATA3, V1, V2, V3
+	vpshufb		BSWAP_MASK, %zmm1, %zmm0
+	vpxord		GHASH_ACC, %zmm0, %zmm0
+	_ghash_mul_noreduce	H_POW1, %zmm0, LO, MI, HI, \
+				GHASHDATA3, %zmm1, %zmm2, %zmm3
 	vpxor		GHASH_ACC_XMM, GHASH_ACC_XMM, GHASH_ACC_XMM
 
-	add		$VL, POWERS_PTR
-	add		$VL, SRC
-	add		$VL, DST
-	sub		$VL, DATALEN
+	add		$64, POWERS_PTR
+	add		$64, SRC
+	add		$64, DST
+	sub		$64, DATALEN
 	jg		.Lcrypt_loop_1x\@
 
 	// Finally, do the GHASH reduction.
-	_ghash_reduce	LO, MI, HI, GFPOLY, V0
+	_ghash_reduce	LO, MI, HI, GFPOLY, %zmm0
 	_horizontal_xor	HI, HI_XMM, GHASH_ACC_XMM, %xmm0, %xmm1, %xmm2
 
 .Ldone\@:
@@ -944,14 +1011,14 @@
 	RET
 .endm
 
-// void aes_gcm_enc_final_vaes_avx10(const struct aes_gcm_key_avx10 *key,
-//				     const u32 le_ctr[4], u8 ghash_acc[16],
-//				     u64 total_aadlen, u64 total_datalen);
-// bool aes_gcm_dec_final_vaes_avx10(const struct aes_gcm_key_avx10 *key,
-//				     const u32 le_ctr[4],
-//				     const u8 ghash_acc[16],
-//				     u64 total_aadlen, u64 total_datalen,
-//				     const u8 tag[16], int taglen);
+// void aes_gcm_enc_final_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
+//				      const u32 le_ctr[4], u8 ghash_acc[16],
+//				      u64 total_aadlen, u64 total_datalen);
+// bool aes_gcm_dec_final_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
+//				      const u32 le_ctr[4],
+//				      const u8 ghash_acc[16],
+//				      u64 total_aadlen, u64 total_datalen,
+//				      const u8 tag[16], int taglen);
 //
 // This macro generates one of the above two functions (with \enc selecting
 // which one).  Both functions finish computing the GCM authentication tag by
@@ -1081,119 +1148,16 @@
 	RET
 .endm
 
-_set_veclen 32
-SYM_FUNC_START(aes_gcm_precompute_vaes_avx10_256)
-	_aes_gcm_precompute
-SYM_FUNC_END(aes_gcm_precompute_vaes_avx10_256)
-SYM_FUNC_START(aes_gcm_enc_update_vaes_avx10_256)
-	_aes_gcm_update	1
-SYM_FUNC_END(aes_gcm_enc_update_vaes_avx10_256)
-SYM_FUNC_START(aes_gcm_dec_update_vaes_avx10_256)
-	_aes_gcm_update	0
-SYM_FUNC_END(aes_gcm_dec_update_vaes_avx10_256)
-
-_set_veclen 64
-SYM_FUNC_START(aes_gcm_precompute_vaes_avx10_512)
-	_aes_gcm_precompute
-SYM_FUNC_END(aes_gcm_precompute_vaes_avx10_512)
-SYM_FUNC_START(aes_gcm_enc_update_vaes_avx10_512)
+SYM_FUNC_START(aes_gcm_enc_update_vaes_avx512)
 	_aes_gcm_update	1
-SYM_FUNC_END(aes_gcm_enc_update_vaes_avx10_512)
-SYM_FUNC_START(aes_gcm_dec_update_vaes_avx10_512)
+SYM_FUNC_END(aes_gcm_enc_update_vaes_avx512)
+SYM_FUNC_START(aes_gcm_dec_update_vaes_avx512)
 	_aes_gcm_update	0
-SYM_FUNC_END(aes_gcm_dec_update_vaes_avx10_512)
-
-// void aes_gcm_aad_update_vaes_avx10(const struct aes_gcm_key_avx10 *key,
-//				      u8 ghash_acc[16],
-//				      const u8 *aad, int aadlen);
-//
-// This function processes the AAD (Additional Authenticated Data) in GCM.
-// Using the key |key|, it updates the GHASH accumulator |ghash_acc| with the
-// data given by |aad| and |aadlen|.  |key->ghash_key_powers| must have been
-// initialized.  On the first call, |ghash_acc| must be all zeroes.  |aadlen|
-// must be a multiple of 16, except on the last call where it can be any length.
-// The caller must do any buffering needed to ensure this.
-//
-// AES-GCM is almost always used with small amounts of AAD, less than 32 bytes.
-// Therefore, for AAD processing we currently only provide this implementation
-// which uses 256-bit vectors (ymm registers) and only has a 1x-wide loop.  This
-// keeps the code size down, and it enables some micro-optimizations, e.g. using
-// VEX-coded instructions instead of EVEX-coded to save some instruction bytes.
-// To optimize for large amounts of AAD, we could implement a 4x-wide loop and
-// provide a version using 512-bit vectors, but that doesn't seem to be useful.
-SYM_FUNC_START(aes_gcm_aad_update_vaes_avx10)
-
-	// Function arguments
-	.set	KEY,		%rdi
-	.set	GHASH_ACC_PTR,	%rsi
-	.set	AAD,		%rdx
-	.set	AADLEN,		%ecx
-	.set	AADLEN64,	%rcx	// Zero-extend AADLEN before using!
-
-	// Additional local variables.
-	// %rax, %ymm0-%ymm3, and %k1 are used as temporary registers.
-	.set	BSWAP_MASK,	%ymm4
-	.set	GFPOLY,		%ymm5
-	.set	GHASH_ACC,	%ymm6
-	.set	GHASH_ACC_XMM,	%xmm6
-	.set	H_POW1,		%ymm7
-
-	// Load some constants.
-	vbroadcasti128	.Lbswap_mask(%rip), BSWAP_MASK
-	vbroadcasti128	.Lgfpoly(%rip), GFPOLY
-
-	// Load the GHASH accumulator.
-	vmovdqu		(GHASH_ACC_PTR), GHASH_ACC_XMM
-
-	// Update GHASH with 32 bytes of AAD at a time.
-	//
-	// Pre-subtracting 32 from AADLEN saves an instruction from the loop and
-	// also ensures that at least one write always occurs to AADLEN,
-	// zero-extending it and allowing AADLEN64 to be used later.
-	sub		$32, AADLEN
-	jl		.Laad_loop_1x_done
-	vmovdqu8	OFFSETOFEND_H_POWERS-32(KEY), H_POW1	// [H^2, H^1]
-.Laad_loop_1x:
-	vmovdqu		(AAD), %ymm0
-	vpshufb		BSWAP_MASK, %ymm0, %ymm0
-	vpxor		%ymm0, GHASH_ACC, GHASH_ACC
-	_ghash_mul	H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \
-			%ymm0, %ymm1, %ymm2
-	vextracti128	$1, GHASH_ACC, %xmm0
-	vpxor		%xmm0, GHASH_ACC_XMM, GHASH_ACC_XMM
-	add		$32, AAD
-	sub		$32, AADLEN
-	jge		.Laad_loop_1x
-.Laad_loop_1x_done:
-	add		$32, AADLEN
-	jz		.Laad_done
-
-	// Update GHASH with the remaining 1 <= AADLEN < 32 bytes of AAD.
-	mov		$-1, %eax
-	bzhi		AADLEN, %eax, %eax
-	kmovd		%eax, %k1
-	vmovdqu8	(AAD), %ymm0{%k1}{z}
-	neg		AADLEN64
-	and		$~15, AADLEN64  // -round_up(AADLEN, 16)
-	vmovdqu8	OFFSETOFEND_H_POWERS(KEY,AADLEN64), H_POW1
-	vpshufb		BSWAP_MASK, %ymm0, %ymm0
-	vpxor		%ymm0, GHASH_ACC, GHASH_ACC
-	_ghash_mul	H_POW1, GHASH_ACC, GHASH_ACC, GFPOLY, \
-			%ymm0, %ymm1, %ymm2
-	vextracti128	$1, GHASH_ACC, %xmm0
-	vpxor		%xmm0, GHASH_ACC_XMM, GHASH_ACC_XMM
-
-.Laad_done:
-	// Store the updated GHASH accumulator back to memory.
-	vmovdqu		GHASH_ACC_XMM, (GHASH_ACC_PTR)
-
-	vzeroupper	// This is needed after using ymm or zmm registers.
-	RET
-SYM_FUNC_END(aes_gcm_aad_update_vaes_avx10)
+SYM_FUNC_END(aes_gcm_dec_update_vaes_avx512)
 
-SYM_FUNC_START(aes_gcm_enc_final_vaes_avx10)
+SYM_FUNC_START(aes_gcm_enc_final_vaes_avx512)
 	_aes_gcm_final	1
-SYM_FUNC_END(aes_gcm_enc_final_vaes_avx10)
-SYM_FUNC_START(aes_gcm_dec_final_vaes_avx10)
+SYM_FUNC_END(aes_gcm_enc_final_vaes_avx512)
+SYM_FUNC_START(aes_gcm_dec_final_vaes_avx512)
 	_aes_gcm_final	0
-SYM_FUNC_END(aes_gcm_dec_final_vaes_avx10)
+SYM_FUNC_END(aes_gcm_dec_final_vaes_avx512)
diff --git a/arch/x86/crypto/aesni-intel_glue.c b/arch/x86/crypto/aesni-intel_glue.c
index d953ac470aae..bb6e2c47ffc6 100644
--- a/arch/x86/crypto/aesni-intel_glue.c
+++ b/arch/x86/crypto/aesni-intel_glue.c
@@ -874,8 +874,38 @@ struct aes_gcm_key_aesni {
 #define AES_GCM_KEY_AESNI_SIZE	\
 	(sizeof(struct aes_gcm_key_aesni) + (15 & ~(CRYPTO_MINALIGN - 1)))
 
-/* Key struct used by the VAES + AVX10 implementations of AES-GCM */
-struct aes_gcm_key_avx10 {
+/* Key struct used by the VAES + AVX2 implementation of AES-GCM */
+struct aes_gcm_key_vaes_avx2 {
+	/*
+	 * Common part of the key.  The assembly code prefers 16-byte alignment
+	 * for the round keys; we get this by them being located at the start of
+	 * the struct and the whole struct being 32-byte aligned.
+	 */
+	struct aes_gcm_key base;
+
+	/*
+	 * Powers of the hash key H^8 through H^1.  These are 128-bit values.
+	 * They all have an extra factor of x^-1 and are byte-reversed.
+	 * The assembly code prefers 32-byte alignment for this.
+	 */
+	u64 h_powers[8][2] __aligned(32);
+
+	/*
+	 * Each entry in this array contains the two halves of an entry of
+	 * h_powers XOR'd together, in the following order:
+	 * H^8,H^6,H^7,H^5,H^4,H^2,H^3,H^1 i.e. indices 0,2,1,3,4,6,5,7.
+	 * This is used for Karatsuba multiplication.
+	 */
+	u64 h_powers_xored[8];
+};
+
+#define AES_GCM_KEY_VAES_AVX2(key) \
+	container_of((key), struct aes_gcm_key_vaes_avx2, base)
+#define AES_GCM_KEY_VAES_AVX2_SIZE \
+	(sizeof(struct aes_gcm_key_vaes_avx2) + (31 & ~(CRYPTO_MINALIGN - 1)))
+
+/* Key struct used by the VAES + AVX512 implementation of AES-GCM */
+struct aes_gcm_key_vaes_avx512 {
 	/*
 	 * Common part of the key.  The assembly code prefers 16-byte alignment
 	 * for the round keys; we get this by them being located at the start of
@@ -895,10 +925,10 @@ struct aes_gcm_key_avx10 {
 	/* Three padding blocks required by the assembly code */
 	u64 padding[3][2];
 };
-#define AES_GCM_KEY_AVX10(key)	\
-	container_of((key), struct aes_gcm_key_avx10, base)
-#define AES_GCM_KEY_AVX10_SIZE	\
-	(sizeof(struct aes_gcm_key_avx10) + (63 & ~(CRYPTO_MINALIGN - 1)))
+#define AES_GCM_KEY_VAES_AVX512(key) \
+	container_of((key), struct aes_gcm_key_vaes_avx512, base)
+#define AES_GCM_KEY_VAES_AVX512_SIZE \
+	(sizeof(struct aes_gcm_key_vaes_avx512) + (63 & ~(CRYPTO_MINALIGN - 1)))
 
 /*
  * These flags are passed to the AES-GCM helper functions to specify the
@@ -910,14 +940,16 @@ struct aes_gcm_key_avx10 {
 #define FLAG_RFC4106	BIT(0)
 #define FLAG_ENC	BIT(1)
 #define FLAG_AVX	BIT(2)
-#define FLAG_AVX10_256	BIT(3)
-#define FLAG_AVX10_512	BIT(4)
+#define FLAG_VAES_AVX2	BIT(3)
+#define FLAG_VAES_AVX512 BIT(4)
 
 static inline struct aes_gcm_key *
 aes_gcm_key_get(struct crypto_aead *tfm, int flags)
 {
-	if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512))
+	if (flags & FLAG_VAES_AVX512)
 		return PTR_ALIGN(crypto_aead_ctx(tfm), 64);
+	else if (flags & FLAG_VAES_AVX2)
+		return PTR_ALIGN(crypto_aead_ctx(tfm), 32);
 	else
 		return PTR_ALIGN(crypto_aead_ctx(tfm), 16);
 }
@@ -927,26 +959,16 @@ aes_gcm_precompute_aesni(struct aes_gcm_key_aesni *key);
 asmlinkage void
 aes_gcm_precompute_aesni_avx(struct aes_gcm_key_aesni *key);
 asmlinkage void
-aes_gcm_precompute_vaes_avx10_256(struct aes_gcm_key_avx10 *key);
+aes_gcm_precompute_vaes_avx2(struct aes_gcm_key_vaes_avx2 *key);
 asmlinkage void
-aes_gcm_precompute_vaes_avx10_512(struct aes_gcm_key_avx10 *key);
+aes_gcm_precompute_vaes_avx512(struct aes_gcm_key_vaes_avx512 *key);
 
 static void aes_gcm_precompute(struct aes_gcm_key *key, int flags)
 {
-	/*
-	 * To make things a bit easier on the assembly side, the AVX10
-	 * implementations use the same key format.  Therefore, a single
-	 * function using 256-bit vectors would suffice here.  However, it's
-	 * straightforward to provide a 512-bit one because of how the assembly
-	 * code is structured, and it works nicely because the total size of the
-	 * key powers is a multiple of 512 bits.  So we take advantage of that.
-	 *
-	 * A similar situation applies to the AES-NI implementations.
-	 */
-	if (flags & FLAG_AVX10_512)
-		aes_gcm_precompute_vaes_avx10_512(AES_GCM_KEY_AVX10(key));
-	else if (flags & FLAG_AVX10_256)
-		aes_gcm_precompute_vaes_avx10_256(AES_GCM_KEY_AVX10(key));
+	if (flags & FLAG_VAES_AVX512)
+		aes_gcm_precompute_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key));
+	else if (flags & FLAG_VAES_AVX2)
+		aes_gcm_precompute_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key));
 	else if (flags & FLAG_AVX)
 		aes_gcm_precompute_aesni_avx(AES_GCM_KEY_AESNI(key));
 	else
@@ -960,15 +982,21 @@ asmlinkage void
 aes_gcm_aad_update_aesni_avx(const struct aes_gcm_key_aesni *key,
 			     u8 ghash_acc[16], const u8 *aad, int aadlen);
 asmlinkage void
-aes_gcm_aad_update_vaes_avx10(const struct aes_gcm_key_avx10 *key,
-			      u8 ghash_acc[16], const u8 *aad, int aadlen);
+aes_gcm_aad_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
+			     u8 ghash_acc[16], const u8 *aad, int aadlen);
+asmlinkage void
+aes_gcm_aad_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
+			       u8 ghash_acc[16], const u8 *aad, int aadlen);
 
 static void aes_gcm_aad_update(const struct aes_gcm_key *key, u8 ghash_acc[16],
 			       const u8 *aad, int aadlen, int flags)
 {
-	if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512))
-		aes_gcm_aad_update_vaes_avx10(AES_GCM_KEY_AVX10(key), ghash_acc,
-					      aad, aadlen);
+	if (flags & FLAG_VAES_AVX512)
+		aes_gcm_aad_update_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
+					       ghash_acc, aad, aadlen);
+	else if (flags & FLAG_VAES_AVX2)
+		aes_gcm_aad_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
+					     ghash_acc, aad, aadlen);
 	else if (flags & FLAG_AVX)
 		aes_gcm_aad_update_aesni_avx(AES_GCM_KEY_AESNI(key), ghash_acc,
 					     aad, aadlen);
@@ -986,13 +1014,13 @@ aes_gcm_enc_update_aesni_avx(const struct aes_gcm_key_aesni *key,
 			     const u32 le_ctr[4], u8 ghash_acc[16],
 			     const u8 *src, u8 *dst, int datalen);
 asmlinkage void
-aes_gcm_enc_update_vaes_avx10_256(const struct aes_gcm_key_avx10 *key,
-				  const u32 le_ctr[4], u8 ghash_acc[16],
-				  const u8 *src, u8 *dst, int datalen);
+aes_gcm_enc_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
+			     const u32 le_ctr[4], u8 ghash_acc[16],
+			     const u8 *src, u8 *dst, int datalen);
 asmlinkage void
-aes_gcm_enc_update_vaes_avx10_512(const struct aes_gcm_key_avx10 *key,
-				  const u32 le_ctr[4], u8 ghash_acc[16],
-				  const u8 *src, u8 *dst, int datalen);
+aes_gcm_enc_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
+			       const u32 le_ctr[4], u8 ghash_acc[16],
+			       const u8 *src, u8 *dst, int datalen);
 
 asmlinkage void
 aes_gcm_dec_update_aesni(const struct aes_gcm_key_aesni *key,
@@ -1003,13 +1031,13 @@ aes_gcm_dec_update_aesni_avx(const struct aes_gcm_key_aesni *key,
 			     const u32 le_ctr[4], u8 ghash_acc[16],
 			     const u8 *src, u8 *dst, int datalen);
 asmlinkage void
-aes_gcm_dec_update_vaes_avx10_256(const struct aes_gcm_key_avx10 *key,
-				  const u32 le_ctr[4], u8 ghash_acc[16],
-				  const u8 *src, u8 *dst, int datalen);
+aes_gcm_dec_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
+			     const u32 le_ctr[4], u8 ghash_acc[16],
+			     const u8 *src, u8 *dst, int datalen);
 asmlinkage void
-aes_gcm_dec_update_vaes_avx10_512(const struct aes_gcm_key_avx10 *key,
-				  const u32 le_ctr[4], u8 ghash_acc[16],
-				  const u8 *src, u8 *dst, int datalen);
+aes_gcm_dec_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
+			       const u32 le_ctr[4], u8 ghash_acc[16],
+			       const u8 *src, u8 *dst, int datalen);
 
 /* __always_inline to optimize out the branches based on @flags */
 static __always_inline void
@@ -1018,14 +1046,14 @@ aes_gcm_update(const struct aes_gcm_key *key,
 	       const u8 *src, u8 *dst, int datalen, int flags)
 {
 	if (flags & FLAG_ENC) {
-		if (flags & FLAG_AVX10_512)
-			aes_gcm_enc_update_vaes_avx10_512(AES_GCM_KEY_AVX10(key),
-							  le_ctr, ghash_acc,
-							  src, dst, datalen);
-		else if (flags & FLAG_AVX10_256)
-			aes_gcm_enc_update_vaes_avx10_256(AES_GCM_KEY_AVX10(key),
-							  le_ctr, ghash_acc,
-							  src, dst, datalen);
+		if (flags & FLAG_VAES_AVX512)
+			aes_gcm_enc_update_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
+						       le_ctr, ghash_acc,
+						       src, dst, datalen);
+		else if (flags & FLAG_VAES_AVX2)
+			aes_gcm_enc_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
+						     le_ctr, ghash_acc,
+						     src, dst, datalen);
 		else if (flags & FLAG_AVX)
 			aes_gcm_enc_update_aesni_avx(AES_GCM_KEY_AESNI(key),
 						     le_ctr, ghash_acc,
@@ -1034,14 +1062,14 @@ aes_gcm_update(const struct aes_gcm_key *key,
 			aes_gcm_enc_update_aesni(AES_GCM_KEY_AESNI(key), le_ctr,
 						 ghash_acc, src, dst, datalen);
 	} else {
-		if (flags & FLAG_AVX10_512)
-			aes_gcm_dec_update_vaes_avx10_512(AES_GCM_KEY_AVX10(key),
-							  le_ctr, ghash_acc,
-							  src, dst, datalen);
-		else if (flags & FLAG_AVX10_256)
-			aes_gcm_dec_update_vaes_avx10_256(AES_GCM_KEY_AVX10(key),
-							  le_ctr, ghash_acc,
-							  src, dst, datalen);
+		if (flags & FLAG_VAES_AVX512)
+			aes_gcm_dec_update_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
+						       le_ctr, ghash_acc,
+						       src, dst, datalen);
+		else if (flags & FLAG_VAES_AVX2)
+			aes_gcm_dec_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
+						     le_ctr, ghash_acc,
+						     src, dst, datalen);
 		else if (flags & FLAG_AVX)
 			aes_gcm_dec_update_aesni_avx(AES_GCM_KEY_AESNI(key),
 						     le_ctr, ghash_acc,
@@ -1062,9 +1090,13 @@ aes_gcm_enc_final_aesni_avx(const struct aes_gcm_key_aesni *key,
 			    const u32 le_ctr[4], u8 ghash_acc[16],
 			    u64 total_aadlen, u64 total_datalen);
 asmlinkage void
-aes_gcm_enc_final_vaes_avx10(const struct aes_gcm_key_avx10 *key,
-			     const u32 le_ctr[4], u8 ghash_acc[16],
-			     u64 total_aadlen, u64 total_datalen);
+aes_gcm_enc_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
+			    const u32 le_ctr[4], u8 ghash_acc[16],
+			    u64 total_aadlen, u64 total_datalen);
+asmlinkage void
+aes_gcm_enc_final_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
+			      const u32 le_ctr[4], u8 ghash_acc[16],
+			      u64 total_aadlen, u64 total_datalen);
 
 /* __always_inline to optimize out the branches based on @flags */
 static __always_inline void
@@ -1072,10 +1104,14 @@ aes_gcm_enc_final(const struct aes_gcm_key *key,
 		  const u32 le_ctr[4], u8 ghash_acc[16],
 		  u64 total_aadlen, u64 total_datalen, int flags)
 {
-	if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512))
-		aes_gcm_enc_final_vaes_avx10(AES_GCM_KEY_AVX10(key),
-					     le_ctr, ghash_acc,
-					     total_aadlen, total_datalen);
+	if (flags & FLAG_VAES_AVX512)
+		aes_gcm_enc_final_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
+					      le_ctr, ghash_acc,
+					      total_aadlen, total_datalen);
+	else if (flags & FLAG_VAES_AVX2)
+		aes_gcm_enc_final_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
+					    le_ctr, ghash_acc,
+					    total_aadlen, total_datalen);
 	else if (flags & FLAG_AVX)
 		aes_gcm_enc_final_aesni_avx(AES_GCM_KEY_AESNI(key),
 					    le_ctr, ghash_acc,
@@ -1097,10 +1133,15 @@ aes_gcm_dec_final_aesni_avx(const struct aes_gcm_key_aesni *key,
 			    u64 total_aadlen, u64 total_datalen,
 			    const u8 tag[16], int taglen);
 asmlinkage bool __must_check
-aes_gcm_dec_final_vaes_avx10(const struct aes_gcm_key_avx10 *key,
-			     const u32 le_ctr[4], const u8 ghash_acc[16],
-			     u64 total_aadlen, u64 total_datalen,
-			     const u8 tag[16], int taglen);
+aes_gcm_dec_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
+			    const u32 le_ctr[4], const u8 ghash_acc[16],
+			    u64 total_aadlen, u64 total_datalen,
+			    const u8 tag[16], int taglen);
+asmlinkage bool __must_check
+aes_gcm_dec_final_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
+			      const u32 le_ctr[4], const u8 ghash_acc[16],
+			      u64 total_aadlen, u64 total_datalen,
+			      const u8 tag[16], int taglen);
 
 /* __always_inline to optimize out the branches based on @flags */
 static __always_inline bool __must_check
@@ -1108,11 +1149,16 @@ aes_gcm_dec_final(const struct aes_gcm_key *key, const u32 le_ctr[4],
 		  u8 ghash_acc[16], u64 total_aadlen, u64 total_datalen,
 		  u8 tag[16], int taglen, int flags)
 {
-	if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512))
-		return aes_gcm_dec_final_vaes_avx10(AES_GCM_KEY_AVX10(key),
-						    le_ctr, ghash_acc,
-						    total_aadlen, total_datalen,
-						    tag, taglen);
+	if (flags & FLAG_VAES_AVX512)
+		return aes_gcm_dec_final_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
+						     le_ctr, ghash_acc,
+						     total_aadlen, total_datalen,
+						     tag, taglen);
+	else if (flags & FLAG_VAES_AVX2)
+		return aes_gcm_dec_final_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
+						   le_ctr, ghash_acc,
+						   total_aadlen, total_datalen,
+						   tag, taglen);
 	else if (flags & FLAG_AVX)
 		return aes_gcm_dec_final_aesni_avx(AES_GCM_KEY_AESNI(key),
 						   le_ctr, ghash_acc,
@@ -1195,10 +1241,14 @@ static int gcm_setkey(struct crypto_aead *tfm, const u8 *raw_key,
 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_powers) != 496);
 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_powers_xored) != 624);
 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_times_x64) != 688);
-	BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, base.aes_key.key_enc) != 0);
-	BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, base.aes_key.key_length) != 480);
-	BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, h_powers) != 512);
-	BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, padding) != 768);
+	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, base.aes_key.key_enc) != 0);
+	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, base.aes_key.key_length) != 480);
+	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, h_powers) != 512);
+	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, h_powers_xored) != 640);
+	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx512, base.aes_key.key_enc) != 0);
+	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx512, base.aes_key.key_length) != 480);
+	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx512, h_powers) != 512);
+	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx512, padding) != 768);
 
 	if (likely(crypto_simd_usable())) {
 		err = aes_check_keylen(keylen);
@@ -1231,8 +1281,9 @@ static int gcm_setkey(struct crypto_aead *tfm, const u8 *raw_key,
 		gf128mul_lle(&h, (const be128 *)x_to_the_minus1);
 
 		/* Compute the needed key powers */
-		if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512)) {
-			struct aes_gcm_key_avx10 *k = AES_GCM_KEY_AVX10(key);
+		if (flags & FLAG_VAES_AVX512) {
+			struct aes_gcm_key_vaes_avx512 *k =
+				AES_GCM_KEY_VAES_AVX512(key);
 
 			for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) {
 				k->h_powers[i][0] = be64_to_cpu(h.b);
@@ -1240,6 +1291,22 @@ static int gcm_setkey(struct crypto_aead *tfm, const u8 *raw_key,
 				gf128mul_lle(&h, &h1);
 			}
 			memset(k->padding, 0, sizeof(k->padding));
+		} else if (flags & FLAG_VAES_AVX2) {
+			struct aes_gcm_key_vaes_avx2 *k =
+				AES_GCM_KEY_VAES_AVX2(key);
+			static const u8 indices[8] = { 0, 2, 1, 3, 4, 6, 5, 7 };
+
+			for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) {
+				k->h_powers[i][0] = be64_to_cpu(h.b);
+				k->h_powers[i][1] = be64_to_cpu(h.a);
+				gf128mul_lle(&h, &h1);
+			}
+			for (i = 0; i < ARRAY_SIZE(k->h_powers_xored); i++) {
+				int j = indices[i];
+
+				k->h_powers_xored[i] = k->h_powers[j][0] ^
+						       k->h_powers[j][1];
+			}
 		} else {
 			struct aes_gcm_key_aesni *k = AES_GCM_KEY_AESNI(key);
 
@@ -1508,15 +1575,15 @@ DEFINE_GCM_ALGS(aesni_avx, FLAG_AVX,
 		"generic-gcm-aesni-avx", "rfc4106-gcm-aesni-avx",
 		AES_GCM_KEY_AESNI_SIZE, 500);
 
-/* aes_gcm_algs_vaes_avx10_256 */
-DEFINE_GCM_ALGS(vaes_avx10_256, FLAG_AVX10_256,
-		"generic-gcm-vaes-avx10_256", "rfc4106-gcm-vaes-avx10_256",
-		AES_GCM_KEY_AVX10_SIZE, 700);
+/* aes_gcm_algs_vaes_avx2 */
+DEFINE_GCM_ALGS(vaes_avx2, FLAG_VAES_AVX2,
+		"generic-gcm-vaes-avx2", "rfc4106-gcm-vaes-avx2",
+		AES_GCM_KEY_VAES_AVX2_SIZE, 600);
 
-/* aes_gcm_algs_vaes_avx10_512 */
-DEFINE_GCM_ALGS(vaes_avx10_512, FLAG_AVX10_512,
-		"generic-gcm-vaes-avx10_512", "rfc4106-gcm-vaes-avx10_512",
-		AES_GCM_KEY_AVX10_SIZE, 800);
+/* aes_gcm_algs_vaes_avx512 */
+DEFINE_GCM_ALGS(vaes_avx512, FLAG_VAES_AVX512,
+		"generic-gcm-vaes-avx512", "rfc4106-gcm-vaes-avx512",
+		AES_GCM_KEY_VAES_AVX512_SIZE, 800);
 
 static int __init register_avx_algs(void)
 {
@@ -1548,6 +1615,10 @@ static int __init register_avx_algs(void)
 					ARRAY_SIZE(skcipher_algs_vaes_avx2));
 	if (err)
 		return err;
+	err = crypto_register_aeads(aes_gcm_algs_vaes_avx2,
+				    ARRAY_SIZE(aes_gcm_algs_vaes_avx2));
+	if (err)
+		return err;
 
 	if (!boot_cpu_has(X86_FEATURE_AVX512BW) ||
 	    !boot_cpu_has(X86_FEATURE_AVX512VL) ||
@@ -1556,26 +1627,21 @@ static int __init register_avx_algs(void)
 			       XFEATURE_MASK_AVX512, NULL))
 		return 0;
 
-	err = crypto_register_aeads(aes_gcm_algs_vaes_avx10_256,
-				    ARRAY_SIZE(aes_gcm_algs_vaes_avx10_256));
-	if (err)
-		return err;
-
 	if (boot_cpu_has(X86_FEATURE_PREFER_YMM)) {
 		int i;
 
 		for (i = 0; i < ARRAY_SIZE(skcipher_algs_vaes_avx512); i++)
 			skcipher_algs_vaes_avx512[i].base.cra_priority = 1;
-		for (i = 0; i < ARRAY_SIZE(aes_gcm_algs_vaes_avx10_512); i++)
-			aes_gcm_algs_vaes_avx10_512[i].base.cra_priority = 1;
+		for (i = 0; i < ARRAY_SIZE(aes_gcm_algs_vaes_avx512); i++)
+			aes_gcm_algs_vaes_avx512[i].base.cra_priority = 1;
 	}
 
 	err = crypto_register_skciphers(skcipher_algs_vaes_avx512,
 					ARRAY_SIZE(skcipher_algs_vaes_avx512));
 	if (err)
 		return err;
-	err = crypto_register_aeads(aes_gcm_algs_vaes_avx10_512,
-				    ARRAY_SIZE(aes_gcm_algs_vaes_avx10_512));
+	err = crypto_register_aeads(aes_gcm_algs_vaes_avx512,
+				    ARRAY_SIZE(aes_gcm_algs_vaes_avx512));
 	if (err)
 		return err;
 
@@ -1595,8 +1661,8 @@ static void unregister_avx_algs(void)
 	unregister_aeads(aes_gcm_algs_aesni_avx);
 	unregister_skciphers(skcipher_algs_vaes_avx2);
 	unregister_skciphers(skcipher_algs_vaes_avx512);
-	unregister_aeads(aes_gcm_algs_vaes_avx10_256);
-	unregister_aeads(aes_gcm_algs_vaes_avx10_512);
+	unregister_aeads(aes_gcm_algs_vaes_avx2);
+	unregister_aeads(aes_gcm_algs_vaes_avx512);
 }
 #else /* CONFIG_X86_64 */
 static struct aead_alg aes_gcm_algs_aesni[0];
diff --git a/drivers/md/Kconfig b/drivers/md/Kconfig
index 104aa5355090..cac4926fc340 100644
--- a/drivers/md/Kconfig
+++ b/drivers/md/Kconfig
@@ -546,6 +546,7 @@ config DM_VERITY
 	depends on BLK_DEV_DM
 	select CRYPTO
 	select CRYPTO_HASH
+	select CRYPTO_LIB_SHA256
 	select DM_BUFIO
 	help
 	  This device-mapper target creates a read-only device that
diff --git a/drivers/md/dm-verity-fec.c b/drivers/md/dm-verity-fec.c
index 72047b47a7a0..0c858b9ee06b 100644
--- a/drivers/md/dm-verity-fec.c
+++ b/drivers/md/dm-verity-fec.c
@@ -188,14 +188,13 @@ static int fec_decode_bufs(struct dm_verity *v, struct dm_verity_io *io,
  * Locate data block erasures using verity hashes.
  */
 static int fec_is_erasure(struct dm_verity *v, struct dm_verity_io *io,
-			  u8 *want_digest, u8 *data)
+			  const u8 *want_digest, const u8 *data)
 {
 	if (unlikely(verity_hash(v, io, data, 1 << v->data_dev_block_bits,
-				 verity_io_real_digest(v, io))))
+				 io->tmp_digest)))
 		return 0;
 
-	return memcmp(verity_io_real_digest(v, io), want_digest,
-		      v->digest_size) != 0;
+	return memcmp(io->tmp_digest, want_digest, v->digest_size) != 0;
 }
 
 /*
@@ -362,7 +361,7 @@ static void fec_init_bufs(struct dm_verity *v, struct dm_verity_fec_io *fio)
  */
 static int fec_decode_rsb(struct dm_verity *v, struct dm_verity_io *io,
 			  struct dm_verity_fec_io *fio, u64 rsb, u64 offset,
-			  bool use_erasures)
+			  const u8 *want_digest, bool use_erasures)
 {
 	int r, neras = 0;
 	unsigned int pos;
@@ -388,12 +387,11 @@ static int fec_decode_rsb(struct dm_verity *v, struct dm_verity_io *io,
 
 	/* Always re-validate the corrected block against the expected hash */
 	r = verity_hash(v, io, fio->output, 1 << v->data_dev_block_bits,
-			verity_io_real_digest(v, io));
+			io->tmp_digest);
 	if (unlikely(r < 0))
 		return r;
 
-	if (memcmp(verity_io_real_digest(v, io), verity_io_want_digest(v, io),
-		   v->digest_size)) {
+	if (memcmp(io->tmp_digest, want_digest, v->digest_size)) {
 		DMERR_LIMIT("%s: FEC %llu: failed to correct (%d erasures)",
 			    v->data_dev->name, (unsigned long long)rsb, neras);
 		return -EILSEQ;
@@ -404,7 +402,8 @@ static int fec_decode_rsb(struct dm_verity *v, struct dm_verity_io *io,
 
 /* Correct errors in a block. Copies corrected block to dest. */
 int verity_fec_decode(struct dm_verity *v, struct dm_verity_io *io,
-		      enum verity_block_type type, sector_t block, u8 *dest)
+		      enum verity_block_type type, const u8 *want_digest,
+		      sector_t block, u8 *dest)
 {
 	int r;
 	struct dm_verity_fec_io *fio = fec_io(io);
@@ -447,9 +446,9 @@ int verity_fec_decode(struct dm_verity *v, struct dm_verity_io *io,
 	 * them first. Do a second attempt with erasures if the corruption is
 	 * bad enough.
 	 */
-	r = fec_decode_rsb(v, io, fio, rsb, offset, false);
+	r = fec_decode_rsb(v, io, fio, rsb, offset, want_digest, false);
 	if (r < 0) {
-		r = fec_decode_rsb(v, io, fio, rsb, offset, true);
+		r = fec_decode_rsb(v, io, fio, rsb, offset, want_digest, true);
 		if (r < 0)
 			goto done;
 	}
diff --git a/drivers/md/dm-verity-fec.h b/drivers/md/dm-verity-fec.h
index 09123a612953..a6689cdc489d 100644
--- a/drivers/md/dm-verity-fec.h
+++ b/drivers/md/dm-verity-fec.h
@@ -68,8 +68,8 @@ struct dm_verity_fec_io {
 extern bool verity_fec_is_enabled(struct dm_verity *v);
 
 extern int verity_fec_decode(struct dm_verity *v, struct dm_verity_io *io,
-			     enum verity_block_type type, sector_t block,
-			     u8 *dest);
+			     enum verity_block_type type, const u8 *want_digest,
+			     sector_t block, u8 *dest);
 
 extern unsigned int verity_fec_status_table(struct dm_verity *v, unsigned int sz,
 					char *result, unsigned int maxlen);
@@ -99,6 +99,7 @@ static inline bool verity_fec_is_enabled(struct dm_verity *v)
 static inline int verity_fec_decode(struct dm_verity *v,
 				    struct dm_verity_io *io,
 				    enum verity_block_type type,
+				    const u8 *want_digest,
 				    sector_t block, u8 *dest)
 {
 	return -EOPNOTSUPP;
diff --git a/drivers/md/dm-verity-target.c b/drivers/md/dm-verity-target.c
index 66a00a8ccb39..bf0aee73b074 100644
--- a/drivers/md/dm-verity-target.c
+++ b/drivers/md/dm-verity-target.c
@@ -117,11 +117,25 @@ static sector_t verity_position_at_level(struct dm_verity *v, sector_t block,
 int verity_hash(struct dm_verity *v, struct dm_verity_io *io,
 		const u8 *data, size_t len, u8 *digest)
 {
-	struct shash_desc *desc = &io->hash_desc;
+	struct shash_desc *desc;
 	int r;
 
+	if (likely(v->use_sha256_lib)) {
+		struct sha256_ctx *ctx = &io->hash_ctx.sha256;
+
+		/*
+		 * Fast path using SHA-256 library.  This is enabled only for
+		 * verity version 1, where the salt is at the beginning.
+		 */
+		*ctx = *v->initial_hashstate.sha256;
+		sha256_update(ctx, data, len);
+		sha256_final(ctx, digest);
+		return 0;
+	}
+
+	desc = &io->hash_ctx.shash;
 	desc->tfm = v->shash_tfm;
-	if (unlikely(v->initial_hashstate == NULL)) {
+	if (unlikely(v->initial_hashstate.shash == NULL)) {
 		/* Version 0: salt at end */
 		r = crypto_shash_init(desc) ?:
 		    crypto_shash_update(desc, data, len) ?:
@@ -129,7 +143,7 @@ int verity_hash(struct dm_verity *v, struct dm_verity_io *io,
 		    crypto_shash_final(desc, digest);
 	} else {
 		/* Version 1: salt at beginning */
-		r = crypto_shash_import(desc, v->initial_hashstate) ?:
+		r = crypto_shash_import(desc, v->initial_hashstate.shash) ?:
 		    crypto_shash_finup(desc, data, len, digest);
 	}
 	if (unlikely(r))
@@ -215,12 +229,12 @@ static int verity_handle_err(struct dm_verity *v, enum verity_block_type type,
  * Verify hash of a metadata block pertaining to the specified data block
  * ("block" argument) at a specified level ("level" argument).
  *
- * On successful return, verity_io_want_digest(v, io) contains the hash value
- * for a lower tree level or for the data block (if we're at the lowest level).
+ * On successful return, want_digest contains the hash value for a lower tree
+ * level or for the data block (if we're at the lowest level).
  *
  * If "skip_unverified" is true, unverified buffer is skipped and 1 is returned.
  * If "skip_unverified" is false, unverified buffer is hashed and verified
- * against current value of verity_io_want_digest(v, io).
+ * against current value of want_digest.
  */
 static int verity_verify_level(struct dm_verity *v, struct dm_verity_io *io,
 			       sector_t block, int level, bool skip_unverified,
@@ -259,7 +273,7 @@ static int verity_verify_level(struct dm_verity *v, struct dm_verity_io *io,
 		if (IS_ERR(data))
 			return r;
 		if (verity_fec_decode(v, io, DM_VERITY_BLOCK_TYPE_METADATA,
-				      hash_block, data) == 0) {
+				      want_digest, hash_block, data) == 0) {
 			aux = dm_bufio_get_aux_data(buf);
 			aux->hash_verified = 1;
 			goto release_ok;
@@ -279,11 +293,11 @@ static int verity_verify_level(struct dm_verity *v, struct dm_verity_io *io,
 		}
 
 		r = verity_hash(v, io, data, 1 << v->hash_dev_block_bits,
-				verity_io_real_digest(v, io));
+				io->tmp_digest);
 		if (unlikely(r < 0))
 			goto release_ret_r;
 
-		if (likely(memcmp(verity_io_real_digest(v, io), want_digest,
+		if (likely(memcmp(io->tmp_digest, want_digest,
 				  v->digest_size) == 0))
 			aux->hash_verified = 1;
 		else if (static_branch_unlikely(&use_bh_wq_enabled) && io->in_bh) {
@@ -294,7 +308,7 @@ static int verity_verify_level(struct dm_verity *v, struct dm_verity_io *io,
 			r = -EAGAIN;
 			goto release_ret_r;
 		} else if (verity_fec_decode(v, io, DM_VERITY_BLOCK_TYPE_METADATA,
-					     hash_block, data) == 0)
+					     want_digest, hash_block, data) == 0)
 			aux->hash_verified = 1;
 		else if (verity_handle_err(v,
 					   DM_VERITY_BLOCK_TYPE_METADATA,
@@ -358,7 +372,8 @@ int verity_hash_for_block(struct dm_verity *v, struct dm_verity_io *io,
 }
 
 static noinline int verity_recheck(struct dm_verity *v, struct dm_verity_io *io,
-				   sector_t cur_block, u8 *dest)
+				   const u8 *want_digest, sector_t cur_block,
+				   u8 *dest)
 {
 	struct page *page;
 	void *buffer;
@@ -382,12 +397,11 @@ static noinline int verity_recheck(struct dm_verity *v, struct dm_verity_io *io,
 		goto free_ret;
 
 	r = verity_hash(v, io, buffer, 1 << v->data_dev_block_bits,
-			verity_io_real_digest(v, io));
+			io->tmp_digest);
 	if (unlikely(r))
 		goto free_ret;
 
-	if (memcmp(verity_io_real_digest(v, io),
-		   verity_io_want_digest(v, io), v->digest_size)) {
+	if (memcmp(io->tmp_digest, want_digest, v->digest_size)) {
 		r = -EIO;
 		goto free_ret;
 	}
@@ -402,9 +416,13 @@ static noinline int verity_recheck(struct dm_verity *v, struct dm_verity_io *io,
 
 static int verity_handle_data_hash_mismatch(struct dm_verity *v,
 					    struct dm_verity_io *io,
-					    struct bio *bio, sector_t blkno,
-					    u8 *data)
+					    struct bio *bio,
+					    struct pending_block *block)
 {
+	const u8 *want_digest = block->want_digest;
+	sector_t blkno = block->blkno;
+	u8 *data = block->data;
+
 	if (static_branch_unlikely(&use_bh_wq_enabled) && io->in_bh) {
 		/*
 		 * Error handling code (FEC included) cannot be run in the
@@ -412,14 +430,14 @@ static int verity_handle_data_hash_mismatch(struct dm_verity *v,
 		 */
 		return -EAGAIN;
 	}
-	if (verity_recheck(v, io, blkno, data) == 0) {
+	if (verity_recheck(v, io, want_digest, blkno, data) == 0) {
 		if (v->validated_blocks)
 			set_bit(blkno, v->validated_blocks);
 		return 0;
 	}
 #if defined(CONFIG_DM_VERITY_FEC)
-	if (verity_fec_decode(v, io, DM_VERITY_BLOCK_TYPE_DATA, blkno,
-			      data) == 0)
+	if (verity_fec_decode(v, io, DM_VERITY_BLOCK_TYPE_DATA, want_digest,
+			      blkno, data) == 0)
 		return 0;
 #endif
 	if (bio->bi_status)
@@ -433,6 +451,58 @@ static int verity_handle_data_hash_mismatch(struct dm_verity *v,
 	return 0;
 }
 
+static void verity_clear_pending_blocks(struct dm_verity_io *io)
+{
+	int i;
+
+	for (i = io->num_pending - 1; i >= 0; i--) {
+		kunmap_local(io->pending_blocks[i].data);
+		io->pending_blocks[i].data = NULL;
+	}
+	io->num_pending = 0;
+}
+
+static int verity_verify_pending_blocks(struct dm_verity *v,
+					struct dm_verity_io *io,
+					struct bio *bio)
+{
+	const unsigned int block_size = 1 << v->data_dev_block_bits;
+	int i, r;
+
+	if (io->num_pending == 2) {
+		/* num_pending == 2 implies that the algorithm is SHA-256 */
+		sha256_finup_2x(v->initial_hashstate.sha256,
+				io->pending_blocks[0].data,
+				io->pending_blocks[1].data, block_size,
+				io->pending_blocks[0].real_digest,
+				io->pending_blocks[1].real_digest);
+	} else {
+		for (i = 0; i < io->num_pending; i++) {
+			r = verity_hash(v, io, io->pending_blocks[i].data,
+					block_size,
+					io->pending_blocks[i].real_digest);
+			if (unlikely(r))
+				return r;
+		}
+	}
+
+	for (i = 0; i < io->num_pending; i++) {
+		struct pending_block *block = &io->pending_blocks[i];
+
+		if (likely(memcmp(block->real_digest, block->want_digest,
+				  v->digest_size) == 0)) {
+			if (v->validated_blocks)
+				set_bit(block->blkno, v->validated_blocks);
+		} else {
+			r = verity_handle_data_hash_mismatch(v, io, bio, block);
+			if (unlikely(r))
+				return r;
+		}
+	}
+	verity_clear_pending_blocks(io);
+	return 0;
+}
+
 /*
  * Verify one "dm_verity_io" structure.
  */
@@ -440,10 +510,14 @@ static int verity_verify_io(struct dm_verity_io *io)
 {
 	struct dm_verity *v = io->v;
 	const unsigned int block_size = 1 << v->data_dev_block_bits;
+	const int max_pending = v->use_sha256_finup_2x ? 2 : 1;
 	struct bvec_iter iter_copy;
 	struct bvec_iter *iter;
 	struct bio *bio = dm_bio_from_per_bio_data(io, v->ti->per_io_data_size);
 	unsigned int b;
+	int r;
+
+	io->num_pending = 0;
 
 	if (static_branch_unlikely(&use_bh_wq_enabled) && io->in_bh) {
 		/*
@@ -457,21 +531,22 @@ static int verity_verify_io(struct dm_verity_io *io)
 
 	for (b = 0; b < io->n_blocks;
 	     b++, bio_advance_iter(bio, iter, block_size)) {
-		int r;
-		sector_t cur_block = io->block + b;
+		sector_t blkno = io->block + b;
+		struct pending_block *block;
 		bool is_zero;
 		struct bio_vec bv;
 		void *data;
 
 		if (v->validated_blocks && bio->bi_status == BLK_STS_OK &&
-		    likely(test_bit(cur_block, v->validated_blocks)))
+		    likely(test_bit(blkno, v->validated_blocks)))
 			continue;
 
-		r = verity_hash_for_block(v, io, cur_block,
-					  verity_io_want_digest(v, io),
+		block = &io->pending_blocks[io->num_pending];
+
+		r = verity_hash_for_block(v, io, blkno, block->want_digest,
 					  &is_zero);
 		if (unlikely(r < 0))
-			return r;
+			goto error;
 
 		bv = bio_iter_iovec(bio, *iter);
 		if (unlikely(bv.bv_len < block_size)) {
@@ -482,7 +557,8 @@ static int verity_verify_io(struct dm_verity_io *io)
 			 * data block size to be greater than PAGE_SIZE.
 			 */
 			DMERR_LIMIT("unaligned io (data block spans pages)");
-			return -EIO;
+			r = -EIO;
+			goto error;
 		}
 
 		data = bvec_kmap_local(&bv);
@@ -496,29 +572,26 @@ static int verity_verify_io(struct dm_verity_io *io)
 			kunmap_local(data);
 			continue;
 		}
-
-		r = verity_hash(v, io, data, block_size,
-				verity_io_real_digest(v, io));
-		if (unlikely(r < 0)) {
-			kunmap_local(data);
-			return r;
+		block->data = data;
+		block->blkno = blkno;
+		if (++io->num_pending == max_pending) {
+			r = verity_verify_pending_blocks(v, io, bio);
+			if (unlikely(r))
+				goto error;
 		}
+	}
 
-		if (likely(memcmp(verity_io_real_digest(v, io),
-				  verity_io_want_digest(v, io), v->digest_size) == 0)) {
-			if (v->validated_blocks)
-				set_bit(cur_block, v->validated_blocks);
-			kunmap_local(data);
-			continue;
-		}
-		r = verity_handle_data_hash_mismatch(v, io, bio, cur_block,
-						     data);
-		kunmap_local(data);
+	if (io->num_pending) {
+		r = verity_verify_pending_blocks(v, io, bio);
 		if (unlikely(r))
-			return r;
+			goto error;
 	}
 
 	return 0;
+
+error:
+	verity_clear_pending_blocks(io);
+	return r;
 }
 
 /*
@@ -1004,7 +1077,7 @@ static void verity_dtr(struct dm_target *ti)
 
 	kvfree(v->validated_blocks);
 	kfree(v->salt);
-	kfree(v->initial_hashstate);
+	kfree(v->initial_hashstate.shash);
 	kfree(v->root_digest);
 	kfree(v->zero_digest);
 	verity_free_sig(v);
@@ -1069,8 +1142,7 @@ static int verity_alloc_zero_digest(struct dm_verity *v)
 	if (!v->zero_digest)
 		return r;
 
-	io = kmalloc(sizeof(*io) + crypto_shash_descsize(v->shash_tfm),
-		     GFP_KERNEL);
+	io = kmalloc(v->ti->per_io_data_size, GFP_KERNEL);
 
 	if (!io)
 		return r; /* verity_dtr will free zero_digest */
@@ -1252,11 +1324,26 @@ static int verity_setup_hash_alg(struct dm_verity *v, const char *alg_name)
 	}
 	v->shash_tfm = shash;
 	v->digest_size = crypto_shash_digestsize(shash);
-	DMINFO("%s using \"%s\"", alg_name, crypto_shash_driver_name(shash));
 	if ((1 << v->hash_dev_block_bits) < v->digest_size * 2) {
 		ti->error = "Digest size too big";
 		return -EINVAL;
 	}
+	if (likely(v->version && strcmp(alg_name, "sha256") == 0)) {
+		/*
+		 * Fast path: use the library API for reduced overhead and
+		 * interleaved hashing support.
+		 */
+		v->use_sha256_lib = true;
+		if (sha256_finup_2x_is_optimized())
+			v->use_sha256_finup_2x = true;
+		ti->per_io_data_size =
+			offsetofend(struct dm_verity_io, hash_ctx.sha256);
+	} else {
+		/* Fallback case: use the generic crypto API. */
+		ti->per_io_data_size =
+			offsetofend(struct dm_verity_io, hash_ctx.shash) +
+			crypto_shash_descsize(shash);
+	}
 	return 0;
 }
 
@@ -1277,7 +1364,18 @@ static int verity_setup_salt_and_hashstate(struct dm_verity *v, const char *arg)
 			return -EINVAL;
 		}
 	}
-	if (v->version) { /* Version 1: salt at beginning */
+	if (likely(v->use_sha256_lib)) {
+		/* Implies version 1: salt at beginning */
+		v->initial_hashstate.sha256 =
+			kmalloc(sizeof(struct sha256_ctx), GFP_KERNEL);
+		if (!v->initial_hashstate.sha256) {
+			ti->error = "Cannot allocate initial hash state";
+			return -ENOMEM;
+		}
+		sha256_init(v->initial_hashstate.sha256);
+		sha256_update(v->initial_hashstate.sha256,
+			      v->salt, v->salt_size);
+	} else if (v->version) { /* Version 1: salt at beginning */
 		SHASH_DESC_ON_STACK(desc, v->shash_tfm);
 		int r;
 
@@ -1285,16 +1383,16 @@ static int verity_setup_salt_and_hashstate(struct dm_verity *v, const char *arg)
 		 * Compute the pre-salted hash state that can be passed to
 		 * crypto_shash_import() for each block later.
 		 */
-		v->initial_hashstate = kmalloc(
+		v->initial_hashstate.shash = kmalloc(
 			crypto_shash_statesize(v->shash_tfm), GFP_KERNEL);
-		if (!v->initial_hashstate) {
+		if (!v->initial_hashstate.shash) {
 			ti->error = "Cannot allocate initial hash state";
 			return -ENOMEM;
 		}
 		desc->tfm = v->shash_tfm;
 		r = crypto_shash_init(desc) ?:
 		    crypto_shash_update(desc, v->salt, v->salt_size) ?:
-		    crypto_shash_export(desc, v->initial_hashstate);
+		    crypto_shash_export(desc, v->initial_hashstate.shash);
 		if (r) {
 			ti->error = "Cannot set up initial hash state";
 			return r;
@@ -1556,9 +1654,6 @@ static int verity_ctr(struct dm_target *ti, unsigned int argc, char **argv)
 		goto bad;
 	}
 
-	ti->per_io_data_size = sizeof(struct dm_verity_io) +
-			       crypto_shash_descsize(v->shash_tfm);
-
 	r = verity_fec_ctr(v);
 	if (r)
 		goto bad;
diff --git a/drivers/md/dm-verity.h b/drivers/md/dm-verity.h
index 6d141abd965c..f975a9e5c5d6 100644
--- a/drivers/md/dm-verity.h
+++ b/drivers/md/dm-verity.h
@@ -16,6 +16,7 @@
 #include <linux/device-mapper.h>
 #include <linux/interrupt.h>
 #include <crypto/hash.h>
+#include <crypto/sha2.h>
 
 #define DM_VERITY_MAX_LEVELS		63
 
@@ -42,7 +43,10 @@ struct dm_verity {
 	struct crypto_shash *shash_tfm;
 	u8 *root_digest;	/* digest of the root block */
 	u8 *salt;		/* salt: its size is salt_size */
-	u8 *initial_hashstate;	/* salted initial state, if version >= 1 */
+	union {
+		struct sha256_ctx *sha256;	/* for use_sha256_lib=1 */
+		u8 *shash;			/* for use_sha256_lib=0 */
+	} initial_hashstate; /* salted initial state, if version >= 1 */
 	u8 *zero_digest;	/* digest for a zero block */
 #ifdef CONFIG_SECURITY
 	u8 *root_digest_sig;	/* signature of the root digest */
@@ -59,6 +63,8 @@ struct dm_verity {
 	unsigned char version;
 	bool hash_failed:1;	/* set if hash of any block failed */
 	bool use_bh_wq:1;	/* try to verify in BH wq before normal work-queue */
+	bool use_sha256_lib:1;	/* use SHA-256 library instead of generic crypto API */
+	bool use_sha256_finup_2x:1; /* use interleaved hashing optimization */
 	unsigned int digest_size;	/* digest size for the current hash algorithm */
 	enum verity_mode mode;	/* mode for handling verification errors */
 	enum verity_mode error_mode;/* mode for handling I/O errors */
@@ -78,6 +84,13 @@ struct dm_verity {
 	mempool_t recheck_pool;
 };
 
+struct pending_block {
+	void *data;
+	sector_t blkno;
+	u8 want_digest[HASH_MAX_DIGESTSIZE];
+	u8 real_digest[HASH_MAX_DIGESTSIZE];
+};
+
 struct dm_verity_io {
 	struct dm_verity *v;
 
@@ -94,28 +107,29 @@ struct dm_verity_io {
 	struct work_struct work;
 	struct work_struct bh_work;
 
-	u8 real_digest[HASH_MAX_DIGESTSIZE];
-	u8 want_digest[HASH_MAX_DIGESTSIZE];
+	u8 tmp_digest[HASH_MAX_DIGESTSIZE];
 
 	/*
-	 * Temporary space for hashing.  This is variable-length and must be at
-	 * the end of the struct.  struct shash_desc is just the fixed part;
-	 * it's followed by a context of size crypto_shash_descsize(shash_tfm).
+	 * This is the queue of data blocks that are pending verification.  When
+	 * the crypto layer supports interleaved hashing, we allow multiple
+	 * blocks to be queued up in order to utilize it.  This can improve
+	 * performance significantly vs. sequential hashing of each block.
 	 */
-	struct shash_desc hash_desc;
-};
+	int num_pending;
+	struct pending_block pending_blocks[2];
 
-static inline u8 *verity_io_real_digest(struct dm_verity *v,
-					struct dm_verity_io *io)
-{
-	return io->real_digest;
-}
-
-static inline u8 *verity_io_want_digest(struct dm_verity *v,
-					struct dm_verity_io *io)
-{
-	return io->want_digest;
-}
+	/*
+	 * Temporary space for hashing.  Either sha256 or shash is used,
+	 * depending on the value of use_sha256_lib.  If shash is used,
+	 * then this field is variable-length, with total size
+	 * sizeof(struct shash_desc) + crypto_shash_descsize(shash_tfm).
+	 * For this reason, this field must be the end of the struct.
+	 */
+	union {
+		struct sha256_ctx sha256;
+		struct shash_desc shash;
+	} hash_ctx;
+};
 
 extern int verity_hash(struct dm_verity *v, struct dm_verity_io *io,
 		       const u8 *data, size_t len, u8 *digest);
diff --git a/include/linux/rhashtable.h b/include/linux/rhashtable.h
index 05a221ce79a6..b87a768b955c 100644
--- a/include/linux/rhashtable.h
+++ b/include/linux/rhashtable.h
@@ -355,12 +355,25 @@ static inline void rht_unlock(struct bucket_table *tbl,
 	local_irq_restore(flags);
 }
 
-static inline struct rhash_head *__rht_ptr(
-	struct rhash_lock_head *p, struct rhash_lock_head __rcu *const *bkt)
+enum rht_lookup_freq {
+	RHT_LOOKUP_NORMAL,
+	RHT_LOOKUP_LIKELY,
+};
+
+static __always_inline struct rhash_head *__rht_ptr(
+	struct rhash_lock_head *p, struct rhash_lock_head __rcu *const *bkt,
+	const enum rht_lookup_freq freq)
 {
-	return (struct rhash_head *)
-		((unsigned long)p & ~BIT(0) ?:
-		 (unsigned long)RHT_NULLS_MARKER(bkt));
+	unsigned long p_val = (unsigned long)p & ~BIT(0);
+
+	BUILD_BUG_ON(!__builtin_constant_p(freq));
+
+	if (freq == RHT_LOOKUP_LIKELY)
+		return (struct rhash_head *)
+			(likely(p_val) ? p_val : (unsigned long)RHT_NULLS_MARKER(bkt));
+	else
+		return (struct rhash_head *)
+			(p_val ?: (unsigned long)RHT_NULLS_MARKER(bkt));
 }
 
 /*
@@ -370,10 +383,17 @@ static inline struct rhash_head *__rht_ptr(
  *   rht_ptr_exclusive() dereferences in a context where exclusive
  *            access is guaranteed, such as when destroying the table.
  */
+static __always_inline struct rhash_head *__rht_ptr_rcu(
+	struct rhash_lock_head __rcu *const *bkt,
+	const enum rht_lookup_freq freq)
+{
+	return __rht_ptr(rcu_dereference(*bkt), bkt, freq);
+}
+
 static inline struct rhash_head *rht_ptr_rcu(
 	struct rhash_lock_head __rcu *const *bkt)
 {
-	return __rht_ptr(rcu_dereference_all(*bkt), bkt);
+	return __rht_ptr_rcu(bkt, RHT_LOOKUP_NORMAL);
 }
 
 static inline struct rhash_head *rht_ptr(
@@ -381,13 +401,15 @@ static inline struct rhash_head *rht_ptr(
 	struct bucket_table *tbl,
 	unsigned int hash)
 {
-	return __rht_ptr(rht_dereference_bucket(*bkt, tbl, hash), bkt);
+	return __rht_ptr(rht_dereference_bucket(*bkt, tbl, hash), bkt,
+			 RHT_LOOKUP_NORMAL);
 }
 
 static inline struct rhash_head *rht_ptr_exclusive(
 	struct rhash_lock_head __rcu *const *bkt)
 {
-	return __rht_ptr(rcu_dereference_protected(*bkt, 1), bkt);
+	return __rht_ptr(rcu_dereference_protected(*bkt, 1), bkt,
+			 RHT_LOOKUP_NORMAL);
 }
 
 static inline void rht_assign_locked(struct rhash_lock_head __rcu **bkt,
@@ -588,7 +610,8 @@ static inline int rhashtable_compare(struct rhashtable_compare_arg *arg,
 /* Internal function, do not use. */
 static __always_inline struct rhash_head *__rhashtable_lookup(
 	struct rhashtable *ht, const void *key,
-	const struct rhashtable_params params)
+	const struct rhashtable_params params,
+	const enum rht_lookup_freq freq)
 {
 	struct rhashtable_compare_arg arg = {
 		.ht = ht,
@@ -599,12 +622,13 @@ static __always_inline struct rhash_head *__rhashtable_lookup(
 	struct rhash_head *he;
 	unsigned int hash;
 
+	BUILD_BUG_ON(!__builtin_constant_p(freq));
 	tbl = rht_dereference_rcu(ht->tbl, ht);
 restart:
 	hash = rht_key_hashfn(ht, tbl, key, params);
 	bkt = rht_bucket(tbl, hash);
 	do {
-		rht_for_each_rcu_from(he, rht_ptr_rcu(bkt), tbl, hash) {
+		rht_for_each_rcu_from(he, __rht_ptr_rcu(bkt, freq), tbl, hash) {
 			if (params.obj_cmpfn ?
 			    params.obj_cmpfn(&arg, rht_obj(ht, he)) :
 			    rhashtable_compare(&arg, rht_obj(ht, he)))
@@ -643,11 +667,22 @@ static __always_inline void *rhashtable_lookup(
 	struct rhashtable *ht, const void *key,
 	const struct rhashtable_params params)
 {
-	struct rhash_head *he = __rhashtable_lookup(ht, key, params);
+	struct rhash_head *he = __rhashtable_lookup(ht, key, params,
+						    RHT_LOOKUP_NORMAL);
 
 	return he ? rht_obj(ht, he) : NULL;
 }
 
+static __always_inline void *rhashtable_lookup_likely(
+	struct rhashtable *ht, const void *key,
+	const struct rhashtable_params params)
+{
+	struct rhash_head *he = __rhashtable_lookup(ht, key, params,
+						    RHT_LOOKUP_LIKELY);
+
+	return likely(he) ? rht_obj(ht, he) : NULL;
+}
+
 /**
  * rhashtable_lookup_fast - search hash table, without RCU read lock
  * @ht:		hash table
@@ -693,11 +728,22 @@ static __always_inline struct rhlist_head *rhltable_lookup(
 	struct rhltable *hlt, const void *key,
 	const struct rhashtable_params params)
 {
-	struct rhash_head *he = __rhashtable_lookup(&hlt->ht, key, params);
+	struct rhash_head *he = __rhashtable_lookup(&hlt->ht, key, params,
+						    RHT_LOOKUP_NORMAL);
 
 	return he ? container_of(he, struct rhlist_head, rhead) : NULL;
 }
 
+static __always_inline struct rhlist_head *rhltable_lookup_likely(
+	struct rhltable *hlt, const void *key,
+	const struct rhashtable_params params)
+{
+	struct rhash_head *he = __rhashtable_lookup(&hlt->ht, key, params,
+						    RHT_LOOKUP_LIKELY);
+
+	return likely(he) ? container_of(he, struct rhlist_head, rhead) : NULL;
+}
+
 /* Internal function, please use rhashtable_insert_fast() instead. This
  * function returns the existing element already in hashes if there is a clash,
  * otherwise it returns an error via ERR_PTR().
-- 
2.52.0

