From d504eee0e80385cecce09b190a035fdcee7ef043 Mon Sep 17 00:00:00 2001
From: Peter Jung <admin@ptr1337.dev>
Date: Wed, 8 Apr 2026 19:07:32 +0200
Subject: [PATCH] cgroup-vram

Signed-off-by: Peter Jung <admin@ptr1337.dev>
---
 drivers/gpu/drm/ttm/ttm_bo.c       | 221 +++++++++++++++++++++++++----
 drivers/gpu/drm/ttm/ttm_resource.c |  48 +++++--
 include/drm/ttm/ttm_resource.h     |   6 +-
 include/linux/cgroup.h             |  21 +++
 include/linux/cgroup_dmem.h        |  25 ++++
 kernel/cgroup/dmem.c               |  91 ++++++++++++
 6 files changed, 369 insertions(+), 43 deletions(-)

diff --git a/drivers/gpu/drm/ttm/ttm_bo.c b/drivers/gpu/drm/ttm/ttm_bo.c
index 0765d69423d2..73b19c5d06f7 100644
--- a/drivers/gpu/drm/ttm/ttm_bo.c
+++ b/drivers/gpu/drm/ttm/ttm_bo.c
@@ -489,6 +489,117 @@ int ttm_bo_evict_first(struct ttm_device *bdev, struct ttm_resource_manager *man
 	return ret;
 }
 
+struct ttm_bo_alloc_state {
+	/** @charge_pool: The memory pool the resource is charged to */
+	struct dmem_cgroup_pool_state *charge_pool;
+	/** @limit_pool: Which pool limit we should test against */
+	struct dmem_cgroup_pool_state *limit_pool;
+	/** @in_evict: Whether we are currently evicting buffers */
+	bool in_evict;
+	/** @may_try_low: If only unprotected BOs, i.e. BOs whose cgroup
+	 *  is exceeding its dmem low/min protection, should be considered for eviction
+	 */
+	bool may_try_low;
+};
+
+/**
+ * ttm_bo_alloc_at_place - Attempt allocating a BO's backing store in a place
+ *
+ * @bo: The buffer to allocate the backing store of
+ * @place: The place to attempt allocation in
+ * @ctx: ttm_operation_ctx associated with this allocation
+ * @force_space: If we should evict buffers to force space
+ * @res: On allocation success, the resulting struct ttm_resource.
+ * @alloc_state: Object holding allocation state such as charged cgroups.
+ *
+ * Returns:
+ * -EBUSY: No space available, but allocation should be retried with ttm_bo_evict_alloc.
+ * -ENOSPC: No space available, allocation should not be retried.
+ * -ERESTARTSYS: An interruptible sleep was interrupted by a signal.
+ *
+ */
+static int ttm_bo_alloc_at_place(struct ttm_buffer_object *bo,
+				 const struct ttm_place *place,
+				 struct ttm_operation_ctx *ctx,
+				 bool force_space,
+				 struct ttm_resource **res,
+				 struct ttm_bo_alloc_state *alloc_state)
+{
+	bool may_evict;
+	int ret;
+
+	may_evict = !alloc_state->in_evict && force_space &&
+		    place->mem_type != TTM_PL_SYSTEM;
+	if (!alloc_state->charge_pool) {
+		ret = ttm_resource_try_charge(bo, place, &alloc_state->charge_pool,
+					      force_space ? &alloc_state->limit_pool
+							  : NULL);
+		if (ret) {
+			/*
+			 * -EAGAIN means the charge failed, which we treat
+			 * like an allocation failure. Therefore, return an
+			 * error code indicating the allocation failed -
+			 * either -EBUSY if the allocation should be
+			 * retried with eviction, or -ENOSPC if there should
+			 * be no second attempt.
+			 */
+			if (ret == -EAGAIN)
+				ret = may_evict ? -EBUSY : -ENOSPC;
+			return ret;
+		}
+	}
+
+	/*
+	 * cgroup protection plays a special role in eviction.
+	 * Conceptually, protection of memory via the dmem cgroup controller
+	 * entitles the protected cgroup to use a certain amount of memory.
+	 * There are two types of protection - the 'low' limit is a
+	 * "best-effort" protection, whereas the 'min' limit provides a hard
+	 * guarantee that memory within the cgroup's allowance will not be
+	 * evicted under any circumstance.
+	 *
+	 * To faithfully model this concept in TTM, we also need to take cgroup
+	 * protection into account when allocating. When allocation in one
+	 * place fails, TTM will default to trying other places first before
+	 * evicting.
+	 * If the allocation is covered by dmem cgroup protection, however,
+	 * this prevents the allocation from using the memory it is "entitled"
+	 * to. To make sure unprotected allocations cannot push new protected
+	 * allocations out of places they are "entitled" to use, we should
+	 * evict buffers not covered by any cgroup protection, if this
+	 * allocation is covered by cgroup protection.
+	 *
+	 * Buffers covered by 'min' protection are a special case - the 'min'
+	 * limit is a stronger guarantee than 'low', and thus buffers protected
+	 * by 'low' but not 'min' should also be considered for eviction.
+	 * Buffers protected by 'min' will never be considered for eviction
+	 * anyway, so the regular eviction path should be triggered here.
+	 * Buffers protected by 'low' but not 'min' will take a special
+	 * eviction path that only evicts buffers covered by neither 'low' or
+	 * 'min' protections.
+	 */
+	if (!alloc_state->in_evict) {
+		may_evict |= dmem_cgroup_below_min(NULL, alloc_state->charge_pool);
+		alloc_state->may_try_low = may_evict;
+
+		may_evict |= dmem_cgroup_below_low(NULL, alloc_state->charge_pool);
+	}
+
+	ret = ttm_resource_alloc(bo, place, res, alloc_state->charge_pool);
+	if (ret) {
+		if (ret == -ENOSPC && may_evict)
+			ret = -EBUSY;
+		return ret;
+	}
+
+	/*
+	 * Ownership of charge_pool has been transferred to the TTM resource,
+	 * don't make the caller think we still hold a reference to it.
+	 */
+	alloc_state->charge_pool = NULL;
+	return 0;
+}
+
 /**
  * struct ttm_bo_evict_walk - Parameters for the evict walk.
  */
