Fix: Service group updates broken when adding users to groups

This commit is contained in:
Erik Kalkoken 2022-06-18 02:41:23 +00:00 committed by Ariel Rin
parent 8b2527f408
commit 84ad571aa4
4 changed files with 154 additions and 53 deletions

View File

@ -1,4 +1,5 @@
import logging import logging
from functools import partial
from django.contrib.auth.models import User, Group, Permission from django.contrib.auth.models import User, Group, Permission
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
@ -8,7 +9,7 @@ from django.db.models.signals import pre_delete
from django.db.models.signals import pre_save from django.db.models.signals import pre_save
from django.dispatch import receiver from django.dispatch import receiver
from .hooks import ServicesHook from .hooks import ServicesHook
from .tasks import disable_user from .tasks import disable_user, update_groups_for_user
from allianceauth.authentication.models import State, UserProfile from allianceauth.authentication.models import State, UserProfile
from allianceauth.authentication.signals import state_changed from allianceauth.authentication.signals import state_changed
@ -19,21 +20,27 @@ logger = logging.getLogger(__name__)
@receiver(m2m_changed, sender=User.groups.through) @receiver(m2m_changed, sender=User.groups.through)
def m2m_changed_user_groups(sender, instance, action, *args, **kwargs): def m2m_changed_user_groups(sender, instance, action, *args, **kwargs):
logger.debug(f"Received m2m_changed from {instance} groups with action {action}") logger.debug(
"%s: Received m2m_changed from groups with action %s", instance, action
def trigger_service_group_update(): )
logger.debug("Triggering service group update for %s" % instance) if instance.pk and (
# Iterate through Service hooks action == "post_add" or action == "post_remove" or action == "post_clear"
for svc in ServicesHook.get_services(): ):
try: if isinstance(instance, User):
svc.validate_user(instance) logger.debug(
svc.update_groups(instance) "Waiting for commit to trigger service group update for %s", instance
except: )
logger.exception(f'Exception running update_groups for services module {svc} on user {instance}') transaction.on_commit(partial(update_groups_for_user.delay, instance.pk))
elif (
if instance.pk and (action == "post_add" or action == "post_remove" or action == "post_clear"): isinstance(instance, Group)
logger.debug("Waiting for commit to trigger service group update for %s" % instance) and kwargs.get("model") is User
transaction.on_commit(trigger_service_group_update) and "pk_set" in kwargs
):
for user_pk in kwargs["pk_set"]:
logger.debug(
"%s: Waiting for commit to trigger service group update for user", user_pk
)
transaction.on_commit(partial(update_groups_for_user.delay, user_pk))
@receiver(m2m_changed, sender=User.user_permissions.through) @receiver(m2m_changed, sender=User.user_permissions.through)

View File

@ -47,3 +47,20 @@ def disable_user(user):
for svc in ServicesHook.get_services(): for svc in ServicesHook.get_services():
if svc.service_active_for_user(user): if svc.service_active_for_user(user):
svc.delete_user(user) svc.delete_user(user)
@shared_task
def update_groups_for_user(user_pk: int) -> None:
"""Update groups for all services registered to a user."""
user = User.objects.get(pk=user_pk)
logger.debug("%s: Triggering service group update for user", user)
for svc in ServicesHook.get_services():
try:
svc.validate_user(user)
svc.update_groups(user)
except Exception:
logger.exception(
'Exception running update_groups for services module %s on user %s',
svc,
user
)

View File

