diff --git a/allianceauth/authentication/core/__init__.py b/allianceauth/authentication/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/allianceauth/authentication/core/celery_workers.py b/allianceauth/authentication/core/celery_workers.py new file mode 100644 index 00000000..6da60924 --- /dev/null +++ b/allianceauth/authentication/core/celery_workers.py @@ -0,0 +1,48 @@ +"""API for interacting with celery workers.""" + +import itertools +import logging +from typing import Optional + +from amqp.exceptions import ChannelError +from celery import current_app + +from django.conf import settings + +logger = logging.getLogger(__name__) + + +def active_tasks_count() -> Optional[int]: + """Return count of currently active tasks + or None if celery workers are not online. + """ + inspect = current_app.control.inspect() + return _tasks_count(inspect.active()) + + +def _tasks_count(data: dict) -> Optional[int]: + """Return count of tasks in data from celery inspect API.""" + try: + tasks = itertools.chain(*data.values()) + except AttributeError: + return None + return len(list(tasks)) + + +def queued_tasks_count() -> Optional[int]: + """Return count of queued tasks. Return None if there was an error.""" + try: + with current_app.connection_or_acquire() as conn: + result = conn.default_channel.queue_declare( + queue=getattr(settings, "CELERY_DEFAULT_QUEUE", "celery"), passive=True + ) + return result.message_count + + except ChannelError: + # Queue doesn't exist, probably empty + return 0 + + except Exception: + logger.exception("Failed to get celery queue length") + + return None diff --git a/allianceauth/authentication/task_statistics/counters.py b/allianceauth/authentication/task_statistics/counters.py index 06e2af83..bdb6d034 100644 --- a/allianceauth/authentication/task_statistics/counters.py +++ b/allianceauth/authentication/task_statistics/counters.py @@ -4,13 +4,11 @@ import datetime as dt from typing import NamedTuple, Optional from .event_series import EventSeries -from .helpers import ItemCounter # Global series for counting task events. succeeded_tasks = EventSeries("SUCCEEDED_TASKS") retried_tasks = EventSeries("RETRIED_TASKS") failed_tasks = EventSeries("FAILED_TASKS") -running_tasks = ItemCounter("running_tasks") class _TaskCounts(NamedTuple): @@ -20,7 +18,6 @@ class _TaskCounts(NamedTuple): total: int earliest_task: Optional[dt.datetime] hours: int - running: int def dashboard_results(hours: int) -> _TaskCounts: @@ -38,7 +35,6 @@ def dashboard_results(hours: int) -> _TaskCounts: earliest_events += earliest_if_exists(retried_tasks, earliest) failed_count = failed_tasks.count(earliest=earliest) earliest_events += earliest_if_exists(failed_tasks, earliest) - running_count = running_tasks.value() return _TaskCounts( succeeded=succeeded_count, retried=retried_count, @@ -46,5 +42,4 @@ def dashboard_results(hours: int) -> _TaskCounts: total=succeeded_count + retried_count + failed_count, earliest_task=min(earliest_events) if earliest_events else None, hours=hours, - running=running_count, ) diff --git a/allianceauth/authentication/task_statistics/helpers.py b/allianceauth/authentication/task_statistics/helpers.py index b75fb39c..464cee8f 100644 --- a/allianceauth/authentication/task_statistics/helpers.py +++ b/allianceauth/authentication/task_statistics/helpers.py @@ -1,12 +1,9 @@ """Helpers for Task Statistics.""" import logging -from typing import Optional from redis import Redis, RedisError -from django.core.cache import cache - from allianceauth.utils.cache import get_redis_client logger = logging.getLogger(__name__) @@ -37,62 +34,6 @@ class _RedisStub: pass -class ItemCounter: - """A process safe item counter. - - Args: - - name: Unique name for the counter - - minimum: Counter can not go below the minimum, when set - - redis: A Redis client. Will use AA's cache client by default - """ - - CACHE_KEY_BASE = "allianceauth-item-counter" - DEFAULT_CACHE_TIMEOUT = 24 * 3600 - - def __init__( - self, name: str, minimum: Optional[int] = None, redis: Optional[Redis] = None - ) -> None: - if not name: - raise ValueError("Must define a name") - - self._name = str(name) - self._minimum = minimum - self._redis = get_redis_client_or_stub() if not redis else redis - - @property - def _cache_key(self) -> str: - return f"{self.CACHE_KEY_BASE}-{self._name}" - - def reset(self, init_value: int = 0): - """Reset counter to initial value.""" - with self._redis.lock(f"{self.CACHE_KEY_BASE}-reset"): - if self._minimum is not None and init_value < self._minimum: - raise ValueError("Can not reset below minimum") - - cache.set(self._cache_key, init_value, self.DEFAULT_CACHE_TIMEOUT) - - def incr(self, delta: int = 1): - """Increment counter by delta.""" - try: - cache.incr(self._cache_key, delta) - except ValueError: - pass - - def decr(self, delta: int = 1): - """Decrement counter by delta.""" - with self._redis.lock(f"{self.CACHE_KEY_BASE}-decr"): - if self._minimum is not None and self.value() == self._minimum: - return - try: - cache.decr(self._cache_key, delta) - except ValueError: - pass - - def value(self) -> Optional[int]: - """Return current value or None if not yet initialized.""" - return cache.get(self._cache_key) - - def get_redis_client_or_stub() -> Redis: """Return AA's default cache client or a stub if Redis is not available.""" redis = get_redis_client() diff --git a/allianceauth/authentication/task_statistics/signals.py b/allianceauth/authentication/task_statistics/signals.py index 17665d65..e9d7babc 100644 --- a/allianceauth/authentication/task_statistics/signals.py +++ b/allianceauth/authentication/task_statistics/signals.py @@ -1,15 +1,12 @@ """Signals for Task Statistics.""" from celery.signals import ( - task_failure, task_internal_error, task_postrun, task_prerun, task_retry, - task_success, worker_ready, + task_failure, task_internal_error, task_retry, task_success, worker_ready, ) from django.conf import settings -from .counters import ( - failed_tasks, retried_tasks, running_tasks, succeeded_tasks, -) +from .counters import failed_tasks, retried_tasks, succeeded_tasks def reset_counters(): @@ -17,7 +14,6 @@ def reset_counters(): succeeded_tasks.clear() failed_tasks.clear() retried_tasks.clear() - running_tasks.reset() def is_enabled() -> bool: @@ -55,15 +51,3 @@ def record_task_failed(*args, **kwargs): def record_task_internal_error(*args, **kwargs): if is_enabled(): failed_tasks.add() - - -@task_prerun.connect -def record_task_prerun(*args, **kwargs): - if is_enabled(): - running_tasks.incr() - - -@task_postrun.connect -def record_task_postrun(*args, **kwargs): - if is_enabled(): - running_tasks.decr() diff --git a/allianceauth/authentication/task_statistics/tests/test_counters.py b/allianceauth/authentication/task_statistics/tests/test_counters.py index 284f86ca..2d2555aa 100644 --- a/allianceauth/authentication/task_statistics/tests/test_counters.py +++ b/allianceauth/authentication/task_statistics/tests/test_counters.py @@ -4,11 +4,7 @@ from django.test import TestCase from django.utils.timezone import now from allianceauth.authentication.task_statistics.counters import ( - dashboard_results, - succeeded_tasks, - retried_tasks, - failed_tasks, - running_tasks, + dashboard_results, failed_tasks, retried_tasks, succeeded_tasks, ) @@ -32,7 +28,6 @@ class TestDashboardResults(TestCase): failed_tasks.add(now() - dt.timedelta(hours=1, seconds=1)) failed_tasks.add() - running_tasks.reset(8) # when results = dashboard_results(hours=1) # then @@ -41,14 +36,12 @@ class TestDashboardResults(TestCase): self.assertEqual(results.failed, 1) self.assertEqual(results.total, 6) self.assertEqual(results.earliest_task, earliest_task) - self.assertEqual(results.running, 8) def test_should_work_with_no_data(self): # given succeeded_tasks.clear() retried_tasks.clear() failed_tasks.clear() - running_tasks.reset() # when results = dashboard_results(hours=1) # then @@ -57,4 +50,3 @@ class TestDashboardResults(TestCase): self.assertEqual(results.failed, 0) self.assertEqual(results.total, 0) self.assertIsNone(results.earliest_task) - self.assertEqual(results.running, 0) diff --git a/allianceauth/authentication/task_statistics/tests/test_helpers.py b/allianceauth/authentication/task_statistics/tests/test_helpers.py index 757ae38e..51dae201 100644 --- a/allianceauth/authentication/task_statistics/tests/test_helpers.py +++ b/allianceauth/authentication/task_statistics/tests/test_helpers.py @@ -4,125 +4,11 @@ from unittest.mock import patch from redis import RedisError from allianceauth.authentication.task_statistics.helpers import ( - ItemCounter, _RedisStub, get_redis_client_or_stub, + _RedisStub, get_redis_client_or_stub, ) MODULE_PATH = "allianceauth.authentication.task_statistics.helpers" -COUNTER_NAME = "test-counter" - - -class TestItemCounter(TestCase): - def test_can_create_counter(self): - # when - counter = ItemCounter(COUNTER_NAME) - # then - self.assertIsInstance(counter, ItemCounter) - - def test_can_reset_counter_to_default(self): - # given - counter = ItemCounter(COUNTER_NAME) - # when - counter.reset() - # then - self.assertEqual(counter.value(), 0) - - def test_can_reset_counter_to_custom_value(self): - # given - counter = ItemCounter(COUNTER_NAME) - # when - counter.reset(42) - # then - self.assertEqual(counter.value(), 42) - - def test_can_increment_counter_by_default(self): - # given - counter = ItemCounter(COUNTER_NAME) - counter.reset(0) - # when - counter.incr() - # then - self.assertEqual(counter.value(), 1) - - def test_can_increment_counter_by_custom_value(self): - # given - counter = ItemCounter(COUNTER_NAME) - counter.reset(0) - # when - counter.incr(8) - # then - self.assertEqual(counter.value(), 8) - - def test_can_decrement_counter_by_default(self): - # given - counter = ItemCounter(COUNTER_NAME) - counter.reset(9) - # when - counter.decr() - # then - self.assertEqual(counter.value(), 8) - - def test_can_decrement_counter_by_custom_value(self): - # given - counter = ItemCounter(COUNTER_NAME) - counter.reset(9) - # when - counter.decr(8) - # then - self.assertEqual(counter.value(), 1) - - def test_can_decrement_counter_below_zero(self): - # given - counter = ItemCounter(COUNTER_NAME) - counter.reset(0) - # when - counter.decr(1) - # then - self.assertEqual(counter.value(), -1) - - def test_can_not_decrement_counter_below_minimum(self): - # given - counter = ItemCounter(COUNTER_NAME, minimum=0) - counter.reset(0) - # when - counter.decr(1) - # then - self.assertEqual(counter.value(), 0) - - def test_can_not_reset_counter_below_minimum(self): - # given - counter = ItemCounter(COUNTER_NAME, minimum=0) - # when/then - with self.assertRaises(ValueError): - counter.reset(-1) - - def test_can_not_init_without_name(self): - # when/then - with self.assertRaises(ValueError): - ItemCounter(name="") - - def test_can_ignore_invalid_values_when_incrementing(self): - # given - counter = ItemCounter(COUNTER_NAME) - counter.reset(0) - # when - with patch(MODULE_PATH + ".cache.incr") as m: - m.side_effect = ValueError - counter.incr() - # then - self.assertEqual(counter.value(), 0) - - def test_can_ignore_invalid_values_when_decrementing(self): - # given - counter = ItemCounter(COUNTER_NAME) - counter.reset(1) - # when - with patch(MODULE_PATH + ".cache.decr") as m: - m.side_effect = ValueError - counter.decr() - # then - self.assertEqual(counter.value(), 1) - class TestGetRedisClient(TestCase): def test_should_return_mock_if_redis_not_available_1(self): diff --git a/allianceauth/authentication/tests/core/__init__.py b/allianceauth/authentication/tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/allianceauth/authentication/tests/core/test_celery_workers.py b/allianceauth/authentication/tests/core/test_celery_workers.py new file mode 100644 index 00000000..95bec08d --- /dev/null +++ b/allianceauth/authentication/tests/core/test_celery_workers.py @@ -0,0 +1,85 @@ +from unittest.mock import patch + +from amqp.exceptions import ChannelError + +from django.test import TestCase + +from allianceauth.authentication.core.celery_workers import ( + active_tasks_count, queued_tasks_count, +) + +MODULE_PATH = "allianceauth.authentication.core.celery_workers" + + +@patch(MODULE_PATH + ".current_app") +class TestActiveTasksCount(TestCase): + def test_should_return_correct_count_when_no_active_tasks(self, mock_current_app): + # given + mock_current_app.control.inspect.return_value.active.return_value = { + "queue": [] + } + # when + result = active_tasks_count() + # then + self.assertEqual(result, 0) + + def test_should_return_correct_task_count_for_active_tasks(self, mock_current_app): + # given + mock_current_app.control.inspect.return_value.active.return_value = { + "queue": [1, 2, 3] + } + # when + result = active_tasks_count() + # then + self.assertEqual(result, 3) + + def test_should_return_correct_task_count_for_multiple_queues( + self, mock_current_app + ): + # given + mock_current_app.control.inspect.return_value.active.return_value = { + "queue_1": [1, 2], + "queue_2": [3, 4], + } + # when + result = active_tasks_count() + # then + self.assertEqual(result, 4) + + def test_should_return_none_when_celery_not_available(self, mock_current_app): + # given + mock_current_app.control.inspect.return_value.active.return_value = None + # when + result = active_tasks_count() + # then + self.assertIsNone(result) + + +@patch(MODULE_PATH + ".current_app") +class TestQueuedTasksCount(TestCase): + def test_should_return_queue_length_when_queue_exists(self, mock_current_app): + # given + mock_conn = ( + mock_current_app.connection_or_acquire.return_value.__enter__.return_value + ) + mock_conn.default_channel.queue_declare.return_value.message_count = 7 + # when + result = queued_tasks_count() + # then + self.assertEqual(result, 7) + + def test_should_return_0_when_queue_does_not_exists(self, mock_current_app): + # given + mock_current_app.connection_or_acquire.side_effect = ChannelError + # when + result = queued_tasks_count() + # then + self.assertEqual(result, 0) + + def test_should_return_None_on_other_errors(self, mock_current_app): + # given + mock_current_app.connection_or_acquire.side_effect = RuntimeError + # when + result = queued_tasks_count() + # then + self.assertIsNone(result) diff --git a/allianceauth/authentication/tests/test_templatetags.py b/allianceauth/authentication/tests/test_templatetags.py index 927c5504..450da33e 100644 --- a/allianceauth/authentication/tests/test_templatetags.py +++ b/allianceauth/authentication/tests/test_templatetags.py @@ -9,12 +9,8 @@ from django.core.cache import cache from django.test import TestCase from allianceauth.templatetags.admin_status import ( - status_overview, - _fetch_list_from_gitlab, - _current_notifications, - _current_version_summary, - _fetch_notification_issues_from_gitlab, - _latests_versions + _current_notifications, _current_version_summary, _fetch_list_from_gitlab, + _fetch_notification_issues_from_gitlab, _latests_versions, status_overview, ) MODULE_PATH = 'allianceauth.templatetags' @@ -56,14 +52,10 @@ TEST_VERSION = '2.6.5' class TestStatusOverviewTag(TestCase): @patch(MODULE_PATH + '.admin_status.__version__', TEST_VERSION) - @patch(MODULE_PATH + '.admin_status._fetch_celery_queue_length') @patch(MODULE_PATH + '.admin_status._current_version_summary') @patch(MODULE_PATH + '.admin_status._current_notifications') def test_status_overview( - self, - mock_current_notifications, - mock_current_version_info, - mock_fetch_celery_queue_length + self, mock_current_notifications, mock_current_version_info ): # given notifications = { @@ -82,7 +74,6 @@ class TestStatusOverviewTag(TestCase): 'latest_beta_version': '2.4.4a1', } mock_current_version_info.return_value = version_info - mock_fetch_celery_queue_length.return_value = 3 # when result = status_overview() # then @@ -96,7 +87,6 @@ class TestStatusOverviewTag(TestCase): self.assertEqual(result["latest_minor_version"], '2.4.0') self.assertEqual(result["latest_patch_version"], '2.4.5') self.assertEqual(result["latest_beta_version"], '2.4.4a1') - self.assertEqual(result["task_queue_length"], 3) class TestNotifications(TestCase): diff --git a/allianceauth/authentication/tests/test_views.py b/allianceauth/authentication/tests/test_views.py new file mode 100644 index 00000000..09aa4807 --- /dev/null +++ b/allianceauth/authentication/tests/test_views.py @@ -0,0 +1,39 @@ +import json +from unittest.mock import patch + +from django.test import RequestFactory, TestCase + +from allianceauth.authentication.views import task_counts +from allianceauth.tests.auth_utils import AuthUtils + +MODULE_PATH = "allianceauth.authentication.views" + + +def jsonresponse_to_dict(response) -> dict: + return json.loads(response.content) + + +@patch(MODULE_PATH + ".queued_tasks_count") +@patch(MODULE_PATH + ".active_tasks_count") +class TestRunningTasksCount(TestCase): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.factory = RequestFactory() + cls.user = AuthUtils.create_user("bruce_wayne") + + def test_should_return_data( + self, mock_active_tasks_count, mock_queued_tasks_count + ): + # given + mock_active_tasks_count.return_value = 2 + mock_queued_tasks_count.return_value = 3 + request = self.factory.get("/") + request.user = self.user + # when + response = task_counts(request) + # then + self.assertEqual(response.status_code, 200) + self.assertDictEqual( + jsonresponse_to_dict(response), {"tasks_running": 2, "tasks_queued": 3} + ) diff --git a/allianceauth/authentication/urls.py b/allianceauth/authentication/urls.py index 6b1e9a5e..a7dc66e3 100644 --- a/allianceauth/authentication/urls.py +++ b/allianceauth/authentication/urls.py @@ -38,4 +38,5 @@ urlpatterns = [ name='token_refresh' ), path('dashboard/', views.dashboard, name='dashboard'), + path('task-counts/', views.task_counts, name='task_counts'), ] diff --git a/allianceauth/authentication/views.py b/allianceauth/authentication/views.py index 74e05c28..15c746c7 100644 --- a/allianceauth/authentication/views.py +++ b/allianceauth/authentication/views.py @@ -1,31 +1,31 @@ import logging +from django_registration.backends.activation.views import ( + REGISTRATION_SALT, ActivationView as BaseActivationView, + RegistrationView as BaseRegistrationView, +) +from django_registration.signals import user_registered + from django.conf import settings from django.contrib import messages -from django.contrib.auth import login, authenticate +from django.contrib.auth import authenticate, login from django.contrib.auth.decorators import login_required from django.contrib.auth.models import User from django.core import signing -from django.core.mail import EmailMultiAlternatives from django.http import JsonResponse from django.shortcuts import redirect, render from django.template.loader import render_to_string from django.urls import reverse, reverse_lazy from django.utils.translation import gettext_lazy as _ -from allianceauth.eveonline.models import EveCharacter from esi.decorators import token_required from esi.models import Token -from django_registration.backends.activation.views import ( - RegistrationView as BaseRegistrationView, - ActivationView as BaseActivationView, - REGISTRATION_SALT -) -from django_registration.signals import user_registered +from allianceauth.eveonline.models import EveCharacter -from .models import CharacterOwnership +from .core.celery_workers import active_tasks_count, queued_tasks_count from .forms import RegistrationForm +from .models import CharacterOwnership if 'allianceauth.eveonline.autogroups' in settings.INSTALLED_APPS: _has_auto_groups = True @@ -61,6 +61,7 @@ def dashboard(request): } return render(request, 'authentication/dashboard.html', context) + @login_required def token_management(request): tokens = request.user.token_set.all() @@ -70,6 +71,7 @@ def token_management(request): } return render(request, 'authentication/tokens.html', context) + @login_required def token_delete(request, token_id=None): try: @@ -83,6 +85,7 @@ def token_delete(request, token_id=None): messages.warning(request, "Token does not exist") return redirect('authentication:token_management') + @login_required def token_refresh(request, token_id=None): try: @@ -127,7 +130,7 @@ def main_character_change(request, token): def add_character(request, token): if CharacterOwnership.objects.filter(character__character_id=token.character_id).filter( owner_hash=token.character_owner_hash).filter(user=request.user).exists(): - messages.success(request, _('Added %(name)s to your account.'% ({'name': token.character_name}))) + messages.success(request, _('Added %(name)s to your account.' % ({'name': token.character_name}))) else: messages.error(request, _('Failed to add %(name)s to your account: they already have an account.' % ({'name': token.character_name}))) return redirect('authentication:dashboard') @@ -268,8 +271,11 @@ class ActivationView(BaseActivationView): def validate_key(self, activation_key): try: - dump = signing.loads(activation_key, salt=REGISTRATION_SALT, - max_age=settings.ACCOUNT_ACTIVATION_DAYS * 86400) + dump = signing.loads( + activation_key, + salt=REGISTRATION_SALT, + max_age=settings.ACCOUNT_ACTIVATION_DAYS * 86400 + ) return dump except signing.BadSignature: return None @@ -299,3 +305,12 @@ def activation_complete(request): def registration_closed(request): messages.error(request, _('Registration of new accounts is not allowed at this time.')) return redirect('authentication:login') + + +def task_counts(request) -> JsonResponse: + """Return task counts as JSON for an AJAX call.""" + data = { + "tasks_running": active_tasks_count(), + "tasks_queued": queued_tasks_count() + } + return JsonResponse(data) diff --git a/allianceauth/templates/allianceauth/admin-status/overview.html b/allianceauth/templates/allianceauth/admin-status/overview.html index 6f04af23..adfd6e58 100644 --- a/allianceauth/templates/allianceauth/admin-status/overview.html +++ b/allianceauth/templates/allianceauth/admin-status/overview.html @@ -92,12 +92,8 @@ {% include "allianceauth/admin-status/celery_bar_partial.html" with label="failed" level="danger" tasks_count=tasks_failed %}

