# Copyright 2026 Gentoo Authors # Distributed under the terms of the GNU General Public License v2 DESCRIPTION="Manage PyTorch Python module backend (ROCm/CUDA)" VERSION="0.1" PYTORCHDIR="${EROOT}/usr/lib/pytorch" SITEDIR="${EROOT}/usr/lib/python3.14/site-packages" find_targets() { local targets=() local d for d in "${PYTORCHDIR}"/*/; do [[ -d "${d}" ]] && targets+=("$(basename "${d}")") done echo "${targets[@]}" } get_active() { if [[ -L "${SITEDIR}/torch" ]]; then local target target=$(readlink "${SITEDIR}/torch") target=${target#../../pytorch/} target=${target%%/*} echo "${target}" fi } remove_symlinks() { rm -f "${SITEDIR}/torch" rm -f "${SITEDIR}/functorch" rm -f "${SITEDIR}/torchgen" rm -f "${EROOT}/usr/bin/torchrun" } set_symlinks() { local target="${1}" local prefix="${PYTORCHDIR}/${target}" if [[ ! -d "${prefix}/torch" ]]; then die -q "Target '${target}' not found: ${prefix}/torch" fi ln -s "../../pytorch/${target}/torch" "${SITEDIR}/torch" [[ -d "${prefix}/functorch" ]] && \ ln -s "../../pytorch/${target}/functorch" "${SITEDIR}/functorch" [[ -d "${prefix}/torchgen" ]] && \ ln -s "../../pytorch/${target}/torchgen" "${SITEDIR}/torchgen" [[ -x "${prefix}/bin/torchrun" ]] && \ ln -s "../lib/pytorch/${target}/bin/torchrun" "${EROOT}/usr/bin/torchrun" } ### list action ### describe_list() { echo "List available PyTorch backends" } do_list() { local active active=$(get_active) write_list_start "Available PyTorch backends:" local d i=1 for d in "${PYTORCHDIR}"/*/; do [[ -d "${d}" ]] || continue local variant="$(basename "${d}")" if [[ "${variant}" == "${active}" ]]; then write_kv_list_entry "[${i}]" "${variant} *" else write_kv_list_entry "[${i}]" "${variant}" fi ((i++)) done } ### show action ### describe_show() { echo "Show currently active PyTorch backend" } do_show() { local active active=$(get_active) echo "${active:-"(unset)"}" } ### set action ### describe_set() { echo "Set active PyTorch backend" } describe_set_parameters() { echo "" } do_set() { [[ -z "${1}" ]] && die -q "You must specify a backend (rocm, cuda)" local target="${1}" if is_number "${target}"; then local targets=( $(find_targets) ) target="${targets[$((${1} - 1))]}" [[ -z "${target}" ]] && die -q "Invalid selection" fi remove_symlinks set_symlinks "${target}" }