@ -1,7 +1,7 @@
from copy import deepcopy from copy import deepcopy
from unittest import mock from unittest import mock
from django.test import TestCase from django.test import override_settings, TestCase, TransactionTestCase
from django.contrib.auth.models import Group, Permission from django.contrib.auth.models import Group, Permission
from allianceauth.authentication.models import State from allianceauth.authentication.models import State
@ -9,6 +9,9 @@ from allianceauth.eveonline.models import EveCharacter
from allianceauth.tests.auth_utils import AuthUtils from allianceauth.tests.auth_utils import AuthUtils
MODULE_PATH = 'allianceauth.services.signals'
class ServicesSignalsTestCase(TestCase): class ServicesSignalsTestCase(TestCase):
def setUp(self): def setUp(self):
self.member = AuthUtils.create_user('auth_member', disconnect_signals=True) self.member = AuthUtils.create_user('auth_member', disconnect_signals=True)
@ -17,17 +20,12 @@ class ServicesSignalsTestCase(TestCase):
) )
self.none_user = AuthUtils.create_user('none_user', disconnect_signals=True) self.none_user = AuthUtils.create_user('none_user', disconnect_signals=True)
@mock.patch('allianceauth.services.signals.transaction') @mock.patch(MODULE_PATH + '.transaction', spec=True)
@mock.patch('allianceauth.services.signals.ServicesHook') @mock.patch(MODULE_PATH + '.update_groups_for_user', spec=True)
def test_m2m_changed_user_groups(self, services_hook, transaction): def test_m2m_changed_user_groups(self, update_groups_for_user, transaction):
""" """
Test that update_groups hook function is called on user groups change Test that update_groups hook function is called on user groups change
""" """
svc = mock.Mock()
svc.update_groups.return_value = None
svc.validate_user.return_value = None
services_hook.get_services.return_value = [svc]
# Overload transaction.on_commit so everything happens synchronously # Overload transaction.on_commit so everything happens synchronously
transaction.on_commit = lambda fn: fn() transaction.on_commit = lambda fn: fn()
@ -39,17 +37,11 @@ class ServicesSignalsTestCase(TestCase):
self.member.save() self.member.save()
# Assert # Assert
self.assertTrue(services_hook.get_services.called) self.assertTrue(update_groups_for_user.delay.called)
args, _ = update_groups_for_user.delay.call_args
self.assertEqual(self.member.pk, args[0])
self.assertTrue(svc.update_groups.called) @mock.patch(MODULE_PATH + '.disable_user')
args, kwargs = svc.update_groups.call_args
self.assertEqual(self.member, args[0])
self.assertTrue(svc.validate_user.called)
args, kwargs = svc.validate_user.call_args
self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.disable_user')
def test_pre_delete_user(self, disable_user): def test_pre_delete_user(self, disable_user):
""" """
Test that disable_member is called when a user is deleted Test that disable_member is called when a user is deleted
@ -60,7 +52,7 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = disable_user.call_args args, kwargs = disable_user.call_args
self.assertEqual(self.none_user, args[0]) self.assertEqual(self.none_user, args[0])
@mock.patch('allianceauth.services.signals.disable_user') @mock.patch(MODULE_PATH + '.disable_user')
def test_pre_save_user_inactivation(self, disable_user): def test_pre_save_user_inactivation(self, disable_user):
""" """
Test a user set inactive has disable_member called Test a user set inactive has disable_member called
@ -72,7 +64,7 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = disable_user.call_args args, kwargs = disable_user.call_args
self.assertEqual(self.member, args[0]) self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.disable_user') @mock.patch(MODULE_PATH + '.disable_user')
def test_disable_services_on_loss_of_main_character(self, disable_user): def test_disable_services_on_loss_of_main_character(self, disable_user):
""" """
Test a user set inactive has disable_member called Test a user set inactive has disable_member called
@ -84,8 +76,8 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = disable_user.call_args args, kwargs = disable_user.call_args
self.assertEqual(self.member, args[0]) self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.transaction') @mock.patch(MODULE_PATH + '.transaction')
@mock.patch('allianceauth.services.signals.ServicesHook') @mock.patch(MODULE_PATH + '.ServicesHook')
def test_m2m_changed_group_permissions(self, services_hook, transaction): def test_m2m_changed_group_permissions(self, services_hook, transaction):
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
svc = mock.Mock() svc = mock.Mock()
@ -116,8 +108,8 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = svc.validate_user.call_args args, kwargs = svc.validate_user.call_args
self.assertEqual(self.member, args[0]) self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.transaction') @mock.patch(MODULE_PATH + '.transaction')
@mock.patch('allianceauth.services.signals.ServicesHook') @mock.patch(MODULE_PATH + '.ServicesHook')
def test_m2m_changed_user_permissions(self, services_hook, transaction): def test_m2m_changed_user_permissions(self, services_hook, transaction):
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
svc = mock.Mock() svc = mock.Mock()
@ -145,8 +137,8 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = svc.validate_user.call_args args, kwargs = svc.validate_user.call_args
self.assertEqual(self.member, args[0]) self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.transaction') @mock.patch(MODULE_PATH + '.transaction')
@mock.patch('allianceauth.services.signals.ServicesHook') @mock.patch(MODULE_PATH + '.ServicesHook')
def test_m2m_changed_user_state_permissions(self, services_hook, transaction): def test_m2m_changed_user_state_permissions(self, services_hook, transaction):
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
svc = mock.Mock() svc = mock.Mock()
@ -180,7 +172,7 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = svc.validate_user.call_args args, kwargs = svc.validate_user.call_args
self.assertEqual(self.member, args[0]) self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.ServicesHook') @mock.patch(MODULE_PATH + '.ServicesHook')
def test_state_changed_services_validation_and_groups_update(self, services_hook): def test_state_changed_services_validation_and_groups_update(self, services_hook):
"""Test a user changing state has service accounts validated and groups updated """Test a user changing state has service accounts validated and groups updated
""" """
@ -206,8 +198,7 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = svc.update_groups.call_args args, kwargs = svc.update_groups.call_args
self.assertEqual(self.member, args[0]) self.assertEqual(self.member, args[0])
@mock.patch(MODULE_PATH + '.ServicesHook')
@mock.patch('allianceauth.services.signals.ServicesHook')
def test_state_changed_services_validation_and_groups_update_1(self, services_hook): def test_state_changed_services_validation_and_groups_update_1(self, services_hook):
"""Test a user changing main has service accounts validated and sync updated """Test a user changing main has service accounts validated and sync updated
""" """
@ -238,7 +229,7 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = svc.sync_nickname.call_args args, kwargs = svc.sync_nickname.call_args
self.assertEqual(self.member, args[0]) self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.ServicesHook') @mock.patch(MODULE_PATH + '.ServicesHook')
def test_state_changed_services_validation_and_groups_update_2(self, services_hook): def test_state_changed_services_validation_and_groups_update_2(self, services_hook):
"""Test a user changing main has service does not have accounts validated """Test a user changing main has service does not have accounts validated
and sync updated if the new main is equal to the old main and sync updated if the new main is equal to the old main
@ -260,3 +251,71 @@ class ServicesSignalsTestCase(TestCase):
self.assertFalse(services_hook.get_services.called) self.assertFalse(services_hook.get_services.called)
self.assertFalse(svc.validate_user.called) self.assertFalse(svc.validate_user.called)
self.assertFalse(svc.sync_nickname.called) self.assertFalse(svc.sync_nickname.called)
@mock.patch(
"allianceauth.services.modules.mumble.auth_hooks.MumbleService.update_groups"
)
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
class TestUserGroupBulkUpdate(TransactionTestCase):
def test_should_run_user_service_check_when_group_added_to_user(
self, mock_update_groups
):
# given
user = AuthUtils.create_user("Bruce Wayne")
AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
group = Group.objects.create(name="Group")
mock_update_groups.reset_mock()
# when
user.groups.add(group)
# then
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
self.assertSetEqual(users_updated, {user})
def test_should_run_user_service_check_when_multiple_groups_are_added_to_user(
self, mock_update_groups
):
# given
user = AuthUtils.create_user("Bruce Wayne")
AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
group_1 = Group.objects.create(name="Group 1")
group_2 = Group.objects.create(name="Group 2")
mock_update_groups.reset_mock()
# when
user.groups.add(group_1, group_2)
# then
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
self.assertSetEqual(users_updated, {user})
def test_should_run_user_service_check_when_user_added_to_group(
self, mock_update_groups
):
# given
user = AuthUtils.create_user("Bruce Wayne")
AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
group = Group.objects.create(name="Group")
mock_update_groups.reset_mock()
# when
group.user_set.add(user)
# then
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
self.assertSetEqual(users_updated, {user})
def test_should_run_user_service_check_when_multiple_users_are_added_to_group(
self, mock_update_groups
):
# given
user_1 = AuthUtils.create_user("Bruce Wayne")
AuthUtils.add_main_character_2(user_1, "Bruce Wayne", 1001)
user_2 = AuthUtils.create_user("Peter Parker")
AuthUtils.add_main_character_2(user_2, "Peter Parker", 1002)
user_3 = AuthUtils.create_user("Lex Luthor")
AuthUtils.add_main_character_2(user_3, "Lex Luthor", 1011)
group = Group.objects.create(name="Group")
user_1.groups.add(group)
mock_update_groups.reset_mock()
# when
group.user_set.add(user_2, user_3)
# then
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
self.assertSetEqual(users_updated, {user_2, user_3})