- {% blocktranslate with running_count=tasks_running|default_if_none:"?"|intcomma %} - {{ running_count }} running | - {% endblocktranslate %} - {% blocktranslate with queue_length=task_queue_length|default_if_none:"?"|intcomma %} - {{ queue_length }} queued - {% endblocktranslate %} + ? {% translate 'running' %} | + ? {% translate 'queued' %}

@@ -105,3 +101,36 @@
+ + diff --git a/allianceauth/templatetags/admin_status.py b/allianceauth/templatetags/admin_status.py index 9a896926..f9f8347b 100644 --- a/allianceauth/templatetags/admin_status.py +++ b/allianceauth/templatetags/admin_status.py @@ -1,9 +1,6 @@ import logging -from typing import Optional -import amqp.exceptions import requests -from celery.app import app_or_default from packaging.version import InvalidVersion, Version as Pep440Version from django import template @@ -11,8 +8,9 @@ from django.conf import settings from django.core.cache import cache from allianceauth import __version__ - -from ..authentication.task_statistics.counters import dashboard_results +from allianceauth.authentication.task_statistics.counters import ( + dashboard_results, +) register = template.Library() @@ -48,18 +46,15 @@ def status_overview() -> dict: response = { "notifications": list(), "current_version": __version__, - "task_queue_length": None, "tasks_succeeded": 0, "tasks_retried": 0, "tasks_failed": 0, "tasks_total": 0, "tasks_hours": 0, "earliest_task": None, - "tasks_running": 0 } response.update(_current_notifications()) response.update(_current_version_summary()) - response.update({'task_queue_length': _fetch_celery_queue_length()}) response.update(_celery_stats()) return response @@ -74,27 +69,9 @@ def _celery_stats() -> dict: "tasks_total": results.total, "tasks_hours": results.hours, "earliest_task": results.earliest_task, - "tasks_running": results.running, } -def _fetch_celery_queue_length() -> Optional[int]: - try: - app = app_or_default(None) - with app.connection_or_acquire() as conn: - result = conn.default_channel.queue_declare( - queue=getattr(settings, 'CELERY_DEFAULT_QUEUE', 'celery'), - passive=True - ) - return result.message_count - except amqp.exceptions.ChannelError: - # Queue doesn't exist, probably empty - return 0 - except Exception: - logger.exception("Failed to get celery queue length") - return None - - def _current_notifications() -> dict: """returns the newest 5 announcement issues""" try: diff --git a/allianceauth/utils/counters.py b/allianceauth/utils/counters.py new file mode 100644 index 00000000..095ea13d --- /dev/null +++ b/allianceauth/utils/counters.py @@ -0,0 +1,65 @@ +"""Counters.""" + +from typing import Optional + +from redis import Redis + +from django.core.cache import cache + +from .cache import get_redis_client + + +class ItemCounter: + """A process safe item counter. + + Args: + - name: Unique name for the counter + - minimum: Counter can not go below the minimum, when set + - redis: A Redis client. Will use AA's cache client by default + """ + + CACHE_KEY_BASE = "allianceauth-item-counter" + DEFAULT_CACHE_TIMEOUT = 24 * 3600 + + def __init__( + self, name: str, minimum: Optional[int] = None, redis: Optional[Redis] = None + ) -> None: + if not name: + raise ValueError("Must define a name") + + self._name = str(name) + self._minimum = minimum + self._redis = get_redis_client() if not redis else redis + + @property + def _cache_key(self) -> str: + return f"{self.CACHE_KEY_BASE}-{self._name}" + + def reset(self, init_value: int = 0): + """Reset counter to initial value.""" + with self._redis.lock(f"{self.CACHE_KEY_BASE}-reset"): + if self._minimum is not None and init_value < self._minimum: + raise ValueError("Can not reset below minimum") + + cache.set(self._cache_key, init_value, self.DEFAULT_CACHE_TIMEOUT) + + def incr(self, delta: int = 1): + """Increment counter by delta.""" + try: + cache.incr(self._cache_key, delta) + except ValueError: + pass + + def decr(self, delta: int = 1): + """Decrement counter by delta.""" + with self._redis.lock(f"{self.CACHE_KEY_BASE}-decr"): + if self._minimum is not None and self.value() == self._minimum: + return + try: + cache.decr(self._cache_key, delta) + except ValueError: + pass + + def value(self) -> Optional[int]: + """Return current value or None if not yet initialized.""" + return cache.get(self._cache_key) diff --git a/allianceauth/utils/tests/test_counters.py b/allianceauth/utils/tests/test_counters.py new file mode 100644 index 00000000..cda3c08b --- /dev/null +++ b/allianceauth/utils/tests/test_counters.py @@ -0,0 +1,120 @@ +from unittest import TestCase +from unittest.mock import patch + +from allianceauth.utils.counters import ItemCounter + +MODULE_PATH = "allianceauth.utils.counters" + +COUNTER_NAME = "test-counter" + + +class TestItemCounter(TestCase): + def test_can_create_counter(self): + # when + counter = ItemCounter(COUNTER_NAME) + # then + self.assertIsInstance(counter, ItemCounter) + + def test_can_reset_counter_to_default(self): + # given + counter = ItemCounter(COUNTER_NAME) + # when + counter.reset() + # then + self.assertEqual(counter.value(), 0) + + def test_can_reset_counter_to_custom_value(self): + # given + counter = ItemCounter(COUNTER_NAME) + # when + counter.reset(42) + # then + self.assertEqual(counter.value(), 42) + + def test_can_increment_counter_by_default(self): + # given + counter = ItemCounter(COUNTER_NAME) + counter.reset(0) + # when + counter.incr() + # then + self.assertEqual(counter.value(), 1) + + def test_can_increment_counter_by_custom_value(self): + # given + counter = ItemCounter(COUNTER_NAME) + counter.reset(0) + # when + counter.incr(8) + # then + self.assertEqual(counter.value(), 8) + + def test_can_decrement_counter_by_default(self): + # given + counter = ItemCounter(COUNTER_NAME) + counter.reset(9) + # when + counter.decr() + # then + self.assertEqual(counter.value(), 8) + + def test_can_decrement_counter_by_custom_value(self): + # given + counter = ItemCounter(COUNTER_NAME) + counter.reset(9) + # when + counter.decr(8) + # then + self.assertEqual(counter.value(), 1) + + def test_can_decrement_counter_below_zero(self): + # given + counter = ItemCounter(COUNTER_NAME) + counter.reset(0) + # when + counter.decr(1) + # then + self.assertEqual(counter.value(), -1) + + def test_can_not_decrement_counter_below_minimum(self): + # given + counter = ItemCounter(COUNTER_NAME, minimum=0) + counter.reset(0) + # when + counter.decr(1) + # then + self.assertEqual(counter.value(), 0) + + def test_can_not_reset_counter_below_minimum(self): + # given + counter = ItemCounter(COUNTER_NAME, minimum=0) + # when/then + with self.assertRaises(ValueError): + counter.reset(-1) + + def test_can_not_init_without_name(self): + # when/then + with self.assertRaises(ValueError): + ItemCounter(name="") + + def test_can_ignore_invalid_values_when_incrementing(self): + # given + counter = ItemCounter(COUNTER_NAME) + counter.reset(0) + # when + with patch(MODULE_PATH + ".cache.incr") as m: + m.side_effect = ValueError + counter.incr() + # then + self.assertEqual(counter.value(), 0) + + def test_can_ignore_invalid_values_when_decrementing(self): + # given + counter = ItemCounter(COUNTER_NAME) + counter.reset(1) + # when + with patch(MODULE_PATH + ".cache.decr") as m: + m.side_effect = ValueError + counter.decr() + # then + self.assertEqual(counter.value(), 1)