diff --git a/allianceauth/services/modules/teamspeak3/admin.py b/allianceauth/services/modules/teamspeak3/admin.py index a8b614d8..dbba1cb1 100644 --- a/allianceauth/services/modules/teamspeak3/admin.py +++ b/allianceauth/services/modules/teamspeak3/admin.py @@ -1,7 +1,8 @@ from django.contrib import admin - -from .models import AuthTS, Teamspeak3User, StateGroup +from django.contrib.auth.models import Group +from .models import AuthTS, Teamspeak3User, StateGroup, TSgroup from ...admin import ServicesUserAdmin +from allianceauth.groupmanagement.models import ReservedGroupName @admin.register(Teamspeak3User) @@ -25,6 +26,16 @@ class AuthTSgroupAdmin(admin.ModelAdmin): fields = ('auth_group', 'ts_group') filter_horizontal = ('ts_group',) + def formfield_for_foreignkey(self, db_field, request, **kwargs): + if db_field.name == 'auth_group': + kwargs['queryset'] = Group.objects.exclude(name__in=ReservedGroupName.objects.values_list('name', flat=True)) + return super().formfield_for_foreignkey(db_field, request, **kwargs) + + def formfield_for_manytomany(self, db_field, request, **kwargs): + if db_field.name == 'ts_group': + kwargs['queryset'] = TSgroup.objects.exclude(ts_group_name__in=ReservedGroupName.objects.values_list('name', flat=True)) + return super().formfield_for_manytomany(db_field, request, **kwargs) + def _ts_group(self, obj): return [x for x in obj.ts_group.all().order_by('ts_group_id')] diff --git a/allianceauth/services/modules/teamspeak3/manager.py b/allianceauth/services/modules/teamspeak3/manager.py index 9f7eea6c..9cad7cae 100755 --- a/allianceauth/services/modules/teamspeak3/manager.py +++ b/allianceauth/services/modules/teamspeak3/manager.py @@ -157,32 +157,25 @@ class Teamspeak3Manager: logger.info(f"Removed user id {uid} from group id {groupid} on TS3 server.") def _sync_ts_group_db(self): - logger.debug("_sync_ts_group_db function called.") try: remote_groups = self._group_list() - local_groups = TSgroup.objects.all() - logger.debug("Comparing remote groups to TSgroup objects: %s" % local_groups) - for key in remote_groups: - logger.debug(f"Typecasting remote_group value at position {key} to int: {remote_groups[key]}") - remote_groups[key] = int(remote_groups[key]) + managed_groups = {g:remote_groups[g] for g in remote_groups if g in set(remote_groups.keys()) - set(ReservedGroupName.objects.values_list('name', flat=True))} + remove = TSgroup.objects.exclude(ts_group_id__in=managed_groups.values()) + + if remove: + logger.debug(f"Deleting {remove.count()} TSgroup models: not found on server, or reserved name.") + remove.delete() + + add = {g:managed_groups[g] for g in managed_groups if managed_groups[g] in set(managed_groups.values()) - set(TSgroup.objects.values_list("ts_group_id", flat=True))} + if add: + logger.debug(f"Adding {len(add)} new TSgroup models.") + models = [TSgroup(ts_group_name=name, ts_group_id=add[name]) for name in add] + TSgroup.objects.bulk_create(models) - for group in local_groups: - logger.debug("Checking local group %s" % group) - if group.ts_group_id not in remote_groups.values(): - logger.debug( - f"Local group id {group.ts_group_id} not found on server. Deleting model {group}") - TSgroup.objects.filter(ts_group_id=group.ts_group_id).delete() - for key in remote_groups: - g = TSgroup(ts_group_id=remote_groups[key], ts_group_name=key) - q = TSgroup.objects.filter(ts_group_id=g.ts_group_id) - if not q: - logger.debug("Local group does not exist for TS group {}. Creating TSgroup model {}".format( - remote_groups[key], g)) - g.save() except TeamspeakError as e: - logger.error("Error occured while syncing TS group db: %s" % str(e)) + logger.error(f"Error occurred while syncing TS group db: {str(e)}") except: - logger.exception("An unhandled exception has occured while syncing TS groups.") + logger.exception("An unhandled exception has occurred while syncing TS groups.") def add_user(self, user, fmt_name): username_clean = self.__santatize_username(fmt_name[:30]) diff --git a/allianceauth/services/modules/teamspeak3/tests.py b/allianceauth/services/modules/teamspeak3/tests.py index 4abc68e4..bfed1906 100644 --- a/allianceauth/services/modules/teamspeak3/tests.py +++ b/allianceauth/services/modules/teamspeak3/tests.py @@ -5,16 +5,17 @@ from django import urls from django.contrib.auth.models import User, Group, Permission from django.core.exceptions import ObjectDoesNotExist from django.db.models import signals +from django.contrib.admin import AdminSite from allianceauth.tests.auth_utils import AuthUtils from .auth_hooks import Teamspeak3Service from .models import Teamspeak3User, AuthTS, TSgroup, StateGroup from .tasks import Teamspeak3Tasks from .signals import m2m_changed_authts_group, post_save_authts, post_delete_authts +from .admin import AuthTSgroupAdmin from .manager import Teamspeak3Manager from .util.ts3 import TeamspeakError -from allianceauth.authentication.models import State from allianceauth.groupmanagement.models import ReservedGroupName MODULE_PATH = 'allianceauth.services.modules.teamspeak3' @@ -316,9 +317,9 @@ class Teamspeak3SignalsTestCase(TestCase): class Teamspeak3ManagerTestCase(TestCase): - - def setUp(self): - self.reserved = ReservedGroupName.objects.create(name='reserved', reason='tests', created_by='Bob, praise be!') + @classmethod + def setUpTestData(cls): + cls.reserved = ReservedGroupName.objects.create(name='reserved', reason='tests', created_by='Bob, praise be!') @staticmethod def my_side_effect(*args, **kwargs): @@ -338,8 +339,8 @@ class Teamspeak3ManagerTestCase(TestCase): manager._server = server # create test data - user = User.objects.create_user("dummy") - user.profile.state = State.objects.filter(name="Member").first() + user = AuthUtils.create_user("dummy") + AuthUtils.assign_state(user, AuthUtils.get_member_state()) # perform test manager.add_user(user, "Dummy User") @@ -348,8 +349,7 @@ class Teamspeak3ManagerTestCase(TestCase): @mock.patch.object(Teamspeak3Manager, '_user_group_list') @mock.patch.object(Teamspeak3Manager, '_add_user_to_group') @mock.patch.object(Teamspeak3Manager, '_remove_user_from_group') - @mock.patch.object(Teamspeak3Manager, 'server') - def test_update_groups_add(self, server, remove, add, groups, userid): + def test_update_groups_add(self, remove, add, groups, userid): """Add to one group""" userid.return_value = 1 groups.return_value = {'test': 1} @@ -363,8 +363,7 @@ class Teamspeak3ManagerTestCase(TestCase): @mock.patch.object(Teamspeak3Manager, '_user_group_list') @mock.patch.object(Teamspeak3Manager, '_add_user_to_group') @mock.patch.object(Teamspeak3Manager, '_remove_user_from_group') - @mock.patch.object(Teamspeak3Manager, 'server') - def test_update_groups_remove(self, server, remove, add, groups, userid): + def test_update_groups_remove(self, remove, add, groups, userid): """Remove from one group""" userid.return_value = 1 groups.return_value = {'test': 1, 'dummy': 2} @@ -378,8 +377,7 @@ class Teamspeak3ManagerTestCase(TestCase): @mock.patch.object(Teamspeak3Manager, '_user_group_list') @mock.patch.object(Teamspeak3Manager, '_add_user_to_group') @mock.patch.object(Teamspeak3Manager, '_remove_user_from_group') - @mock.patch.object(Teamspeak3Manager, 'server') - def test_update_groups_remove_reserved(self, server, remove, add, groups, userid): + def test_update_groups_remove_reserved(self, remove, add, groups, userid): """Remove from one group, but do not touch reserved group""" userid.return_value = 1 groups.return_value = {'test': 1, 'dummy': 2, self.reserved.name: 3} @@ -388,3 +386,71 @@ class Teamspeak3ManagerTestCase(TestCase): self.assertEqual(add.call_count, 0) self.assertEqual(remove.call_count, 1) self.assertEqual(remove.call_args[0][1], 2) + + @mock.patch.object(Teamspeak3Manager, '_group_list') + def test_sync_group_db_create(self, group_list): + """Populate the list of all TSgroups""" + group_list.return_value = {'allowed':1, 'also allowed': 2} + Teamspeak3Manager()._sync_ts_group_db() + self.assertEqual(TSgroup.objects.all().count(), 2) + + @mock.patch.object(Teamspeak3Manager, '_group_list') + def test_sync_group_db_delete(self, group_list): + """Populate the list of all TSgroups, and delete one which no longer exists""" + TSgroup.objects.create(ts_group_name='deleted', ts_group_id=3) + group_list.return_value = {'allowed': 1, 'also allowed': 2} + Teamspeak3Manager()._sync_ts_group_db() + self.assertEqual(TSgroup.objects.all().count(), 2) + self.assertFalse(TSgroup.objects.filter(ts_group_name='deleted').exists()) + + @mock.patch.object(Teamspeak3Manager, '_group_list') + def test_sync_group_db_dont_create_reserved(self, group_list): + """Populate the list of all TSgroups, ignoring a reserved group name""" + group_list.return_value = {'allowed': 1, 'reserved': 4} + Teamspeak3Manager()._sync_ts_group_db() + self.assertEqual(TSgroup.objects.all().count(), 1) + self.assertFalse(TSgroup.objects.filter(ts_group_name='reserved').exists()) + + @mock.patch.object(Teamspeak3Manager, '_group_list') + def test_sync_group_db_delete_reserved(self, group_list): + """Populate the list of all TSgroups, deleting the TSgroup model for one which has become reserved""" + TSgroup.objects.create(ts_group_name='reserved', ts_group_id=4) + group_list.return_value = {'allowed': 1, 'reserved': 4} + Teamspeak3Manager()._sync_ts_group_db() + self.assertEqual(TSgroup.objects.all().count(), 1) + self.assertFalse(TSgroup.objects.filter(ts_group_name='reserved').exists()) + + +class MockRequest: + pass + + +class MockSuperUser: + def has_perm(self, perm, obj=None): + return True + + +request = MockRequest() +request.user = MockSuperUser() + + +class Teamspeak3AdminTestCase(TestCase): + @classmethod + def setUpTestData(cls): + cls.site = AdminSite() + cls.admin = AuthTSgroupAdmin(AuthTS, cls.site) + cls.group = Group.objects.create(name='test') + cls.ts_group = TSgroup.objects.create(ts_group_name='test') + + def test_field_queryset_no_reserved_names(self): + """Ensure all groups are listed when no reserved names""" + form = self.admin.get_form(request) + self.assertEqual(form.base_fields['auth_group']._get_queryset().count(), 1) + self.assertEqual(form.base_fields['ts_group']._get_queryset().count(), 1) + + def test_field_queryset_reserved_names(self): + """Ensure reserved group names are filtered out""" + ReservedGroupName.objects.bulk_create([ReservedGroupName(name='test', reason='tests', created_by='Bob')]) + form = self.admin.get_form(request) + self.assertEqual(form.base_fields['auth_group']._get_queryset().count(), 0) + self.assertEqual(form.base_fields['ts_group']._get_queryset().count(), 0)