Prevent assignment of reserved groups to AuthTSgroup mappings.

Implemented in TS group updates to prevent their creation / delete once
reserved, and the admin site for when a reserved group name is created
but before the TS group sync occurs.
This commit is contained in:
Adarnof 2021-12-08 23:41:10 -05:00
parent d11832913d
commit 72740b9e4d
3 changed files with 105 additions and 35 deletions

View File

@ -1,7 +1,8 @@
from django.contrib import admin from django.contrib import admin
from django.contrib.auth.models import Group
from .models import AuthTS, Teamspeak3User, StateGroup from .models import AuthTS, Teamspeak3User, StateGroup, TSgroup
from ...admin import ServicesUserAdmin from ...admin import ServicesUserAdmin
from allianceauth.groupmanagement.models import ReservedGroupName
@admin.register(Teamspeak3User) @admin.register(Teamspeak3User)
@ -25,6 +26,16 @@ class AuthTSgroupAdmin(admin.ModelAdmin):
fields = ('auth_group', 'ts_group') fields = ('auth_group', 'ts_group')
filter_horizontal = ('ts_group',) filter_horizontal = ('ts_group',)
def formfield_for_foreignkey(self, db_field, request, **kwargs):
if db_field.name == 'auth_group':
kwargs['queryset'] = Group.objects.exclude(name__in=ReservedGroupName.objects.values_list('name', flat=True))
return super().formfield_for_foreignkey(db_field, request, **kwargs)
def formfield_for_manytomany(self, db_field, request, **kwargs):
if db_field.name == 'ts_group':
kwargs['queryset'] = TSgroup.objects.exclude(ts_group_name__in=ReservedGroupName.objects.values_list('name', flat=True))
return super().formfield_for_manytomany(db_field, request, **kwargs)
def _ts_group(self, obj): def _ts_group(self, obj):
return [x for x in obj.ts_group.all().order_by('ts_group_id')] return [x for x in obj.ts_group.all().order_by('ts_group_id')]

View File

@ -157,32 +157,25 @@ class Teamspeak3Manager:
logger.info(f"Removed user id {uid} from group id {groupid} on TS3 server.") logger.info(f"Removed user id {uid} from group id {groupid} on TS3 server.")
def _sync_ts_group_db(self): def _sync_ts_group_db(self):
logger.debug("_sync_ts_group_db function called.")
try: try:
remote_groups = self._group_list() remote_groups = self._group_list()
local_groups = TSgroup.objects.all() managed_groups = {g:remote_groups[g] for g in remote_groups if g in set(remote_groups.keys()) - set(ReservedGroupName.objects.values_list('name', flat=True))}
logger.debug("Comparing remote groups to TSgroup objects: %s" % local_groups) remove = TSgroup.objects.exclude(ts_group_id__in=managed_groups.values())
for key in remote_groups:
logger.debug(f"Typecasting remote_group value at position {key} to int: {remote_groups[key]}") if remove:
remote_groups[key] = int(remote_groups[key]) logger.debug(f"Deleting {remove.count()} TSgroup models: not found on server, or reserved name.")
remove.delete()
add = {g:managed_groups[g] for g in managed_groups if managed_groups[g] in set(managed_groups.values()) - set(TSgroup.objects.values_list("ts_group_id", flat=True))}
if add:
logger.debug(f"Adding {len(add)} new TSgroup models.")
models = [TSgroup(ts_group_name=name, ts_group_id=add[name]) for name in add]
TSgroup.objects.bulk_create(models)
for group in local_groups:
logger.debug("Checking local group %s" % group)
if group.ts_group_id not in remote_groups.values():
logger.debug(
f"Local group id {group.ts_group_id} not found on server. Deleting model {group}")
TSgroup.objects.filter(ts_group_id=group.ts_group_id).delete()
for key in remote_groups:
g = TSgroup(ts_group_id=remote_groups[key], ts_group_name=key)
q = TSgroup.objects.filter(ts_group_id=g.ts_group_id)
if not q:
logger.debug("Local group does not exist for TS group {}. Creating TSgroup model {}".format(
remote_groups[key], g))
g.save()
except TeamspeakError as e: except TeamspeakError as e:
logger.error("Error occured while syncing TS group db: %s" % str(e)) logger.error(f"Error occurred while syncing TS group db: {str(e)}")
except: except:
logger.exception("An unhandled exception has occured while syncing TS groups.") logger.exception("An unhandled exception has occurred while syncing TS groups.")
def add_user(self, user, fmt_name): def add_user(self, user, fmt_name):
username_clean = self.__santatize_username(fmt_name[:30]) username_clean = self.__santatize_username(fmt_name[:30])