@@ -504,22 +615,61 @@ struct ttm_bo_evict_walk {
 	/** @evicted: Number of successful evictions. */
 	unsigned long evicted;
 
-	/** @limit_pool: Which pool limit we should test against */
-	struct dmem_cgroup_pool_state *limit_pool;
 	/** @try_low: Whether we should attempt to evict BO's with low watermark threshold */
 	bool try_low;
 	/** @hit_low: If we cannot evict a bo when @try_low is false (first pass) */
 	bool hit_low;
+
+	/** @alloc_state: State associated with the allocation attempt. */
+	struct ttm_bo_alloc_state *alloc_state;
 };
 
 static s64 ttm_bo_evict_cb(struct ttm_lru_walk *walk, struct ttm_buffer_object *bo)
 {
 	struct ttm_bo_evict_walk *evict_walk =
 		container_of(walk, typeof(*evict_walk), walk);
+	struct dmem_cgroup_pool_state *limit_pool, *ancestor = NULL;
+	bool evict_valuable;
 	s64 lret;
 
-	if (!dmem_cgroup_state_evict_valuable(evict_walk->limit_pool, bo->resource->css,
-					      evict_walk->try_low, &evict_walk->hit_low))
+	/*
+	 * If may_try_low is not set, then we're trying to evict unprotected
+	 * buffers in favor of a protected allocation for charge_pool. Explicitly skip
+	 * buffers belonging to the same cgroup here - that cgroup is definitely protected,
+	 * even though dmem_cgroup_state_evict_valuable would allow the eviction because a
+	 * cgroup is always allowed to evict from itself even if it is protected.
+	 */
+	if (!evict_walk->alloc_state->may_try_low &&
+			bo->resource->css == evict_walk->alloc_state->charge_pool)
+		return 0;
+
+	limit_pool = evict_walk->alloc_state->limit_pool;
+	/*
+	 * If there is no explicit limit pool, find the root of the shared subtree between
+	 * evictor and evictee. This is important so that recursive protection rules can
+	 * apply properly: Recursive protection distributes cgroup protection afforded
+	 * to a parent cgroup but not used explicitly by a child cgroup between all child
+	 * cgroups (see docs of effective_protection in mm/page_counter.c). However, when
+	 * direct siblings compete for memory, siblings that were explicitly protected
+	 * should get prioritized over siblings that weren't. This only happens correctly
+	 * when the root of the shared subtree is passed to
+	 * dmem_cgroup_state_evict_valuable. Otherwise, the effective-protection
+	 * calculation cannot distinguish direct siblings from unrelated subtrees and the
+	 * calculated protection ends up wrong.
+	 */
+	if (!limit_pool) {
+		ancestor = dmem_cgroup_get_common_ancestor(bo->resource->css,
+							   evict_walk->alloc_state->charge_pool);
+		limit_pool = ancestor;
+	}
+
+	evict_valuable = dmem_cgroup_state_evict_valuable(limit_pool, bo->resource->css,
+							  evict_walk->try_low,
+							  &evict_walk->hit_low);
+	if (ancestor)
+		dmem_cgroup_pool_state_put(ancestor);
+
+	if (!evict_valuable)
 		return 0;
 
 	if (bo->pin_count || !bo->bdev->funcs->eviction_valuable(bo, evict_walk->place))
@@ -538,8 +688,9 @@ static s64 ttm_bo_evict_cb(struct ttm_lru_walk *walk, struct ttm_buffer_object *
 
 	evict_walk->evicted++;
 	if (evict_walk->res)
-		lret = ttm_resource_alloc(evict_walk->evictor, evict_walk->place,
-					  evict_walk->res, NULL);
+		lret = ttm_bo_alloc_at_place(evict_walk->evictor, evict_walk->place,
+					     walk->arg.ctx, false, evict_walk->res,
+					     evict_walk->alloc_state);
 	if (lret == 0)
 		return 1;
 out:
@@ -561,7 +712,7 @@ static int ttm_bo_evict_alloc(struct ttm_device *bdev,
 			      struct ttm_operation_ctx *ctx,
 			      struct ww_acquire_ctx *ticket,
 			      struct ttm_resource **res,
-			      struct dmem_cgroup_pool_state *limit_pool)
+			      struct ttm_bo_alloc_state *state)
 {
 	struct ttm_bo_evict_walk evict_walk = {
 		.walk = {
@@ -574,15 +725,21 @@ static int ttm_bo_evict_alloc(struct ttm_device *bdev,
 		.place = place,
 		.evictor = evictor,
 		.res = res,
-		.limit_pool = limit_pool,
+		.alloc_state = state,
 	};
 	s64 lret;
 
+	state->in_evict = true;
+
 	evict_walk.walk.arg.trylock_only = true;
 	lret = ttm_lru_walk_for_evict(&evict_walk.walk, bdev, man, 1);
 
-	/* One more attempt if we hit low limit? */
-	if (!lret && evict_walk.hit_low) {
+	/* If we failed to find enough BOs to evict, but we skipped over
+	 * some BOs because they were covered by dmem low protection, retry
+	 * evicting these protected BOs too, except if we're told not to
+	 * consider protected BOs at all.
+	 */
+	if (!lret && evict_walk.hit_low && state->may_try_low) {
 		evict_walk.try_low = true;
 		lret = ttm_lru_walk_for_evict(&evict_walk.walk, bdev, man, 1);
 	}
@@ -603,11 +760,13 @@ static int ttm_bo_evict_alloc(struct ttm_device *bdev,
 	} while (!lret && evict_walk.evicted);
 
 	/* We hit the low limit? Try once more */
-	if (!lret && evict_walk.hit_low && !evict_walk.try_low) {
+	if (!lret && evict_walk.hit_low && !evict_walk.try_low &&
+			state->may_try_low) {
 		evict_walk.try_low = true;
 		goto retry;
 	}
 out:
+	state->in_evict = false;
 	if (lret < 0)
 		return lret;
 	if (lret == 0)
@@ -725,9 +884,8 @@ static int ttm_bo_alloc_resource(struct ttm_buffer_object *bo,
 
 	for (i = 0; i < placement->num_placement; ++i) {
 		const struct ttm_place *place = &placement->placement[i];
-		struct dmem_cgroup_pool_state *limit_pool = NULL;
+		struct ttm_bo_alloc_state alloc_state = {};
 		struct ttm_resource_manager *man;
-		bool may_evict;
 
 		man = ttm_manager_type(bdev, place->mem_type);
 		if (!man || !ttm_resource_manager_used(man))
@@ -737,25 +895,30 @@ static int ttm_bo_alloc_resource(struct ttm_buffer_object *bo,
 				    TTM_PL_FLAG_FALLBACK))
 			continue;
 
-		may_evict = (force_space && place->mem_type != TTM_PL_SYSTEM);
-		ret = ttm_resource_alloc(bo, place, res, force_space ? &limit_pool : NULL);
-		if (ret) {
-			if (ret != -ENOSPC && ret != -EAGAIN) {
-				dmem_cgroup_pool_state_put(limit_pool);
-				return ret;
-			}
-			if (!may_evict) {
-				dmem_cgroup_pool_state_put(limit_pool);
-				continue;
-			}
+		ret = ttm_bo_alloc_at_place(bo, place, ctx, force_space,
+				res, &alloc_state);
 
+		if (ret == -ENOSPC) {
+			dmem_cgroup_uncharge(alloc_state.charge_pool, bo->base.size);
+			dmem_cgroup_pool_state_put(alloc_state.limit_pool);
+			continue;
+		} else if (ret == -EBUSY) {
 			ret = ttm_bo_evict_alloc(bdev, man, place, bo, ctx,
-						 ticket, res, limit_pool);
-			dmem_cgroup_pool_state_put(limit_pool);
-			if (ret == -EBUSY)
-				continue;
-			if (ret)
+						 ticket, res, &alloc_state);
+
+			dmem_cgroup_pool_state_put(alloc_state.limit_pool);
+
+			if (ret) {
+				dmem_cgroup_uncharge(alloc_state.charge_pool,
+						bo->base.size);
+				if (ret == -EBUSY)
+					continue;
 				return ret;
+			}
+		} else if (ret) {
+			dmem_cgroup_uncharge(alloc_state.charge_pool, bo->base.size);
+			dmem_cgroup_pool_state_put(alloc_state.limit_pool);
+			return ret;
 		}
 
 		ret = ttm_bo_add_pipelined_eviction_fences(bo, man, ctx->no_wait_gpu);
diff --git a/drivers/gpu/drm/ttm/ttm_resource.c b/drivers/gpu/drm/ttm/ttm_resource.c
index 192fca24f37e..a8a836f6e376 100644
--- a/drivers/gpu/drm/ttm/ttm_resource.c
+++ b/drivers/gpu/drm/ttm/ttm_resource.c
@@ -373,30 +373,52 @@ void ttm_resource_fini(struct ttm_resource_manager *man,
 }
 EXPORT_SYMBOL(ttm_resource_fini);
 
+/**
+ * ttm_resource_try_charge - charge a resource manager's cgroup pool
+ * @bo: buffer for which an allocation should be charged
+ * @place: where the allocation is attempted to be placed
+ * @ret_pool: on charge success, the pool that was charged
+ * @ret_limit_pool: on charge failure, the pool responsible for the failure
+ *
+ * Should be used to charge cgroups before attempting resource allocation.
+ * When charging succeeds, the value of ret_pool should be passed to
+ * ttm_resource_alloc.
+ *
+ * Returns: 0 on charge success, negative errno on failure.
+ */
+int ttm_resource_try_charge(struct ttm_buffer_object *bo,
+			    const struct ttm_place *place,
+			    struct dmem_cgroup_pool_state **ret_pool,
+			    struct dmem_cgroup_pool_state **ret_limit_pool)
+{
+	struct ttm_resource_manager *man =
+		ttm_manager_type(bo->bdev, place->mem_type);
+
+	if (!man->cg) {
+		*ret_pool = NULL;
+		if (ret_limit_pool)
+			*ret_limit_pool = NULL;
+		return 0;
+	}
+
+	return dmem_cgroup_try_charge(man->cg, bo->base.size, ret_pool,
+				      ret_limit_pool);
+}
+
 int ttm_resource_alloc(struct ttm_buffer_object *bo,
 		       const struct ttm_place *place,
 		       struct ttm_resource **res_ptr,
-		       struct dmem_cgroup_pool_state **ret_limit_pool)
+		       struct dmem_cgroup_pool_state *charge_pool)
 {
 	struct ttm_resource_manager *man =
 		ttm_manager_type(bo->bdev, place->mem_type);
-	struct dmem_cgroup_pool_state *pool = NULL;
 	int ret;
 
-	if (man->cg) {
-		ret = dmem_cgroup_try_charge(man->cg, bo->base.size, &pool, ret_limit_pool);
-		if (ret)
-			return ret;
-	}
-
 	ret = man->func->alloc(man, bo, place, res_ptr);
-	if (ret) {
-		if (pool)
-			dmem_cgroup_uncharge(pool, bo->base.size);
+	if (ret)
 		return ret;
-	}
 
-	(*res_ptr)->css = pool;
+	(*res_ptr)->css = charge_pool;
 
 	spin_lock(&bo->bdev->lru_lock);
 	ttm_resource_add_bulk_move(*res_ptr, bo);
diff --git a/include/drm/ttm/ttm_resource.h b/include/drm/ttm/ttm_resource.h
index 33e80f30b8b8..549b5b796884 100644
--- a/include/drm/ttm/ttm_resource.h
+++ b/include/drm/ttm/ttm_resource.h
@@ -456,10 +456,14 @@ void ttm_resource_init(struct ttm_buffer_object *bo,
 void ttm_resource_fini(struct ttm_resource_manager *man,
 		       struct ttm_resource *res);
 
+int ttm_resource_try_charge(struct ttm_buffer_object *bo,
+			    const struct ttm_place *place,
+			    struct dmem_cgroup_pool_state **ret_pool,
+			    struct dmem_cgroup_pool_state **ret_limit_pool);
 int ttm_resource_alloc(struct ttm_buffer_object *bo,
 		       const struct ttm_place *place,
 		       struct ttm_resource **res,
-		       struct dmem_cgroup_pool_state **ret_limit_pool);
+		       struct dmem_cgroup_pool_state *charge_pool);
 void ttm_resource_free(struct ttm_buffer_object *bo, struct ttm_resource **res);
 bool ttm_resource_intersects(struct ttm_device *bdev,
 			     struct ttm_resource *res,
diff --git a/include/linux/cgroup.h b/include/linux/cgroup.h
index bc892e3b37ee..560ae995e3a5 100644
--- a/include/linux/cgroup.h
+++ b/include/linux/cgroup.h
@@ -561,6 +561,27 @@ static inline struct cgroup *cgroup_ancestor(struct cgroup *cgrp,
 	return cgrp->ancestors[ancestor_level];
 }
 
+/**
+ * cgroup_common_ancestor - find common ancestor of two cgroups
+ * @a: first cgroup to find common ancestor of
+ * @b: second cgroup to find common ancestor of
+ *
+ * Find the first cgroup that is an ancestor of both @a and @b, if it exists
+ * and return a pointer to it. If such a cgroup doesn't exist, return NULL.
+ *
+ * This function is safe to call as long as both @a and @b are accessible.
+ */
+static inline struct cgroup *cgroup_common_ancestor(struct cgroup *a,
+						    struct cgroup *b)
+{
+	int level;
+
+	for (level = min(a->level, b->level); level >= 0; level--)
+		if (a->ancestors[level] == b->ancestors[level])
+			return a->ancestors[level];
+	return NULL;
+}
+
 /**
  * task_under_cgroup_hierarchy - test task's membership of cgroup ancestry
  * @task: the task to be tested
diff --git a/include/linux/cgroup_dmem.h b/include/linux/cgroup_dmem.h
index dd4869f1d736..9d72457c4cb9 100644
--- a/include/linux/cgroup_dmem.h
+++ b/include/linux/cgroup_dmem.h
@@ -24,6 +24,12 @@ void dmem_cgroup_uncharge(struct dmem_cgroup_pool_state *pool, u64 size);
 bool dmem_cgroup_state_evict_valuable(struct dmem_cgroup_pool_state *limit_pool,
 				      struct dmem_cgroup_pool_state *test_pool,
 				      bool ignore_low, bool *ret_hit_low);
+bool dmem_cgroup_below_min(struct dmem_cgroup_pool_state *root,
+			   struct dmem_cgroup_pool_state *test);
+bool dmem_cgroup_below_low(struct dmem_cgroup_pool_state *root,
+			   struct dmem_cgroup_pool_state *test);
+struct dmem_cgroup_pool_state *dmem_cgroup_get_common_ancestor(struct dmem_cgroup_pool_state *a,
+							       struct dmem_cgroup_pool_state *b);
 
 void dmem_cgroup_pool_state_put(struct dmem_cgroup_pool_state *pool);
 #else
@@ -59,6 +65,25 @@ bool dmem_cgroup_state_evict_valuable(struct dmem_cgroup_pool_state *limit_pool,
 	return true;
 }
 
+static inline bool dmem_cgroup_below_min(struct dmem_cgroup_pool_state *root,
+					 struct dmem_cgroup_pool_state *test)
+{
+	return false;
+}
+
+static inline bool dmem_cgroup_below_low(struct dmem_cgroup_pool_state *root,
+					 struct dmem_cgroup_pool_state *test)
+{
+	return false;
+}
+
+static inline
+struct dmem_cgroup_pool_state *dmem_cgroup_get_common_ancestor(struct dmem_cgroup_pool_state *a,
+							       struct dmem_cgroup_pool_state *b)
+{
+	return NULL;
+}
+
 static inline void dmem_cgroup_pool_state_put(struct dmem_cgroup_pool_state *pool)
 { }
 
diff --git a/kernel/cgroup/dmem.c b/kernel/cgroup/dmem.c
index 9d95824dc6fa..29197d3801ac 100644
--- a/kernel/cgroup/dmem.c
+++ b/kernel/cgroup/dmem.c
@@ -694,6 +694,97 @@ int dmem_cgroup_try_charge(struct dmem_cgroup_region *region, u64 size,
 }
 EXPORT_SYMBOL_GPL(dmem_cgroup_try_charge);
 
+/**
+ * dmem_cgroup_below_min() - Tests whether current usage is within min limit.
+ *
+ * @root: Root of the subtree to calculate protection for, or NULL to calculate global protection.
+ * @test: The pool to test the usage/min limit of.
+ *
+ * Return: true if usage is below min and the cgroup is protected, false otherwise.
+ */
+bool dmem_cgroup_below_min(struct dmem_cgroup_pool_state *root,
+			   struct dmem_cgroup_pool_state *test)
+{
+	if (root == test || !pool_parent(test))
+		return false;
+
+	if (!root) {
+		for (root = test; pool_parent(root); root = pool_parent(root))
+			{}
+	}
+
+	/*
+	 * In mem_cgroup_below_min(), the memcg pendant, this call is missing.
+	 * mem_cgroup_below_min() gets called during traversal of the cgroup tree, where
+	 * protection is already calculated as part of the traversal. dmem cgroup eviction
+	 * does not traverse the cgroup tree, so we need to recalculate effective protection
+	 * here.
+	 */
+	dmem_cgroup_calculate_protection(root, test);
+	return page_counter_read(&test->cnt) <= READ_ONCE(test->cnt.emin);
+}
+EXPORT_SYMBOL_GPL(dmem_cgroup_below_min);
+
+/**
+ * dmem_cgroup_below_low() - Tests whether current usage is within low limit.
+ *
+ * @root: Root of the subtree to calculate protection for, or NULL to calculate global protection.
+ * @test: The pool to test the usage/low limit of.
+ *
+ * Return: true if usage is below low and the cgroup is protected, false otherwise.
+ */
+bool dmem_cgroup_below_low(struct dmem_cgroup_pool_state *root,
+			   struct dmem_cgroup_pool_state *test)
+{
+	if (root == test || !pool_parent(test))
+		return false;
+
+	if (!root) {
+		for (root = test; pool_parent(root); root = pool_parent(root))
+			{}
+	}
+
+	/*
+	 * In mem_cgroup_below_low(), the memcg pendant, this call is missing.
+	 * mem_cgroup_below_low() gets called during traversal of the cgroup tree, where
+	 * protection is already calculated as part of the traversal. dmem cgroup eviction
+	 * does not traverse the cgroup tree, so we need to recalculate effective protection
+	 * here.
+	 */
+	dmem_cgroup_calculate_protection(root, test);
+	return page_counter_read(&test->cnt) <= READ_ONCE(test->cnt.elow);
+}
+EXPORT_SYMBOL_GPL(dmem_cgroup_below_low);
+
+/**
+ * dmem_cgroup_get_common_ancestor(): Find the first common ancestor of two pools.
+ * @a: First pool to find the common ancestor of.
+ * @b: First pool to find the common ancestor of.
+ *
+ * Return: The first pool that is a parent of both @a and @b, or NULL if either @a or @b are NULL,
+ * or if such a pool does not exist. A reference to the returned pool is grabbed and must be
+ * released by the caller when it is done using the pool.
+ */
+struct dmem_cgroup_pool_state *dmem_cgroup_get_common_ancestor(struct dmem_cgroup_pool_state *a,
+							       struct dmem_cgroup_pool_state *b)
+{
+	struct cgroup *ancestor_cgroup;
+	struct cgroup_subsys_state *ancestor_css;
+
+	if (!a || !b)
+		return NULL;
+
+	ancestor_cgroup = cgroup_common_ancestor(a->cs->css.cgroup, b->cs->css.cgroup);
+	if (!ancestor_cgroup)
+		return NULL;
+
+	ancestor_css = cgroup_e_css(ancestor_cgroup, &dmem_cgrp_subsys);
+	css_get(ancestor_css);
+
+	return get_cg_pool_unlocked(css_to_dmemcs(ancestor_css), a->region);
+}
+EXPORT_SYMBOL_GPL(dmem_cgroup_get_common_ancestor);
+
 static int dmem_cgroup_region_capacity_show(struct seq_file *sf, void *v)
 {
 	struct dmem_cgroup_region *region;
-- 
2.53.0

