Skip _vllm_fa3_C target body when FA3_ARCHS is empty.

vllm-flash-attn intersects "9.0a;" (Hopper) with CUDA_ARCHS to compute
FA3_ARCHS, then clears the per-file gencode flags when FA3_ARCHS is
empty.  But the define_gpu_extension_target(_vllm_fa3_C ...) call adds
all FA3 .cu files to the build target unconditionally, so nvcc compiles
them anyway with its default arch — wasting 30-60 minutes of build time
on Ampere/Ada/older systems for kernels that can't run on the GPU.

Wrap the real target-definition block in `if(FA3_ARCHS)` and add an
`add_custom_target(_vllm_fa3_C)` empty stub in the else-branch — same
pattern as cmake/external_projects/deepgemm.cmake's "DeepGEMM will not
compile" path.  vllm's setup.py drives the build via explicit ninja
`--target=_vllm_fa3_C` regardless of arch, so the target needs to
exist (even as a no-op) or the build aborts.

Tested against vllm-project/flash-attention @ f5bc33c (pin from vllm-0.21.0).

--- CMakeLists.txt.orig	2026-05-17 00:34:19.601587314 +0200
+++ CMakeLists.txt	2026-05-17 01:05:47.579431319 +0200
@@ -267,45 +267,52 @@
             CUDA_ARCHS "${FA3_ARCHS}")
     endif()
 
-    define_gpu_extension_target(
-        _vllm_fa3_C
-        DESTINATION vllm_flash_attn
-        LANGUAGE ${VLLM_GPU_LANG}
-        SOURCES
-            hopper/flash_fwd_combine.cu
-            hopper/flash_prepare_scheduler.cu
-            hopper/flash_api.cpp
-            hopper/flash_api_torch_lib.cpp
-            ${FA3_GEN_SRCS}
-        COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
-        ARCHITECTURES "" # LucasW: this is ignored for cuda and set on a per-file basis
-        USE_SABI 3
-        WITH_SOABI)
+    if(FA3_ARCHS)
+        define_gpu_extension_target(
+            _vllm_fa3_C
+            DESTINATION vllm_flash_attn
+            LANGUAGE ${VLLM_GPU_LANG}
+            SOURCES
+                hopper/flash_fwd_combine.cu
+                hopper/flash_prepare_scheduler.cu
+                hopper/flash_api.cpp
+                hopper/flash_api_torch_lib.cpp
+                ${FA3_GEN_SRCS}
+            COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
+            ARCHITECTURES "" # LucasW: this is ignored for cuda and set on a per-file basis
+            USE_SABI 3
+            WITH_SOABI)
 
-    target_include_directories(_vllm_fa3_C PRIVATE
-        hopper
-        csrc/common
-        csrc/cutlass/include)
+        target_include_directories(_vllm_fa3_C PRIVATE
+            hopper
+            csrc/common
+            csrc/cutlass/include)
 
-    # custom definitions
-    target_compile_definitions(_vllm_fa3_C PRIVATE
-        FLASHATTENTION_DISABLE_BACKWARD
-        FLASHATTENTION_DISABLE_DROPOUT
-        # FLASHATTENTION_DISABLE_ALIBI
-        # FLASHATTENTION_DISABLE_SOFTCAP
-        FLASHATTENTION_DISABLE_UNEVEN_K
-        # FLASHATTENTION_DISABLE_LOCAL
-        FLASHATTENTION_DISABLE_PYBIND
-        FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
-        FLASHATTENTION_PACKGQA_ONLY # Custom flag to save on binary size
-        FLASHATTENTION_DISABLE_CLUSTER # disabled for varlen in any case
-        FLASHATTENTION_DISABLE_SM8x
-        # FLASHATTENTION_DISABLE_HDIMDIFF64
-        # FLASHATTENTION_DISABLE_HDIMDIFF192
-        FLASHATTENTION_DISABLE_APPENDKV
-        CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED
-        CUTLASS_ENABLE_GDC_FOR_SM90
-    )
+        # custom definitions
+        target_compile_definitions(_vllm_fa3_C PRIVATE
+            FLASHATTENTION_DISABLE_BACKWARD
+            FLASHATTENTION_DISABLE_DROPOUT
+            # FLASHATTENTION_DISABLE_ALIBI
+            # FLASHATTENTION_DISABLE_SOFTCAP
+            FLASHATTENTION_DISABLE_UNEVEN_K
+            # FLASHATTENTION_DISABLE_LOCAL
+            FLASHATTENTION_DISABLE_PYBIND
+            FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
+            FLASHATTENTION_PACKGQA_ONLY # Custom flag to save on binary size
+            FLASHATTENTION_DISABLE_CLUSTER # disabled for varlen in any case
+            FLASHATTENTION_DISABLE_SM8x
+            # FLASHATTENTION_DISABLE_HDIMDIFF64
+            # FLASHATTENTION_DISABLE_HDIMDIFF192
+            FLASHATTENTION_DISABLE_APPENDKV
+            CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED
+            CUTLASS_ENABLE_GDC_FOR_SM90
+        )
+    else()
+        message(STATUS "FA3 will not compile: no Hopper arch in CUDA_ARCHS=${CUDA_ARCHS}")
+        # Empty stub so vllm setup.py's `ninja --target=_vllm_fa3_C` succeeds.
+        # Same pattern as cmake/external_projects/deepgemm.cmake.
+        add_custom_target(_vllm_fa3_C)
+    endif()
 elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
     message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
 endif ()
