diff --git a/allianceauth/services/signals.py b/allianceauth/services/signals.py index 34124a86..773f8ab6 100644 --- a/allianceauth/services/signals.py +++ b/allianceauth/services/signals.py @@ -1,4 +1,5 @@ import logging +from functools import partial from django.contrib.auth.models import User, Group, Permission from django.core.exceptions import ObjectDoesNotExist @@ -8,7 +9,7 @@ from django.db.models.signals import pre_delete from django.db.models.signals import pre_save from django.dispatch import receiver from .hooks import ServicesHook -from .tasks import disable_user +from .tasks import disable_user, update_groups_for_user from allianceauth.authentication.models import State, UserProfile from allianceauth.authentication.signals import state_changed @@ -19,21 +20,27 @@ logger = logging.getLogger(__name__) @receiver(m2m_changed, sender=User.groups.through) def m2m_changed_user_groups(sender, instance, action, *args, **kwargs): - logger.debug(f"Received m2m_changed from {instance} groups with action {action}") - - def trigger_service_group_update(): - logger.debug("Triggering service group update for %s" % instance) - # Iterate through Service hooks - for svc in ServicesHook.get_services(): - try: - svc.validate_user(instance) - svc.update_groups(instance) - except: - logger.exception(f'Exception running update_groups for services module {svc} on user {instance}') - - if instance.pk and (action == "post_add" or action == "post_remove" or action == "post_clear"): - logger.debug("Waiting for commit to trigger service group update for %s" % instance) - transaction.on_commit(trigger_service_group_update) + logger.debug( + "%s: Received m2m_changed from groups with action %s", instance, action + ) + if instance.pk and ( + action == "post_add" or action == "post_remove" or action == "post_clear" + ): + if isinstance(instance, User): + logger.debug( + "Waiting for commit to trigger service group update for %s", instance + ) + transaction.on_commit(partial(update_groups_for_user.delay, instance.pk)) + elif ( + isinstance(instance, Group) + and kwargs.get("model") is User + and "pk_set" in kwargs + ): + for user_pk in kwargs["pk_set"]: + logger.debug( + "%s: Waiting for commit to trigger service group update for user", user_pk + ) + transaction.on_commit(partial(update_groups_for_user.delay, user_pk)) @receiver(m2m_changed, sender=User.user_permissions.through) diff --git a/allianceauth/services/tasks.py b/allianceauth/services/tasks.py index cac663c6..58947577 100644 --- a/allianceauth/services/tasks.py +++ b/allianceauth/services/tasks.py @@ -47,3 +47,20 @@ def disable_user(user): for svc in ServicesHook.get_services(): if svc.service_active_for_user(user): svc.delete_user(user) + + +@shared_task +def update_groups_for_user(user_pk: int) -> None: + """Update groups for all services registered to a user.""" + user = User.objects.get(pk=user_pk) + logger.debug("%s: Triggering service group update for user", user) + for svc in ServicesHook.get_services(): + try: + svc.validate_user(user) + svc.update_groups(user) + except Exception: + logger.exception( + 'Exception running update_groups for services module %s on user %s', + svc, + user + ) diff --git a/allianceauth/services/tests/test_signals.py b/allianceauth/services/tests/test_signals.py index 7bd0c59e..e8cfdc57 100644 --- a/allianceauth/services/tests/test_signals.py +++ b/allianceauth/services/tests/test_signals.py @@ -1,7 +1,7 @@ from copy import deepcopy from unittest import mock -from django.test import TestCase +from django.test import override_settings, TestCase, TransactionTestCase from django.contrib.auth.models import Group, Permission from allianceauth.authentication.models import State @@ -9,6 +9,9 @@ from allianceauth.eveonline.models import EveCharacter from allianceauth.tests.auth_utils import AuthUtils +MODULE_PATH = 'allianceauth.services.signals' + + class ServicesSignalsTestCase(TestCase): def setUp(self): self.member = AuthUtils.create_user('auth_member', disconnect_signals=True) @@ -17,17 +20,12 @@ class ServicesSignalsTestCase(TestCase): ) self.none_user = AuthUtils.create_user('none_user', disconnect_signals=True) - @mock.patch('allianceauth.services.signals.transaction') - @mock.patch('allianceauth.services.signals.ServicesHook') - def test_m2m_changed_user_groups(self, services_hook, transaction): + @mock.patch(MODULE_PATH + '.transaction', spec=True) + @mock.patch(MODULE_PATH + '.update_groups_for_user', spec=True) + def test_m2m_changed_user_groups(self, update_groups_for_user, transaction): """ Test that update_groups hook function is called on user groups change """ - svc = mock.Mock() - svc.update_groups.return_value = None - svc.validate_user.return_value = None - - services_hook.get_services.return_value = [svc] # Overload transaction.on_commit so everything happens synchronously transaction.on_commit = lambda fn: fn() @@ -39,17 +37,11 @@ class ServicesSignalsTestCase(TestCase): self.member.save() # Assert - self.assertTrue(services_hook.get_services.called) + self.assertTrue(update_groups_for_user.delay.called) + args, _ = update_groups_for_user.delay.call_args + self.assertEqual(self.member.pk, args[0]) - self.assertTrue(svc.update_groups.called) - args, kwargs = svc.update_groups.call_args - self.assertEqual(self.member, args[0]) - - self.assertTrue(svc.validate_user.called) - args, kwargs = svc.validate_user.call_args - self.assertEqual(self.member, args[0]) - - @mock.patch('allianceauth.services.signals.disable_user') + @mock.patch(MODULE_PATH + '.disable_user') def test_pre_delete_user(self, disable_user): """ Test that disable_member is called when a user is deleted @@ -60,7 +52,7 @@ class ServicesSignalsTestCase(TestCase): args, kwargs = disable_user.call_args self.assertEqual(self.none_user, args[0]) - @mock.patch('allianceauth.services.signals.disable_user') + @mock.patch(MODULE_PATH + '.disable_user') def test_pre_save_user_inactivation(self, disable_user): """ Test a user set inactive has disable_member called @@ -72,7 +64,7 @@ class ServicesSignalsTestCase(TestCase): args, kwargs = disable_user.call_args self.assertEqual(self.member, args[0]) - @mock.patch('allianceauth.services.signals.disable_user') + @mock.patch(MODULE_PATH + '.disable_user') def test_disable_services_on_loss_of_main_character(self, disable_user): """ Test a user set inactive has disable_member called @@ -84,8 +76,8 @@ class ServicesSignalsTestCase(TestCase): args, kwargs = disable_user.call_args self.assertEqual(self.member, args[0]) - @mock.patch('allianceauth.services.signals.transaction') - @mock.patch('allianceauth.services.signals.ServicesHook') + @mock.patch(MODULE_PATH + '.transaction') + @mock.patch(MODULE_PATH + '.ServicesHook') def test_m2m_changed_group_permissions(self, services_hook, transaction): from django.contrib.contenttypes.models import ContentType svc = mock.Mock() @@ -116,8 +108,8 @@ class ServicesSignalsTestCase(TestCase): args, kwargs = svc.validate_user.call_args self.assertEqual(self.member, args[0]) - @mock.patch('allianceauth.services.signals.transaction') - @mock.patch('allianceauth.services.signals.ServicesHook') + @mock.patch(MODULE_PATH + '.transaction') + @mock.patch(MODULE_PATH + '.ServicesHook') def test_m2m_changed_user_permissions(self, services_hook, transaction): from django.contrib.contenttypes.models import ContentType svc = mock.Mock() @@ -145,8 +137,8 @@ class ServicesSignalsTestCase(TestCase): args, kwargs = svc.validate_user.call_args self.assertEqual(self.member, args[0]) - @mock.patch('allianceauth.services.signals.transaction') - @mock.patch('allianceauth.services.signals.ServicesHook') + @mock.patch(MODULE_PATH + '.transaction') + @mock.patch(MODULE_PATH + '.ServicesHook') def test_m2m_changed_user_state_permissions(self, services_hook, transaction): from django.contrib.contenttypes.models import ContentType svc = mock.Mock() @@ -180,7 +172,7 @@ class ServicesSignalsTestCase(TestCase): args, kwargs = svc.validate_user.call_args self.assertEqual(self.member, args[0]) - @mock.patch('allianceauth.services.signals.ServicesHook') + @mock.patch(MODULE_PATH + '.ServicesHook') def test_state_changed_services_validation_and_groups_update(self, services_hook): """Test a user changing state has service accounts validated and groups updated """ @@ -206,8 +198,7 @@ class ServicesSignalsTestCase(TestCase): args, kwargs = svc.update_groups.call_args self.assertEqual(self.member, args[0]) - - @mock.patch('allianceauth.services.signals.ServicesHook') + @mock.patch(MODULE_PATH + '.ServicesHook') def test_state_changed_services_validation_and_groups_update_1(self, services_hook): """Test a user changing main has service accounts validated and sync updated """ @@ -238,7 +229,7 @@ class ServicesSignalsTestCase(TestCase): args, kwargs = svc.sync_nickname.call_args self.assertEqual(self.member, args[0]) - @mock.patch('allianceauth.services.signals.ServicesHook') + @mock.patch(MODULE_PATH + '.ServicesHook') def test_state_changed_services_validation_and_groups_update_2(self, services_hook): """Test a user changing main has service does not have accounts validated and sync updated if the new main is equal to the old main @@ -260,3 +251,71 @@ class ServicesSignalsTestCase(TestCase): self.assertFalse(services_hook.get_services.called) self.assertFalse(svc.validate_user.called) self.assertFalse(svc.sync_nickname.called) + + +@mock.patch( + "allianceauth.services.modules.mumble.auth_hooks.MumbleService.update_groups" +) +@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True) +class TestUserGroupBulkUpdate(TransactionTestCase): + def test_should_run_user_service_check_when_group_added_to_user( + self, mock_update_groups + ): + # given + user = AuthUtils.create_user("Bruce Wayne") + AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001) + group = Group.objects.create(name="Group") + mock_update_groups.reset_mock() + # when + user.groups.add(group) + # then + users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list} + self.assertSetEqual(users_updated, {user}) + + def test_should_run_user_service_check_when_multiple_groups_are_added_to_user( + self, mock_update_groups + ): + # given + user = AuthUtils.create_user("Bruce Wayne") + AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001) + group_1 = Group.objects.create(name="Group 1") + group_2 = Group.objects.create(name="Group 2") + mock_update_groups.reset_mock() + # when + user.groups.add(group_1, group_2) + # then + users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list} + self.assertSetEqual(users_updated, {user}) + + def test_should_run_user_service_check_when_user_added_to_group( + self, mock_update_groups + ): + # given + user = AuthUtils.create_user("Bruce Wayne") + AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001) + group = Group.objects.create(name="Group") + mock_update_groups.reset_mock() + # when + group.user_set.add(user) + # then + users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list} + self.assertSetEqual(users_updated, {user}) + + def test_should_run_user_service_check_when_multiple_users_are_added_to_group( + self, mock_update_groups + ): + # given + user_1 = AuthUtils.create_user("Bruce Wayne") + AuthUtils.add_main_character_2(user_1, "Bruce Wayne", 1001) + user_2 = AuthUtils.create_user("Peter Parker") + AuthUtils.add_main_character_2(user_2, "Peter Parker", 1002) + user_3 = AuthUtils.create_user("Lex Luthor") + AuthUtils.add_main_character_2(user_3, "Lex Luthor", 1011) + group = Group.objects.create(name="Group") + user_1.groups.add(group) + mock_update_groups.reset_mock() + # when + group.user_set.add(user_2, user_3) + # then + users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list} + self.assertSetEqual(users_updated, {user_2, user_3}) diff --git a/allianceauth/services/tests/test_tasks.py b/allianceauth/services/tests/test_tasks.py index 35d9329e..06257a1f 100644 --- a/allianceauth/services/tests/test_tasks.py +++ b/allianceauth/services/tests/test_tasks.py @@ -3,32 +3,50 @@ from unittest import mock from celery_once import AlreadyQueued from django.core.cache import cache -from django.test import TestCase +from django.test import override_settings, TestCase from allianceauth.tests.auth_utils import AuthUtils -from allianceauth.services.tasks import validate_services +from allianceauth.services.tasks import validate_services, update_groups_for_user from ..tasks import DjangoBackend +@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True) class ServicesTasksTestCase(TestCase): def setUp(self): self.member = AuthUtils.create_user('auth_member') @mock.patch('allianceauth.services.tasks.ServicesHook') def test_validate_services(self, services_hook): + # given svc = mock.Mock() svc.validate_user.return_value = None - services_hook.get_services.return_value = [svc] - + # when validate_services.delay(self.member.pk) - + # then self.assertTrue(services_hook.get_services.called) self.assertTrue(svc.validate_user.called) - args, kwargs = svc.validate_user.call_args + args, _ = svc.validate_user.call_args self.assertEqual(self.member, args[0]) # Assert correct user is passed to service hook function + @mock.patch('allianceauth.services.tasks.ServicesHook') + def test_update_groups_for_user(self, services_hook): + # given + svc = mock.Mock() + svc.validate_user.return_value = None + services_hook.get_services.return_value = [svc] + # when + update_groups_for_user.delay(self.member.pk) + # then + self.assertTrue(services_hook.get_services.called) + self.assertTrue(svc.validate_user.called) + args, _ = svc.validate_user.call_args + self.assertEqual(self.member, args[0]) # Assert correct user + self.assertTrue(svc.update_groups.called) + args, _ = svc.update_groups.call_args + self.assertEqual(self.member, args[0]) # Assert correct user + class TestDjangoBackend(TestCase):