From e9a8ec55c2838418aae92daf75d7e6f4d317764c Mon Sep 17 00:00:00 2001
From: Christian Heusel <christian@heusel.eu>
Date: Mon, 27 Oct 2025 02:07:31 +0100
Subject: [PATCH] build: Add compatibility with LLVM 21 (#2030)

This adds LLVM 21 compatibility by adding the relevant code changes
guarded by `#ifdev`-statements for backwards compatibility.

Avoid deprecated LLVM calls for 21+

Additionally also add CI checks for this.

---------

Signed-off-by: Christian Heusel <christian@heusel.eu>
Signed-off-by: Larry Gritz <lg@larrygritz.com>
Co-authored-by: Larry Gritz <lg@larrygritz.com>
---
 .github/workflows/build-steps.yml |   3 +-
 .github/workflows/ci.yml          |   3 +-
 INSTALL.md                        |   2 +-
 src/cmake/externalpackages.cmake  |   2 +-
 src/liboslcomp/oslcomp.cpp        |  13 ++++
 src/liboslexec/llvm_instance.cpp  |   4 ++
 src/liboslexec/llvm_util.cpp      | 101 ++++++++++++++++++++++--------
 7 files changed, 96 insertions(+), 32 deletions(-)

diff --git a/src/cmake/externalpackages.cmake b/src/cmake/externalpackages.cmake
index 59a992d376..3906b8dec1 100644
--- a/src/cmake/externalpackages.cmake
+++ b/src/cmake/externalpackages.cmake
@@ -58,7 +58,7 @@ checked_find_package (pugixml REQUIRED
 # LLVM library setup
 checked_find_package (LLVM REQUIRED
                       VERSION_MIN 11.0
-                      VERSION_MAX 20.9
+                      VERSION_MAX 21.9
                       PRINT LLVM_SYSTEM_LIBRARIES CLANG_LIBRARIES
                             LLVM_SHARED_MODE)
 # ensure include directory is added (in case of non-standard locations
diff --git a/src/liboslcomp/oslcomp.cpp b/src/liboslcomp/oslcomp.cpp
index 14bee9a1fc..027f47d01a 100644
--- a/src/liboslcomp/oslcomp.cpp
+++ b/src/liboslcomp/oslcomp.cpp
@@ -171,19 +171,32 @@ OSLCompilerImpl::preprocess_buffer(const std::string& buffer,
     llvm::raw_string_ostream errstream(preproc_errors);
     clang::DiagnosticOptions* diagOptions = new clang::DiagnosticOptions();
     clang::TextDiagnosticPrinter* diagPrinter
+#if OSL_LLVM_VERSION < 210
         = new clang::TextDiagnosticPrinter(errstream, diagOptions);
+#else
+        = new clang::TextDiagnosticPrinter(errstream, *diagOptions);
+#endif
     llvm::IntrusiveRefCntPtr<clang::DiagnosticIDs> diagIDs(
         new clang::DiagnosticIDs);
     clang::DiagnosticsEngine* diagEngine
+#if OSL_LLVM_VERSION < 210
         = new clang::DiagnosticsEngine(diagIDs, diagOptions, diagPrinter);
+#else
+        = new clang::DiagnosticsEngine(diagIDs, *diagOptions, diagPrinter);
+#endif
     inst.setDiagnostics(diagEngine);
 
     const std::shared_ptr<clang::TargetOptions> targetopts
         = std::make_shared<clang::TargetOptions>(inst.getTargetOpts());
     targetopts->Triple = llvm::sys::getDefaultTargetTriple();
     clang::TargetInfo* target
+#if OSL_LLVM_VERSION < 210
         = clang::TargetInfo::CreateTargetInfo(inst.getDiagnostics(),
                                               targetopts);
+#else
+        = clang::TargetInfo::CreateTargetInfo(inst.getDiagnostics(),
+                                              *targetopts);
+#endif
 
     inst.setTarget(target);
 
diff --git a/src/liboslexec/llvm_instance.cpp b/src/liboslexec/llvm_instance.cpp
index 974f95b173..b5526f9fa1 100644
--- a/src/liboslexec/llvm_instance.cpp
+++ b/src/liboslexec/llvm_instance.cpp
@@ -2225,7 +2225,11 @@ BackendLLVM::run()
             // The triple is empty with recent versions of LLVM (e.g., 15) for reasons that aren't
             // clear. So we must set them to the expected values.
             // See: https://llvm.org/docs/NVPTXUsage.html
+#    if OSL_LLVM_VERSION < 210
             ll.module()->setTargetTriple("nvptx64-nvidia-cuda");
+#    else
+            ll.module()->setTargetTriple(llvm::Triple("nvptx64-nvidia-cuda"));
+#    endif
             ll.module()->setDataLayout(
                 "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64");
 
diff --git a/src/liboslexec/llvm_util.cpp b/src/liboslexec/llvm_util.cpp
index 2d95256759..4cda958ea9 100644
--- a/src/liboslexec/llvm_util.cpp
+++ b/src/liboslexec/llvm_util.cpp
@@ -476,27 +476,13 @@ LLVM_Util::LLVM_Util(const PerThreadInfo& per_thread_info, int debuglevel,
     m_llvm_type_longlong = (llvm::Type*)llvm::Type::getInt64Ty(*m_llvm_context);
     m_llvm_type_void     = (llvm::Type*)llvm::Type::getVoidTy(*m_llvm_context);
 
-    m_llvm_type_int_ptr      = llvm::PointerType::get(m_llvm_type_int, 0);
-    m_llvm_type_int8_ptr     = llvm::PointerType::get(m_llvm_type_int8, 0);
-    m_llvm_type_int64_ptr    = llvm::PointerType::get(m_llvm_type_int64, 0);
-    m_llvm_type_bool_ptr     = llvm::PointerType::get(m_llvm_type_bool, 0);
-    m_llvm_type_char_ptr     = llvm::PointerType::get(m_llvm_type_char, 0);
-    m_llvm_type_void_ptr     = m_llvm_type_char_ptr;
-    m_llvm_type_float_ptr    = llvm::PointerType::get(m_llvm_type_float, 0);
-    m_llvm_type_longlong_ptr = llvm::PointerType::get(m_llvm_type_int64, 0);
-    m_llvm_type_double_ptr   = llvm::PointerType::get(m_llvm_type_double, 0);
-
     // A triple is a struct composed of 3 floats
     std::vector<llvm::Type*> triplefields(3, m_llvm_type_float);
     m_llvm_type_triple = type_struct(triplefields, "Vec3");
-    m_llvm_type_triple_ptr
-        = (llvm::PointerType*)llvm::PointerType::get(m_llvm_type_triple, 0);
 
     // A matrix is a struct composed 16 floats
     std::vector<llvm::Type*> matrixfields(16, m_llvm_type_float);
     m_llvm_type_matrix = type_struct(matrixfields, "Matrix4");
-    m_llvm_type_matrix_ptr
-        = (llvm::PointerType*)llvm::PointerType::get(m_llvm_type_matrix, 0);
 
     // Setup up wide aliases
     // TODO:  why are there casts to the base class llvm::Type *?
@@ -511,6 +497,48 @@ LLVM_Util::LLVM_Util(const PerThreadInfo& per_thread_info, int debuglevel,
     m_llvm_type_wide_longlong = llvm_vector_type(m_llvm_type_longlong,
                                                  m_vector_width);
 
+    // A twide riple is a struct composed of 3 wide floats
+    std::vector<llvm::Type*> triple_wide_fields(3, m_llvm_type_wide_float);
+    m_llvm_type_wide_triple = type_struct(triple_wide_fields, "WideVec3");
+
+    // A wide matrix is a struct composed 16 wide floats
+    std::vector<llvm::Type*> matrix_wide_fields(16, m_llvm_type_wide_float);
+    m_llvm_type_wide_matrix = type_struct(matrix_wide_fields, "WideMatrix4");
+
+#if OSL_LLVM_VERSION >= 210
+    // All opaque pointers now. Eventually, all the typed ones can go away.
+    m_llvm_type_void_ptr       = llvm::PointerType::get(*m_llvm_context, 0);
+    m_llvm_type_int_ptr        = m_llvm_type_void_ptr;
+    m_llvm_type_int8_ptr       = m_llvm_type_void_ptr;
+    m_llvm_type_int64_ptr      = m_llvm_type_void_ptr;
+    m_llvm_type_bool_ptr       = m_llvm_type_void_ptr;
+    m_llvm_type_char_ptr       = m_llvm_type_void_ptr;
+    m_llvm_type_float_ptr      = m_llvm_type_void_ptr;
+    m_llvm_type_longlong_ptr   = m_llvm_type_void_ptr;
+    m_llvm_type_double_ptr     = m_llvm_type_void_ptr;
+    m_llvm_type_triple_ptr     = m_llvm_type_void_ptr;
+    m_llvm_type_matrix_ptr     = m_llvm_type_void_ptr;
+    m_llvm_type_wide_char_ptr  = m_llvm_type_void_ptr;
+    m_llvm_type_wide_void_ptr  = m_llvm_type_void_ptr;
+    m_llvm_type_wide_int_ptr   = m_llvm_type_void_ptr;
+    m_llvm_type_wide_bool_ptr  = m_llvm_type_void_ptr;
+    m_llvm_type_wide_float_ptr = m_llvm_type_void_ptr;
+#else
+    // Old style typed pointers. These are marked as deprecated in LLVM 21,
+    // and will be removed in some subsequent version.
+    m_llvm_type_int_ptr      = llvm::PointerType::get(m_llvm_type_int, 0);
+    m_llvm_type_int8_ptr     = llvm::PointerType::get(m_llvm_type_int8, 0);
+    m_llvm_type_int64_ptr    = llvm::PointerType::get(m_llvm_type_int64, 0);
+    m_llvm_type_bool_ptr     = llvm::PointerType::get(m_llvm_type_bool, 0);
+    m_llvm_type_char_ptr     = llvm::PointerType::get(m_llvm_type_char, 0);
+    m_llvm_type_void_ptr     = m_llvm_type_char_ptr;
+    m_llvm_type_float_ptr    = llvm::PointerType::get(m_llvm_type_float, 0);
+    m_llvm_type_longlong_ptr = llvm::PointerType::get(m_llvm_type_int64, 0);
+    m_llvm_type_double_ptr   = llvm::PointerType::get(m_llvm_type_double, 0);
+    m_llvm_type_triple_ptr
+        = (llvm::PointerType*)llvm::PointerType::get(m_llvm_type_triple, 0);
+    m_llvm_type_matrix_ptr
+        = (llvm::PointerType*)llvm::PointerType::get(m_llvm_type_matrix, 0);
     m_llvm_type_wide_char_ptr = llvm::PointerType::get(m_llvm_type_wide_char,
                                                        0);
     m_llvm_type_wide_void_ptr = llvm_vector_type(m_llvm_type_void_ptr,
@@ -520,14 +548,7 @@ LLVM_Util::LLVM_Util(const PerThreadInfo& per_thread_info, int debuglevel,
                                                        0);
     m_llvm_type_wide_float_ptr = llvm::PointerType::get(m_llvm_type_wide_float,
                                                         0);
-
-    // A triple is a struct composed of 3 floats
-    std::vector<llvm::Type*> triple_wide_fields(3, m_llvm_type_wide_float);
-    m_llvm_type_wide_triple = type_struct(triple_wide_fields, "WideVec3");
-
-    // A matrix is a struct composed 16 floats
-    std::vector<llvm::Type*> matrix_wide_fields(16, m_llvm_type_wide_float);
-    m_llvm_type_wide_matrix = type_struct(matrix_wide_fields, "WideMatrix4");
+#endif
 
     ustring_rep(m_ustring_rep);  // setup ustring-related types
 }
@@ -545,14 +566,20 @@ LLVM_Util::ustring_rep(UstringRep rep)
         OSL_ASSERT(m_ustring_rep == UstringRep::hash);
         m_llvm_type_ustring = llvm::Type::getInt64Ty(*m_llvm_context);
     }
-    m_llvm_type_ustring_ptr = llvm::PointerType::get(m_llvm_type_ustring, 0);
 
     // Batched versions haven't been updated to handle hash yet.
     // For now leave them using the real ustring regardless of UstringRep
     m_llvm_type_wide_ustring = llvm_vector_type(m_llvm_type_real_ustring,
                                                 m_vector_width);
+
+#if OSL_LLVM_VERSION >= 210
+    m_llvm_type_ustring_ptr      = m_llvm_type_void_ptr;
+    m_llvm_type_wide_ustring_ptr = m_llvm_type_void_ptr;
+#else
+    m_llvm_type_ustring_ptr = llvm::PointerType::get(m_llvm_type_ustring, 0);
     m_llvm_type_wide_ustring_ptr
         = llvm::PointerType::get(m_llvm_type_wide_ustring, 0);
+#endif
 }
 
 
@@ -1790,8 +1817,13 @@ LLVM_Util::nvptx_target_machine()
                    && "PTX compile error: LLVM Target is not initialized");
 
         m_nvptx_target_machine = llvm_target->createTargetMachine(
-            ModuleTriple.str(), CUDA_TARGET_ARCH, "+ptx50", options,
-            llvm::Reloc::Static, llvm::CodeModel::Small,
+#if OSL_LLVM_VERSION >= 210
+            llvm::Triple(ModuleTriple.str()),
+#else
+            ModuleTriple.str(),
+#endif
+            CUDA_TARGET_ARCH, "+ptx50", options, llvm::Reloc::Static,
+            llvm::CodeModel::Small,
 #if OSL_LLVM_VERSION >= 180
             llvm::CodeGenOptLevel::Default
 #else
@@ -2911,7 +2943,11 @@ LLVM_Util::type_struct_field_at_index(llvm::Type* type, int index)
 llvm::PointerType*
 LLVM_Util::type_ptr(llvm::Type* type)
 {
+#if OSL_LLVM_VERSION >= 210
+    return m_llvm_type_void_ptr;
+#else
     return llvm::PointerType::get(type, 0);
+#endif
 }
 
 llvm::Type*
@@ -2959,8 +2995,12 @@ llvm::PointerType*
 LLVM_Util::type_function_ptr(llvm::Type* rettype, cspan<llvm::Type*> params,
                              bool varargs)
 {
+#if OSL_LLVM_VERSION >= 210
+    return m_llvm_type_void_ptr;
+#else
     llvm::FunctionType* functype = type_function(rettype, params, varargs);
     return llvm::PointerType::getUnqual(functype);
+#endif
 }
 
 
@@ -3784,8 +3824,7 @@ llvm::Value*
 LLVM_Util::ptr_to_cast(llvm::Value* val, llvm::Type* type,
                        const std::string& llname)
 {
-    return builder().CreatePointerCast(val, llvm::PointerType::get(type, 0),
-                                       llname);
+    return builder().CreatePointerCast(val, type_ptr(type), llname);
 }
 
 
@@ -3803,14 +3842,22 @@ llvm::Value*
 LLVM_Util::ptr_cast(llvm::Value* val, const TypeDesc& type,
                     const std::string& llname)
 {
+#if OSL_LLVM_VERSION >= 210
+    return ptr_cast(val, m_llvm_type_void_ptr, llname);
+#else
     return ptr_cast(val, llvm::PointerType::get(llvm_type(type), 0), llname);
+#endif
 }
 
 
 llvm::Value*
 LLVM_Util::wide_ptr_cast(llvm::Value* val, const TypeDesc& type)
 {
+#if OSL_LLVM_VERSION >= 210
+    return ptr_cast(val, m_llvm_type_void_ptr);
+#else
     return ptr_cast(val, llvm::PointerType::get(llvm_vector_type(type), 0));
+#endif
 }
 
 