View File

@ -5,16 +5,17 @@ from django import urls
from django.contrib.auth.models import User, Group, Permission from django.contrib.auth.models import User, Group, Permission
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db.models import signals from django.db.models import signals
from django.contrib.admin import AdminSite
from allianceauth.tests.auth_utils import AuthUtils from allianceauth.tests.auth_utils import AuthUtils
from .auth_hooks import Teamspeak3Service from .auth_hooks import Teamspeak3Service
from .models import Teamspeak3User, AuthTS, TSgroup, StateGroup from .models import Teamspeak3User, AuthTS, TSgroup, StateGroup
from .tasks import Teamspeak3Tasks from .tasks import Teamspeak3Tasks
from .signals import m2m_changed_authts_group, post_save_authts, post_delete_authts from .signals import m2m_changed_authts_group, post_save_authts, post_delete_authts
from .admin import AuthTSgroupAdmin
from .manager import Teamspeak3Manager from .manager import Teamspeak3Manager
from .util.ts3 import TeamspeakError from .util.ts3 import TeamspeakError
from allianceauth.authentication.models import State
from allianceauth.groupmanagement.models import ReservedGroupName from allianceauth.groupmanagement.models import ReservedGroupName
MODULE_PATH = 'allianceauth.services.modules.teamspeak3' MODULE_PATH = 'allianceauth.services.modules.teamspeak3'
@ -316,9 +317,9 @@ class Teamspeak3SignalsTestCase(TestCase):
class Teamspeak3ManagerTestCase(TestCase): class Teamspeak3ManagerTestCase(TestCase):
@classmethod
def setUp(self): def setUpTestData(cls):
self.reserved = ReservedGroupName.objects.create(name='reserved', reason='tests', created_by='Bob, praise be!') cls.reserved = ReservedGroupName.objects.create(name='reserved', reason='tests', created_by='Bob, praise be!')
@staticmethod @staticmethod
def my_side_effect(*args, **kwargs): def my_side_effect(*args, **kwargs):
@ -338,8 +339,8 @@ class Teamspeak3ManagerTestCase(TestCase):
manager._server = server manager._server = server
# create test data # create test data
user = User.objects.create_user("dummy") user = AuthUtils.create_user("dummy")
user.profile.state = State.objects.filter(name="Member").first() AuthUtils.assign_state(user, AuthUtils.get_member_state())
# perform test # perform test
manager.add_user(user, "Dummy User") manager.add_user(user, "Dummy User")
@ -348,8 +349,7 @@ class Teamspeak3ManagerTestCase(TestCase):
@mock.patch.object(Teamspeak3Manager, '_user_group_list') @mock.patch.object(Teamspeak3Manager, '_user_group_list')
@mock.patch.object(Teamspeak3Manager, '_add_user_to_group') @mock.patch.object(Teamspeak3Manager, '_add_user_to_group')
@mock.patch.object(Teamspeak3Manager, '_remove_user_from_group') @mock.patch.object(Teamspeak3Manager, '_remove_user_from_group')
@mock.patch.object(Teamspeak3Manager, 'server') def test_update_groups_add(self, remove, add, groups, userid):
def test_update_groups_add(self, server, remove, add, groups, userid):
"""Add to one group""" """Add to one group"""
userid.return_value = 1 userid.return_value = 1
groups.return_value = {'test': 1} groups.return_value = {'test': 1}
@ -363,8 +363,7 @@ class Teamspeak3ManagerTestCase(TestCase):
@mock.patch.object(Teamspeak3Manager, '_user_group_list') @mock.patch.object(Teamspeak3Manager, '_user_group_list')
@mock.patch.object(Teamspeak3Manager, '_add_user_to_group') @mock.patch.object(Teamspeak3Manager, '_add_user_to_group')
@mock.patch.object(Teamspeak3Manager, '_remove_user_from_group') @mock.patch.object(Teamspeak3Manager, '_remove_user_from_group')
@mock.patch.object(Teamspeak3Manager, 'server') def test_update_groups_remove(self, remove, add, groups, userid):
def test_update_groups_remove(self, server, remove, add, groups, userid):
"""Remove from one group""" """Remove from one group"""
userid.return_value = 1 userid.return_value = 1
groups.return_value = {'test': 1, 'dummy': 2} groups.return_value = {'test': 1, 'dummy': 2}
@ -378,8 +377,7 @@ class Teamspeak3ManagerTestCase(TestCase):
@mock.patch.object(Teamspeak3Manager, '_user_group_list') @mock.patch.object(Teamspeak3Manager, '_user_group_list')
@mock.patch.object(Teamspeak3Manager, '_add_user_to_group') @mock.patch.object(Teamspeak3Manager, '_add_user_to_group')
@mock.patch.object(Teamspeak3Manager, '_remove_user_from_group') @mock.patch.object(Teamspeak3Manager, '_remove_user_from_group')
@mock.patch.object(Teamspeak3Manager, 'server') def test_update_groups_remove_reserved(self, remove, add, groups, userid):
def test_update_groups_remove_reserved(self, server, remove, add, groups, userid):
"""Remove from one group, but do not touch reserved group""" """Remove from one group, but do not touch reserved group"""
userid.return_value = 1 userid.return_value = 1
groups.return_value = {'test': 1, 'dummy': 2, self.reserved.name: 3} groups.return_value = {'test': 1, 'dummy': 2, self.reserved.name: 3}
@ -388,3 +386,71 @@ class Teamspeak3ManagerTestCase(TestCase):
self.assertEqual(add.call_count, 0) self.assertEqual(add.call_count, 0)
self.assertEqual(remove.call_count, 1) self.assertEqual(remove.call_count, 1)
self.assertEqual(remove.call_args[0][1], 2) self.assertEqual(remove.call_args[0][1], 2)
@mock.patch.object(Teamspeak3Manager, '_group_list')
def test_sync_group_db_create(self, group_list):
"""Populate the list of all TSgroups"""
group_list.return_value = {'allowed':1, 'also allowed': 2}
Teamspeak3Manager()._sync_ts_group_db()
self.assertEqual(TSgroup.objects.all().count(), 2)
@mock.patch.object(Teamspeak3Manager, '_group_list')
def test_sync_group_db_delete(self, group_list):
"""Populate the list of all TSgroups, and delete one which no longer exists"""
TSgroup.objects.create(ts_group_name='deleted', ts_group_id=3)
group_list.return_value = {'allowed': 1, 'also allowed': 2}
Teamspeak3Manager()._sync_ts_group_db()
self.assertEqual(TSgroup.objects.all().count(), 2)
self.assertFalse(TSgroup.objects.filter(ts_group_name='deleted').exists())
@mock.patch.object(Teamspeak3Manager, '_group_list')
def test_sync_group_db_dont_create_reserved(self, group_list):
"""Populate the list of all TSgroups, ignoring a reserved group name"""
group_list.return_value = {'allowed': 1, 'reserved': 4}
Teamspeak3Manager()._sync_ts_group_db()
self.assertEqual(TSgroup.objects.all().count(), 1)
self.assertFalse(TSgroup.objects.filter(ts_group_name='reserved').exists())
@mock.patch.object(Teamspeak3Manager, '_group_list')
def test_sync_group_db_delete_reserved(self, group_list):
"""Populate the list of all TSgroups, deleting the TSgroup model for one which has become reserved"""
TSgroup.objects.create(ts_group_name='reserved', ts_group_id=4)
group_list.return_value = {'allowed': 1, 'reserved': 4}
Teamspeak3Manager()._sync_ts_group_db()
self.assertEqual(TSgroup.objects.all().count(), 1)
self.assertFalse(TSgroup.objects.filter(ts_group_name='reserved').exists())
class MockRequest:
pass
class MockSuperUser:
def has_perm(self, perm, obj=None):
return True
request = MockRequest()
request.user = MockSuperUser()
class Teamspeak3AdminTestCase(TestCase):
@classmethod
def setUpTestData(cls):
cls.site = AdminSite()
cls.admin = AuthTSgroupAdmin(AuthTS, cls.site)
cls.group = Group.objects.create(name='test')
cls.ts_group = TSgroup.objects.create(ts_group_name='test')
def test_field_queryset_no_reserved_names(self):
"""Ensure all groups are listed when no reserved names"""
form = self.admin.get_form(request)
self.assertEqual(form.base_fields['auth_group']._get_queryset().count(), 1)
self.assertEqual(form.base_fields['ts_group']._get_queryset().count(), 1)
def test_field_queryset_reserved_names(self):
"""Ensure reserved group names are filtered out"""
ReservedGroupName.objects.bulk_create([ReservedGroupName(name='test', reason='tests', created_by='Bob')])
form = self.admin.get_form(request)
self.assertEqual(form.base_fields['auth_group']._get_queryset().count(), 0)
self.assertEqual(form.base_fields['ts_group']._get_queryset().count(), 0)