View File

@ -3,32 +3,50 @@ from unittest import mock
from celery_once import AlreadyQueued from celery_once import AlreadyQueued
from django.core.cache import cache from django.core.cache import cache
from django.test import TestCase from django.test import override_settings, TestCase
from allianceauth.tests.auth_utils import AuthUtils from allianceauth.tests.auth_utils import AuthUtils
from allianceauth.services.tasks import validate_services from allianceauth.services.tasks import validate_services, update_groups_for_user
from ..tasks import DjangoBackend from ..tasks import DjangoBackend
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
class ServicesTasksTestCase(TestCase): class ServicesTasksTestCase(TestCase):
def setUp(self): def setUp(self):
self.member = AuthUtils.create_user('auth_member') self.member = AuthUtils.create_user('auth_member')
@mock.patch('allianceauth.services.tasks.ServicesHook') @mock.patch('allianceauth.services.tasks.ServicesHook')
def test_validate_services(self, services_hook): def test_validate_services(self, services_hook):
# given
svc = mock.Mock() svc = mock.Mock()
svc.validate_user.return_value = None svc.validate_user.return_value = None
services_hook.get_services.return_value = [svc] services_hook.get_services.return_value = [svc]
# when
validate_services.delay(self.member.pk) validate_services.delay(self.member.pk)
# then
self.assertTrue(services_hook.get_services.called) self.assertTrue(services_hook.get_services.called)
self.assertTrue(svc.validate_user.called) self.assertTrue(svc.validate_user.called)
args, kwargs = svc.validate_user.call_args args, _ = svc.validate_user.call_args
self.assertEqual(self.member, args[0]) # Assert correct user is passed to service hook function self.assertEqual(self.member, args[0]) # Assert correct user is passed to service hook function
@mock.patch('allianceauth.services.tasks.ServicesHook')
def test_update_groups_for_user(self, services_hook):
# given
svc = mock.Mock()
svc.validate_user.return_value = None
services_hook.get_services.return_value = [svc]
# when
update_groups_for_user.delay(self.member.pk)
# then
self.assertTrue(services_hook.get_services.called)
self.assertTrue(svc.validate_user.called)
args, _ = svc.validate_user.call_args
self.assertEqual(self.member, args[0]) # Assert correct user
self.assertTrue(svc.update_groups.called)
args, _ = svc.update_groups.call_args
self.assertEqual(self.member, args[0]) # Assert correct user
class TestDjangoBackend(TestCase): class TestDjangoBackend(TestCase):