mirror of
https://gitlab.com/allianceauth/allianceauth.git
synced 2026-02-04 14:16:21 +01:00
Compare commits
86 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
297f98f046 | ||
|
|
27dad05927 | ||
|
|
697e9dd772 | ||
|
|
65f2efc890 | ||
|
|
def30900b4 | ||
|
|
d7fabccddd | ||
|
|
45289e1d17 | ||
|
|
7b9bf08aa3 | ||
|
|
def6431052 | ||
|
|
22a270aedb | ||
|
|
c930f7bbeb | ||
|
|
64ee273953 | ||
|
|
3706a1aedf | ||
|
|
47f1b77320 | ||
|
|
8dec242a93 | ||
|
|
2ff200c566 | ||
|
|
091a2637ea | ||
|
|
113977b19f | ||
|
|
8f39b50b6d | ||
|
|
95b309c358 | ||
|
|
cf3df3b715 | ||
|
|
d815028c4d | ||
|
|
ac5570abe2 | ||
|
|
84ad571aa4 | ||
|
|
38e7705ae7 | ||
|
|
0b6af014fa | ||
|
|
2401f2299d | ||
|
|
919768c8bb | ||
|
|
24db21463b | ||
|
|
1e029af83a | ||
|
|
2b31be789d | ||
|
|
bf1b4bb549 | ||
|
|
dd42b807f0 | ||
|
|
542fbafd98 | ||
|
|
37b9f5c882 | ||
|
|
5bde9a6952 | ||
|
|
23ad9d02d3 | ||
|
|
f99878cc29 | ||
|
|
e64431b06c | ||
|
|
0b2993c1c3 | ||
|
|
75bccf1b0f | ||
|
|
945bc92898 | ||
|
|
ec7d14a839 | ||
|
|
dd1a368ff6 | ||
|
|
54085617dc | ||
|
|
8cdc5af453 | ||
|
|
da93940e13 | ||
|
|
f53b43d9dc | ||
|
|
497a167ca7 | ||
|
|
852c5a3037 | ||
|
|
90f6777a7a | ||
|
|
a8d890abaf | ||
|
|
79379b444c | ||
|
|
ace1de5c68 | ||
|
|
5d6128e9ea | ||
|
|
131cc5ed0a | ||
|
|
9297bed43f | ||
|
|
b2fddc683a | ||
|
|
9af634d16a | ||
|
|
a68163caa3 | ||
|
|
00770fd034 | ||
|
|
01164777ed | ||
|
|
00f5e3e1e0 | ||
|
|
8b2527f408 | ||
|
|
b7500e4e4e | ||
|
|
4f4bd0c419 | ||
|
|
8ae4e02012 | ||
|
|
cc9a07197d | ||
|
|
f18dd1029b | ||
|
|
fd8d43571a | ||
|
|
13e88492f1 | ||
|
|
38df580a56 | ||
|
|
ba39318313 | ||
|
|
d8c6035405 | ||
|
|
2ef3da916b | ||
|
|
d32d8b26ce | ||
|
|
f348b1a34c | ||
|
|
86aaa3edda | ||
|
|
26017056c7 | ||
|
|
e39a3c072b | ||
|
|
827291dda4 | ||
|
|
8de2c3bfcb | ||
|
|
6688f73565 | ||
|
|
72740b9e4d | ||
|
|
d11832913d | ||
|
|
dfe62db8ee |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -76,3 +76,4 @@ celerybeat-schedule
|
||||
.flake8
|
||||
.pylintrc
|
||||
Makefile
|
||||
.isort.cfg
|
||||
|
||||
@@ -54,7 +54,9 @@ test-3.7-core:
|
||||
artifacts:
|
||||
when: always
|
||||
reports:
|
||||
cobertura: coverage.xml
|
||||
coverage_report:
|
||||
coverage_format: cobertura
|
||||
path: coverage.xml
|
||||
|
||||
test-3.8-core:
|
||||
<<: *only-default
|
||||
@@ -64,7 +66,9 @@ test-3.8-core:
|
||||
artifacts:
|
||||
when: always
|
||||
reports:
|
||||
cobertura: coverage.xml
|
||||
coverage_report:
|
||||
coverage_format: cobertura
|
||||
path: coverage.xml
|
||||
|
||||
test-3.9-core:
|
||||
<<: *only-default
|
||||
@@ -74,7 +78,9 @@ test-3.9-core:
|
||||
artifacts:
|
||||
when: always
|
||||
reports:
|
||||
cobertura: coverage.xml
|
||||
coverage_report:
|
||||
coverage_format: cobertura
|
||||
path: coverage.xml
|
||||
|
||||
test-3.10-core:
|
||||
<<: *only-default
|
||||
@@ -84,7 +90,9 @@ test-3.10-core:
|
||||
artifacts:
|
||||
when: always
|
||||
reports:
|
||||
cobertura: coverage.xml
|
||||
coverage_report:
|
||||
coverage_format: cobertura
|
||||
path: coverage.xml
|
||||
|
||||
test-3.11-core:
|
||||
<<: *only-default
|
||||
@@ -94,7 +102,9 @@ test-3.11-core:
|
||||
artifacts:
|
||||
when: always
|
||||
reports:
|
||||
cobertura: coverage.xml
|
||||
coverage_report:
|
||||
coverage_format: cobertura
|
||||
path: coverage.xml
|
||||
allow_failure: true
|
||||
|
||||
test-3.7-all:
|
||||
@@ -105,7 +115,9 @@ test-3.7-all:
|
||||
artifacts:
|
||||
when: always
|
||||
reports:
|
||||
cobertura: coverage.xml
|
||||
coverage_report:
|
||||
coverage_format: cobertura
|
||||
path: coverage.xml
|
||||
|
||||
test-3.8-all:
|
||||
<<: *only-default
|
||||
@@ -115,7 +127,9 @@ test-3.8-all:
|
||||
artifacts:
|
||||
when: always
|
||||
reports:
|
||||
cobertura: coverage.xml
|
||||
coverage_report:
|
||||
coverage_format: cobertura
|
||||
path: coverage.xml
|
||||
|
||||
test-3.9-all:
|
||||
<<: *only-default
|
||||
@@ -125,7 +139,9 @@ test-3.9-all:
|
||||
artifacts:
|
||||
when: always
|
||||
reports:
|
||||
cobertura: coverage.xml
|
||||
coverage_report:
|
||||
coverage_format: cobertura
|
||||
path: coverage.xml
|
||||
|
||||
test-3.10-all:
|
||||
<<: *only-default
|
||||
@@ -135,7 +151,9 @@ test-3.10-all:
|
||||
artifacts:
|
||||
when: always
|
||||
reports:
|
||||
cobertura: coverage.xml
|
||||
coverage_report:
|
||||
coverage_format: cobertura
|
||||
path: coverage.xml
|
||||
|
||||
test-3.11-all:
|
||||
<<: *only-default
|
||||
@@ -145,9 +163,17 @@ test-3.11-all:
|
||||
artifacts:
|
||||
when: always
|
||||
reports:
|
||||
cobertura: coverage.xml
|
||||
coverage_report:
|
||||
coverage_format: cobertura
|
||||
path: coverage.xml
|
||||
allow_failure: true
|
||||
|
||||
test-docs:
|
||||
<<: *only-default
|
||||
image: python:3.9-bullseye
|
||||
script:
|
||||
- tox -e docs
|
||||
|
||||
deploy_production:
|
||||
stage: deploy
|
||||
image: python:3.10-bullseye
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# This will make sure the app is always imported when
|
||||
# Django starts so that shared_task will use this app.
|
||||
|
||||
__version__ = '2.9.4'
|
||||
__version__ = '2.15.1'
|
||||
__title__ = 'Alliance Auth'
|
||||
__url__ = 'https://gitlab.com/allianceauth/allianceauth'
|
||||
NAME = f'{__title__} v{__version__}'
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
from .models import AnalyticsTokens, AnalyticsIdentifier
|
||||
from .tasks import send_ga_tracking_web_view
|
||||
@@ -10,6 +11,8 @@ import re
|
||||
class AnalyticsMiddleware(MiddlewareMixin):
|
||||
def process_response(self, request, response):
|
||||
"""Django Middleware: Process Page Views and creates Analytics Celery Tasks"""
|
||||
if getattr(settings, "ANALYTICS_DISABLED", False):
|
||||
return response
|
||||
analyticstokens = AnalyticsTokens.objects.all()
|
||||
client_id = AnalyticsIdentifier.objects.get(id=1).identifier.hex
|
||||
try:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# Generated by Django 3.1.13 on 2021-10-15 05:02
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
def modify_aa_team_token_add_page_ignore_paths(apps, schema_editor):
|
||||
# We can't import the Person model directly as it may be a newer
|
||||
# version than this migration expects. We use the historical version.
|
||||
# Add /admin/ and /user_notifications_count/ path to ignore
|
||||
|
||||
AnalyticsPath = apps.get_model('analytics', 'AnalyticsPath')
|
||||
admin = AnalyticsPath.objects.create(ignore_path=r"^\/admin\/.*")
|
||||
@@ -17,8 +17,19 @@ def modify_aa_team_token_add_page_ignore_paths(apps, schema_editor):
|
||||
|
||||
|
||||
def undo_modify_aa_team_token_add_page_ignore_paths(apps, schema_editor):
|
||||
# nothing should need to migrate away here?
|
||||
return True
|
||||
#
|
||||
AnalyticsPath = apps.get_model('analytics', 'AnalyticsPath')
|
||||
Tokens = apps.get_model('analytics', 'AnalyticsTokens')
|
||||
|
||||
token = Tokens.objects.get(token="UA-186249766-2")
|
||||
try:
|
||||
admin = AnalyticsPath.objects.get(ignore_path=r"^\/admin\/.*", analyticstokens=token)
|
||||
user_notifications_count = AnalyticsPath.objects.get(ignore_path=r"^\/user_notifications_count\/.*", analyticstokens=token)
|
||||
admin.delete()
|
||||
user_notifications_count.delete()
|
||||
except ObjectDoesNotExist:
|
||||
# Its fine if it doesnt exist, we just dont want them building up when re-migrating
|
||||
pass
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
40
allianceauth/analytics/migrations/0006_more_ignore_paths.py
Normal file
40
allianceauth/analytics/migrations/0006_more_ignore_paths.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Generated by Django 3.2.8 on 2021-10-19 01:47
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
def modify_aa_team_token_add_page_ignore_paths(apps, schema_editor):
|
||||
# Add the /account/activate path to ignore
|
||||
|
||||
AnalyticsPath = apps.get_model('analytics', 'AnalyticsPath')
|
||||
account_activate = AnalyticsPath.objects.create(ignore_path=r"^\/account\/activate\/.*")
|
||||
|
||||
Tokens = apps.get_model('analytics', 'AnalyticsTokens')
|
||||
token = Tokens.objects.get(token="UA-186249766-2")
|
||||
token.ignore_paths.add(account_activate)
|
||||
|
||||
|
||||
def undo_modify_aa_team_token_add_page_ignore_paths(apps, schema_editor):
|
||||
#
|
||||
AnalyticsPath = apps.get_model('analytics', 'AnalyticsPath')
|
||||
Tokens = apps.get_model('analytics', 'AnalyticsTokens')
|
||||
|
||||
token = Tokens.objects.get(token="UA-186249766-2")
|
||||
|
||||
try:
|
||||
account_activate = AnalyticsPath.objects.get(ignore_path=r"^\/account\/activate\/.*", analyticstokens=token)
|
||||
account_activate.delete()
|
||||
except ObjectDoesNotExist:
|
||||
# Its fine if it doesnt exist, we just dont want them building up when re-migrating
|
||||
pass
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('analytics', '0005_alter_analyticspath_ignore_path'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(modify_aa_team_token_add_page_ignore_paths, undo_modify_aa_team_token_add_page_ignore_paths)
|
||||
]
|
||||
@@ -1,7 +1,8 @@
|
||||
from allianceauth.analytics.tasks import analytics_event
|
||||
from celery.signals import task_failure, task_success
|
||||
|
||||
import logging
|
||||
from celery.signals import task_failure, task_success
|
||||
from django.conf import settings
|
||||
from allianceauth.analytics.tasks import analytics_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -11,6 +12,8 @@ def process_failure_signal(
|
||||
sender, task_id, signal,
|
||||
args, kwargs, einfo, **kw):
|
||||
logger.debug("Celery task_failure signal %s" % sender.__class__.__name__)
|
||||
if getattr(settings, "ANALYTICS_DISABLED", False):
|
||||
return
|
||||
|
||||
category = sender.__module__
|
||||
|
||||
@@ -30,6 +33,8 @@ def process_failure_signal(
|
||||
@task_success.connect
|
||||
def celery_success_signal(sender, result=None, **kw):
|
||||
logger.debug("Celery task_success signal %s" % sender.__class__.__name__)
|
||||
if getattr(settings, "ANALYTICS_DISABLED", False):
|
||||
return
|
||||
|
||||
category = sender.__module__
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ if getattr(settings, "ANALYTICS_ENABLE_DEBUG", False) and settings.DEBUG:
|
||||
# Force sending of analytics data during in a debug/test environemt
|
||||
# Usefull for developers working on this feature.
|
||||
logger.warning(
|
||||
"You have 'ANALYTICS_ENABLE_DEBUG' Enabled! "
|
||||
"This debug instance will send analytics data!")
|
||||
"You have 'ANALYTICS_ENABLE_DEBUG' Enabled! "
|
||||
"This debug instance will send analytics data!")
|
||||
DEBUG_URL = COLLECTION_URL
|
||||
|
||||
ANALYTICS_URL = COLLECTION_URL
|
||||
@@ -40,13 +40,12 @@ def analytics_event(category: str,
|
||||
Send a Google Analytics Event for each token stored
|
||||
Includes check for if its enabled/disabled
|
||||
|
||||
Parameters
|
||||
-------
|
||||
`category` (str): Celery Namespace
|
||||
`action` (str): Task Name
|
||||
`label` (str): Optional, Task Success/Exception
|
||||
`value` (int): Optional, If bulk, Query size, can be a binary True/False
|
||||
`event_type` (str): Optional, Celery or Stats only, Default to Celery
|
||||
Args:
|
||||
`category` (str): Celery Namespace
|
||||
`action` (str): Task Name
|
||||
`label` (str): Optional, Task Success/Exception
|
||||
`value` (int): Optional, If bulk, Query size, can be a binary True/False
|
||||
`event_type` (str): Optional, Celery or Stats only, Default to Celery
|
||||
"""
|
||||
analyticstokens = AnalyticsTokens.objects.all()
|
||||
client_id = AnalyticsIdentifier.objects.get(id=1).identifier.hex
|
||||
@@ -60,20 +59,21 @@ def analytics_event(category: str,
|
||||
|
||||
if allowed is True:
|
||||
tracking_id = token.token
|
||||
send_ga_tracking_celery_event.s(tracking_id=tracking_id,
|
||||
client_id=client_id,
|
||||
category=category,
|
||||
action=action,
|
||||
label=label,
|
||||
value=value).\
|
||||
apply_async(priority=9)
|
||||
send_ga_tracking_celery_event.s(
|
||||
tracking_id=tracking_id,
|
||||
client_id=client_id,
|
||||
category=category,
|
||||
action=action,
|
||||
label=label,
|
||||
value=value).apply_async(priority=9)
|
||||
|
||||
|
||||
@shared_task()
|
||||
def analytics_daily_stats():
|
||||
"""Celery Task: Do not call directly
|
||||
|
||||
Gathers a series of daily statistics and sends analytics events containing them"""
|
||||
Gathers a series of daily statistics and sends analytics events containing them
|
||||
"""
|
||||
users = install_stat_users()
|
||||
tokens = install_stat_tokens()
|
||||
addons = install_stat_addons()
|
||||
|
||||
109
allianceauth/analytics/tests/test_integration.py
Normal file
109
allianceauth/analytics/tests/test_integration.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import requests_mock
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
from allianceauth.analytics.tasks import ANALYTICS_URL
|
||||
from allianceauth.eveonline.tasks import update_character
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
@requests_mock.mock()
|
||||
class TestAnalyticsForViews(NoSocketsTestCase):
|
||||
@override_settings(ANALYTICS_DISABLED=False)
|
||||
def test_should_run_analytics(self, requests_mocker):
|
||||
# given
|
||||
requests_mocker.post(ANALYTICS_URL)
|
||||
user = AuthUtils.create_user("Bruce Wayne")
|
||||
self.client.force_login(user)
|
||||
# when
|
||||
response = self.client.get("/dashboard/")
|
||||
# then
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(requests_mocker.called)
|
||||
|
||||
@override_settings(ANALYTICS_DISABLED=True)
|
||||
def test_should_not_run_analytics(self, requests_mocker):
|
||||
# given
|
||||
requests_mocker.post(ANALYTICS_URL)
|
||||
user = AuthUtils.create_user("Bruce Wayne")
|
||||
self.client.force_login(user)
|
||||
# when
|
||||
response = self.client.get("/dashboard/")
|
||||
# then
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertFalse(requests_mocker.called)
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
@requests_mock.mock()
|
||||
class TestAnalyticsForTasks(NoSocketsTestCase):
|
||||
@override_settings(ANALYTICS_DISABLED=False)
|
||||
@patch("allianceauth.eveonline.models.EveCharacter.objects.update_character")
|
||||
def test_should_run_analytics_for_successful_task(
|
||||
self, requests_mocker, mock_update_character
|
||||
):
|
||||
# given
|
||||
requests_mocker.post(ANALYTICS_URL)
|
||||
user = AuthUtils.create_user("Bruce Wayne")
|
||||
character = AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
|
||||
# when
|
||||
update_character.delay(character.character_id)
|
||||
# then
|
||||
self.assertTrue(mock_update_character.called)
|
||||
self.assertTrue(requests_mocker.called)
|
||||
payload = parse_qs(requests_mocker.last_request.text)
|
||||
self.assertListEqual(payload["el"], ["Success"])
|
||||
|
||||
@override_settings(ANALYTICS_DISABLED=True)
|
||||
@patch("allianceauth.eveonline.models.EveCharacter.objects.update_character")
|
||||
def test_should_not_run_analytics_for_successful_task(
|
||||
self, requests_mocker, mock_update_character
|
||||
):
|
||||
# given
|
||||
requests_mocker.post(ANALYTICS_URL)
|
||||
user = AuthUtils.create_user("Bruce Wayne")
|
||||
character = AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
|
||||
# when
|
||||
update_character.delay(character.character_id)
|
||||
# then
|
||||
self.assertTrue(mock_update_character.called)
|
||||
self.assertFalse(requests_mocker.called)
|
||||
|
||||
@override_settings(ANALYTICS_DISABLED=False)
|
||||
@patch("allianceauth.eveonline.models.EveCharacter.objects.update_character")
|
||||
def test_should_run_analytics_for_failed_task(
|
||||
self, requests_mocker, mock_update_character
|
||||
):
|
||||
# given
|
||||
requests_mocker.post(ANALYTICS_URL)
|
||||
mock_update_character.side_effect = RuntimeError
|
||||
user = AuthUtils.create_user("Bruce Wayne")
|
||||
character = AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
|
||||
# when
|
||||
update_character.delay(character.character_id)
|
||||
# then
|
||||
self.assertTrue(mock_update_character.called)
|
||||
self.assertTrue(requests_mocker.called)
|
||||
payload = parse_qs(requests_mocker.last_request.text)
|
||||
self.assertNotEqual(payload["el"], ["Success"])
|
||||
|
||||
@override_settings(ANALYTICS_DISABLED=True)
|
||||
@patch("allianceauth.eveonline.models.EveCharacter.objects.update_character")
|
||||
def test_should_not_run_analytics_for_failed_task(
|
||||
self, requests_mocker, mock_update_character
|
||||
):
|
||||
# given
|
||||
requests_mocker.post(ANALYTICS_URL)
|
||||
mock_update_character.side_effect = RuntimeError
|
||||
user = AuthUtils.create_user("Bruce Wayne")
|
||||
character = AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
|
||||
# when
|
||||
update_character.delay(character.character_id)
|
||||
# then
|
||||
self.assertTrue(mock_update_character.called)
|
||||
self.assertFalse(requests_mocker.called)
|
||||
@@ -1,12 +1,22 @@
|
||||
import requests_mock
|
||||
|
||||
from django.test.utils import override_settings
|
||||
|
||||
from allianceauth.analytics.tasks import (
|
||||
analytics_event,
|
||||
send_ga_tracking_celery_event,
|
||||
send_ga_tracking_web_view)
|
||||
from django.test.testcases import TestCase
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
|
||||
class TestAnalyticsTasks(TestCase):
|
||||
def test_analytics_event(self):
|
||||
GOOGLE_ANALYTICS_DEBUG_URL = 'https://www.google-analytics.com/debug/collect'
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
|
||||
@requests_mock.Mocker()
|
||||
class TestAnalyticsTasks(NoSocketsTestCase):
|
||||
def test_analytics_event(self, requests_mocker):
|
||||
requests_mocker.register_uri('POST', GOOGLE_ANALYTICS_DEBUG_URL)
|
||||
analytics_event(
|
||||
category='allianceauth.analytics',
|
||||
action='send_tests',
|
||||
@@ -14,15 +24,19 @@ class TestAnalyticsTasks(TestCase):
|
||||
value=1,
|
||||
event_type='Stats')
|
||||
|
||||
def test_send_ga_tracking_web_view_sent(self):
|
||||
# This test sends if the event SENDS to google
|
||||
# Not if it was successful
|
||||
def test_send_ga_tracking_web_view_sent(self, requests_mocker):
|
||||
"""This test sends if the event SENDS to google.
|
||||
Not if it was successful.
|
||||
"""
|
||||
# given
|
||||
requests_mocker.register_uri('POST', GOOGLE_ANALYTICS_DEBUG_URL)
|
||||
tracking_id = 'UA-186249766-2'
|
||||
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
|
||||
page = '/index/'
|
||||
title = 'Hello World'
|
||||
locale = 'en'
|
||||
useragent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36"
|
||||
# when
|
||||
response = send_ga_tracking_web_view(
|
||||
tracking_id,
|
||||
client_id,
|
||||
@@ -30,15 +44,23 @@ class TestAnalyticsTasks(TestCase):
|
||||
title,
|
||||
locale,
|
||||
useragent)
|
||||
# then
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_send_ga_tracking_web_view_success(self):
|
||||
def test_send_ga_tracking_web_view_success(self, requests_mocker):
|
||||
# given
|
||||
requests_mocker.register_uri(
|
||||
'POST',
|
||||
GOOGLE_ANALYTICS_DEBUG_URL,
|
||||
json={"hitParsingResult":[{'valid': True}]}
|
||||
)
|
||||
tracking_id = 'UA-186249766-2'
|
||||
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
|
||||
page = '/index/'
|
||||
title = 'Hello World'
|
||||
locale = 'en'
|
||||
useragent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36"
|
||||
# when
|
||||
json_response = send_ga_tracking_web_view(
|
||||
tracking_id,
|
||||
client_id,
|
||||
@@ -46,15 +68,42 @@ class TestAnalyticsTasks(TestCase):
|
||||
title,
|
||||
locale,
|
||||
useragent).json()
|
||||
# then
|
||||
self.assertTrue(json_response["hitParsingResult"][0]["valid"])
|
||||
|
||||
def test_send_ga_tracking_web_view_invalid_token(self):
|
||||
def test_send_ga_tracking_web_view_invalid_token(self, requests_mocker):
|
||||
# given
|
||||
requests_mocker.register_uri(
|
||||
'POST',
|
||||
GOOGLE_ANALYTICS_DEBUG_URL,
|
||||
json={
|
||||
"hitParsingResult":[
|
||||
{
|
||||
'valid': False,
|
||||
'parserMessage': [
|
||||
{
|
||||
'messageType': 'INFO',
|
||||
'description': 'IP Address from this hit was anonymized to 1.132.110.0.',
|
||||
'messageCode': 'VALUE_MODIFIED'
|
||||
},
|
||||
{
|
||||
'messageType': 'ERROR',
|
||||
'description': "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.",
|
||||
'messageCode': 'VALUE_INVALID', 'parameter': 'tid'
|
||||
}
|
||||
],
|
||||
'hit': '/debug/collect?v=1&tid=UA-IntentionallyBadTrackingID-2&cid=ab33e241fbf042b6aa77c7655a768af7&t=pageview&dp=/index/&dt=Hello World&ul=en&ua=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36&aip=1&an=allianceauth&av=2.9.0a2'
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
tracking_id = 'UA-IntentionallyBadTrackingID-2'
|
||||
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
|
||||
page = '/index/'
|
||||
title = 'Hello World'
|
||||
locale = 'en'
|
||||
useragent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36"
|
||||
# when
|
||||
json_response = send_ga_tracking_web_view(
|
||||
tracking_id,
|
||||
client_id,
|
||||
@@ -62,18 +111,25 @@ class TestAnalyticsTasks(TestCase):
|
||||
title,
|
||||
locale,
|
||||
useragent).json()
|
||||
# then
|
||||
self.assertFalse(json_response["hitParsingResult"][0]["valid"])
|
||||
self.assertEqual(json_response["hitParsingResult"][0]["parserMessage"][1]["description"], "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.")
|
||||
self.assertEqual(
|
||||
json_response["hitParsingResult"][0]["parserMessage"][1]["description"],
|
||||
"The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details."
|
||||
)
|
||||
|
||||
# [{'valid': False, 'parserMessage': [{'messageType': 'INFO', 'description': 'IP Address from this hit was anonymized to 1.132.110.0.', 'messageCode': 'VALUE_MODIFIED'}, {'messageType': 'ERROR', 'description': "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.", 'messageCode': 'VALUE_INVALID', 'parameter': 'tid'}], 'hit': '/debug/collect?v=1&tid=UA-IntentionallyBadTrackingID-2&cid=ab33e241fbf042b6aa77c7655a768af7&t=pageview&dp=/index/&dt=Hello World&ul=en&ua=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36&aip=1&an=allianceauth&av=2.9.0a2'}]
|
||||
|
||||
def test_send_ga_tracking_celery_event_sent(self):
|
||||
def test_send_ga_tracking_celery_event_sent(self, requests_mocker):
|
||||
# given
|
||||
requests_mocker.register_uri('POST', GOOGLE_ANALYTICS_DEBUG_URL)
|
||||
tracking_id = 'UA-186249766-2'
|
||||
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
|
||||
category = 'test'
|
||||
action = 'test'
|
||||
label = 'test'
|
||||
value = '1'
|
||||
# when
|
||||
response = send_ga_tracking_celery_event(
|
||||
tracking_id,
|
||||
client_id,
|
||||
@@ -81,15 +137,23 @@ class TestAnalyticsTasks(TestCase):
|
||||
action,
|
||||
label,
|
||||
value)
|
||||
# then
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_send_ga_tracking_celery_event_success(self):
|
||||
def test_send_ga_tracking_celery_event_success(self, requests_mocker):
|
||||
# given
|
||||
requests_mocker.register_uri(
|
||||
'POST',
|
||||
GOOGLE_ANALYTICS_DEBUG_URL,
|
||||
json={"hitParsingResult":[{'valid': True}]}
|
||||
)
|
||||
tracking_id = 'UA-186249766-2'
|
||||
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
|
||||
category = 'test'
|
||||
action = 'test'
|
||||
label = 'test'
|
||||
value = '1'
|
||||
# when
|
||||
json_response = send_ga_tracking_celery_event(
|
||||
tracking_id,
|
||||
client_id,
|
||||
@@ -97,15 +161,42 @@ class TestAnalyticsTasks(TestCase):
|
||||
action,
|
||||
label,
|
||||
value).json()
|
||||
# then
|
||||
self.assertTrue(json_response["hitParsingResult"][0]["valid"])
|
||||
|
||||
def test_send_ga_tracking_celery_event_invalid_token(self):
|
||||
def test_send_ga_tracking_celery_event_invalid_token(self, requests_mocker):
|
||||
# given
|
||||
requests_mocker.register_uri(
|
||||
'POST',
|
||||
GOOGLE_ANALYTICS_DEBUG_URL,
|
||||
json={
|
||||
"hitParsingResult":[
|
||||
{
|
||||
'valid': False,
|
||||
'parserMessage': [
|
||||
{
|
||||
'messageType': 'INFO',
|
||||
'description': 'IP Address from this hit was anonymized to 1.132.110.0.',
|
||||
'messageCode': 'VALUE_MODIFIED'
|
||||
},
|
||||
{
|
||||
'messageType': 'ERROR',
|
||||
'description': "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.",
|
||||
'messageCode': 'VALUE_INVALID', 'parameter': 'tid'
|
||||
}
|
||||
],
|
||||
'hit': '/debug/collect?v=1&tid=UA-IntentionallyBadTrackingID-2&cid=ab33e241fbf042b6aa77c7655a768af7&t=pageview&dp=/index/&dt=Hello World&ul=en&ua=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36&aip=1&an=allianceauth&av=2.9.0a2'
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
tracking_id = 'UA-IntentionallyBadTrackingID-2'
|
||||
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
|
||||
category = 'test'
|
||||
action = 'test'
|
||||
label = 'test'
|
||||
value = '1'
|
||||
# when
|
||||
json_response = send_ga_tracking_celery_event(
|
||||
tracking_id,
|
||||
client_id,
|
||||
@@ -113,7 +204,9 @@ class TestAnalyticsTasks(TestCase):
|
||||
action,
|
||||
label,
|
||||
value).json()
|
||||
# then
|
||||
self.assertFalse(json_response["hitParsingResult"][0]["valid"])
|
||||
self.assertEqual(json_response["hitParsingResult"][0]["parserMessage"][1]["description"], "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.")
|
||||
|
||||
# [{'valid': False, 'parserMessage': [{'messageType': 'INFO', 'description': 'IP Address from this hit was anonymized to 1.132.110.0.', 'messageCode': 'VALUE_MODIFIED'}, {'messageType': 'ERROR', 'description': "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.", 'messageCode': 'VALUE_INVALID', 'parameter': 'tid'}], 'hit': '/debug/collect?v=1&tid=UA-IntentionallyBadTrackingID-2&cid=ab33e241fbf042b6aa77c7655a768af7&t=pageview&dp=/index/&dt=Hello World&ul=en&ua=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36&aip=1&an=allianceauth&av=2.9.0a2'}]
|
||||
self.assertEqual(
|
||||
json_response["hitParsingResult"][0]["parserMessage"][1]["description"],
|
||||
"The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details."
|
||||
)
|
||||
|
||||
@@ -1,26 +1,44 @@
|
||||
from django.contrib import admin
|
||||
from django.contrib.auth.admin import UserAdmin as BaseUserAdmin
|
||||
from django.contrib.auth.models import User as BaseUser, \
|
||||
Permission as BasePermission, Group
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import Permission as BasePermission
|
||||
from django.contrib.auth.models import User as BaseUser
|
||||
from django.db.models import Count, Q
|
||||
from allianceauth.services.hooks import ServicesHook
|
||||
from django.db.models.signals import pre_save, post_save, pre_delete, \
|
||||
post_delete, m2m_changed
|
||||
from django.db.models.functions import Lower
|
||||
from django.db.models.signals import (
|
||||
m2m_changed,
|
||||
post_delete,
|
||||
post_save,
|
||||
pre_delete,
|
||||
pre_save
|
||||
)
|
||||
from django.dispatch import receiver
|
||||
from django.forms import ModelForm
|
||||
from django.utils.html import format_html
|
||||
from django.urls import reverse
|
||||
from django.utils.html import format_html
|
||||
from django.utils.text import slugify
|
||||
|
||||
from allianceauth.authentication.models import State, get_guest_state,\
|
||||
CharacterOwnership, UserProfile, OwnershipRecord
|
||||
from allianceauth.hooks import get_hooks
|
||||
from allianceauth.eveonline.models import EveCharacter, EveCorporationInfo,\
|
||||
EveAllianceInfo, EveFactionInfo
|
||||
from allianceauth.authentication.models import (
|
||||
CharacterOwnership,
|
||||
OwnershipRecord,
|
||||
State,
|
||||
UserProfile,
|
||||
get_guest_state
|
||||
)
|
||||
from allianceauth.eveonline.models import (
|
||||
EveAllianceInfo,
|
||||
EveCharacter,
|
||||
EveCorporationInfo,
|
||||
EveFactionInfo
|
||||
)
|
||||
from allianceauth.eveonline.tasks import update_character
|
||||
from .app_settings import AUTHENTICATION_ADMIN_USERS_MAX_GROUPS, \
|
||||
AUTHENTICATION_ADMIN_USERS_MAX_CHARS
|
||||
from allianceauth.hooks import get_hooks
|
||||
from allianceauth.services.hooks import ServicesHook
|
||||
|
||||
from .app_settings import (
|
||||
AUTHENTICATION_ADMIN_USERS_MAX_CHARS,
|
||||
AUTHENTICATION_ADMIN_USERS_MAX_GROUPS
|
||||
)
|
||||
from .forms import UserChangeForm, UserProfileForm
|
||||
|
||||
|
||||
def make_service_hooks_update_groups_action(service):
|
||||
@@ -59,19 +77,10 @@ def make_service_hooks_sync_nickname_action(service):
|
||||
return sync_nickname
|
||||
|
||||
|
||||
class QuerysetModelForm(ModelForm):
|
||||
# allows specifying FK querysets through kwarg
|
||||
def __init__(self, querysets=None, *args, **kwargs):
|
||||
querysets = querysets or {}
|
||||
super().__init__(*args, **kwargs)
|
||||
for field, qs in querysets.items():
|
||||
self.fields[field].queryset = qs
|
||||
|
||||
|
||||
class UserProfileInline(admin.StackedInline):
|
||||
model = UserProfile
|
||||
readonly_fields = ('state',)
|
||||
form = QuerysetModelForm
|
||||
form = UserProfileForm
|
||||
verbose_name = ''
|
||||
verbose_name_plural = 'Profile'
|
||||
|
||||
@@ -99,6 +108,7 @@ class UserProfileInline(admin.StackedInline):
|
||||
return False
|
||||
|
||||
|
||||
@admin.display(description="")
|
||||
def user_profile_pic(obj):
|
||||
"""profile pic column data for user objects
|
||||
|
||||
@@ -111,13 +121,10 @@ def user_profile_pic(obj):
|
||||
'<img src="{}" class="img-circle">',
|
||||
user_obj.profile.main_character.portrait_url(size=32)
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
user_profile_pic.short_description = ''
|
||||
return None
|
||||
|
||||
|
||||
@admin.display(description="user / main", ordering="username")
|
||||
def user_username(obj):
|
||||
"""user column data for user objects
|
||||
|
||||
@@ -139,18 +146,17 @@ def user_username(obj):
|
||||
user_obj.username,
|
||||
user_obj.profile.main_character.character_name
|
||||
)
|
||||
else:
|
||||
return format_html(
|
||||
'<strong><a href="{}">{}</a></strong>',
|
||||
link,
|
||||
user_obj.username,
|
||||
)
|
||||
|
||||
|
||||
user_username.short_description = 'user / main'
|
||||
user_username.admin_order_field = 'username'
|
||||
return format_html(
|
||||
'<strong><a href="{}">{}</a></strong>',
|
||||
link,
|
||||
user_obj.username,
|
||||
)
|
||||
|
||||
|
||||
@admin.display(
|
||||
description="Corporation / Alliance (Main)",
|
||||
ordering="profile__main_character__corporation_name"
|
||||
)
|
||||
def user_main_organization(obj):
|
||||
"""main organization column data for user objects
|
||||
|
||||
@@ -159,21 +165,15 @@ def user_main_organization(obj):
|
||||
"""
|
||||
user_obj = obj.user if hasattr(obj, 'user') else obj
|
||||
if not user_obj.profile.main_character:
|
||||
result = ''
|
||||
else:
|
||||
result = user_obj.profile.main_character.corporation_name
|
||||
if user_obj.profile.main_character.alliance_id:
|
||||
result += f'<br>{user_obj.profile.main_character.alliance_name}'
|
||||
elif user_obj.profile.main_character.faction_name:
|
||||
result += f'<br>{user_obj.profile.main_character.faction_name}'
|
||||
return ''
|
||||
result = user_obj.profile.main_character.corporation_name
|
||||
if user_obj.profile.main_character.alliance_id:
|
||||
result += f'<br>{user_obj.profile.main_character.alliance_name}'
|
||||
elif user_obj.profile.main_character.faction_name:
|
||||
result += f'<br>{user_obj.profile.main_character.faction_name}'
|
||||
return format_html(result)
|
||||
|
||||
|
||||
user_main_organization.short_description = 'Corporation / Alliance (Main)'
|
||||
user_main_organization.admin_order_field = \
|
||||
'profile__main_character__corporation_name'
|
||||
|
||||
|
||||
class MainCorporationsFilter(admin.SimpleListFilter):
|
||||
"""Custom filter to filter on corporations from mains only
|
||||
|
||||
@@ -196,15 +196,13 @@ class MainCorporationsFilter(admin.SimpleListFilter):
|
||||
def queryset(self, request, qs):
|
||||
if self.value() is None:
|
||||
return qs.all()
|
||||
else:
|
||||
if qs.model == User:
|
||||
return qs.filter(
|
||||
profile__main_character__corporation_id=self.value()
|
||||
)
|
||||
else:
|
||||
return qs.filter(
|
||||
user__profile__main_character__corporation_id=self.value()
|
||||
)
|
||||
if qs.model == User:
|
||||
return qs.filter(
|
||||
profile__main_character__corporation_id=self.value()
|
||||
)
|
||||
return qs.filter(
|
||||
user__profile__main_character__corporation_id=self.value()
|
||||
)
|
||||
|
||||
|
||||
class MainAllianceFilter(admin.SimpleListFilter):
|
||||
@@ -217,12 +215,14 @@ class MainAllianceFilter(admin.SimpleListFilter):
|
||||
parameter_name = 'main_alliance_id__exact'
|
||||
|
||||
def lookups(self, request, model_admin):
|
||||
qs = EveCharacter.objects\
|
||||
.exclude(alliance_id=None)\
|
||||
.exclude(userprofile=None)\
|
||||
.values('alliance_id', 'alliance_name')\
|
||||
.distinct()\
|
||||
qs = (
|
||||
EveCharacter.objects
|
||||
.exclude(alliance_id=None)
|
||||
.exclude(userprofile=None)
|
||||
.values('alliance_id', 'alliance_name')
|
||||
.distinct()
|
||||
.order_by(Lower('alliance_name'))
|
||||
)
|
||||
return tuple(
|
||||
(x['alliance_id'], x['alliance_name']) for x in qs
|
||||
)
|
||||
@@ -230,13 +230,11 @@ class MainAllianceFilter(admin.SimpleListFilter):
|
||||
def queryset(self, request, qs):
|
||||
if self.value() is None:
|
||||
return qs.all()
|
||||
else:
|
||||
if qs.model == User:
|
||||
return qs.filter(profile__main_character__alliance_id=self.value())
|
||||
else:
|
||||
return qs.filter(
|
||||
user__profile__main_character__alliance_id=self.value()
|
||||
)
|
||||
if qs.model == User:
|
||||
return qs.filter(profile__main_character__alliance_id=self.value())
|
||||
return qs.filter(
|
||||
user__profile__main_character__alliance_id=self.value()
|
||||
)
|
||||
|
||||
|
||||
class MainFactionFilter(admin.SimpleListFilter):
|
||||
@@ -249,12 +247,14 @@ class MainFactionFilter(admin.SimpleListFilter):
|
||||
parameter_name = 'main_faction_id__exact'
|
||||
|
||||
def lookups(self, request, model_admin):
|
||||
qs = EveCharacter.objects\
|
||||
.exclude(faction_id=None)\
|
||||
.exclude(userprofile=None)\
|
||||
.values('faction_id', 'faction_name')\
|
||||
.distinct()\
|
||||
qs = (
|
||||
EveCharacter.objects
|
||||
.exclude(faction_id=None)
|
||||
.exclude(userprofile=None)
|
||||
.values('faction_id', 'faction_name')
|
||||
.distinct()
|
||||
.order_by(Lower('faction_name'))
|
||||
)
|
||||
return tuple(
|
||||
(x['faction_id'], x['faction_name']) for x in qs
|
||||
)
|
||||
@@ -262,15 +262,14 @@ class MainFactionFilter(admin.SimpleListFilter):
|
||||
def queryset(self, request, qs):
|
||||
if self.value() is None:
|
||||
return qs.all()
|
||||
else:
|
||||
if qs.model == User:
|
||||
return qs.filter(profile__main_character__faction_id=self.value())
|
||||
else:
|
||||
return qs.filter(
|
||||
user__profile__main_character__faction_id=self.value()
|
||||
)
|
||||
if qs.model == User:
|
||||
return qs.filter(profile__main_character__faction_id=self.value())
|
||||
return qs.filter(
|
||||
user__profile__main_character__faction_id=self.value()
|
||||
)
|
||||
|
||||
|
||||
@admin.display(description="Update main character model from ESI")
|
||||
def update_main_character_model(modeladmin, request, queryset):
|
||||
tasks_count = 0
|
||||
for obj in queryset:
|
||||
@@ -279,21 +278,48 @@ def update_main_character_model(modeladmin, request, queryset):
|
||||
tasks_count += 1
|
||||
|
||||
modeladmin.message_user(
|
||||
request,
|
||||
f'Update from ESI started for {tasks_count} characters'
|
||||
request, f'Update from ESI started for {tasks_count} characters'
|
||||
)
|
||||
|
||||
|
||||
update_main_character_model.short_description = \
|
||||
'Update main character model from ESI'
|
||||
|
||||
|
||||
class UserAdmin(BaseUserAdmin):
|
||||
"""Extending Django's UserAdmin model
|
||||
|
||||
Behavior of groups and characters columns can be configured via settings
|
||||
"""
|
||||
|
||||
inlines = BaseUserAdmin.inlines + [UserProfileInline]
|
||||
ordering = ('username', )
|
||||
list_select_related = ('profile__state', 'profile__main_character')
|
||||
show_full_result_count = True
|
||||
list_display = (
|
||||
user_profile_pic,
|
||||
user_username,
|
||||
'_state',
|
||||
'_groups',
|
||||
user_main_organization,
|
||||
'_characters',
|
||||
'is_active',
|
||||
'date_joined',
|
||||
'_role'
|
||||
)
|
||||
list_display_links = None
|
||||
list_filter = (
|
||||
'profile__state',
|
||||
'groups',
|
||||
MainCorporationsFilter,
|
||||
MainAllianceFilter,
|
||||
MainFactionFilter,
|
||||
'is_active',
|
||||
'date_joined',
|
||||
'is_staff',
|
||||
'is_superuser'
|
||||
)
|
||||
search_fields = ('username', 'character_ownerships__character__character_name')
|
||||
readonly_fields = ('date_joined', 'last_login')
|
||||
filter_horizontal = ('groups', 'user_permissions',)
|
||||
form = UserChangeForm
|
||||
|
||||
class Media:
|
||||
css = {
|
||||
"all": ("authentication/css/admin.css",)
|
||||
@@ -303,9 +329,21 @@ class UserAdmin(BaseUserAdmin):
|
||||
qs = super().get_queryset(request)
|
||||
return qs.prefetch_related("character_ownerships__character", "groups")
|
||||
|
||||
def get_actions(self, request):
|
||||
actions = super(BaseUserAdmin, self).get_actions(request)
|
||||
def get_form(self, request, obj=None, **kwargs):
|
||||
"""Inject current request into change form object."""
|
||||
|
||||
MyForm = super().get_form(request, obj, **kwargs)
|
||||
if obj:
|
||||
class MyFormInjected(MyForm):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
kwargs['request'] = request
|
||||
return MyForm(*args, **kwargs)
|
||||
|
||||
return MyFormInjected
|
||||
return MyForm
|
||||
|
||||
def get_actions(self, request):
|
||||
actions = super().get_actions(request)
|
||||
actions[update_main_character_model.__name__] = (
|
||||
update_main_character_model,
|
||||
update_main_character_model.__name__,
|
||||
@@ -349,38 +387,6 @@ class UserAdmin(BaseUserAdmin):
|
||||
)
|
||||
return result
|
||||
|
||||
inlines = BaseUserAdmin.inlines + [UserProfileInline]
|
||||
ordering = ('username', )
|
||||
list_select_related = ('profile__state', 'profile__main_character')
|
||||
show_full_result_count = True
|
||||
list_display = (
|
||||
user_profile_pic,
|
||||
user_username,
|
||||
'_state',
|
||||
'_groups',
|
||||
user_main_organization,
|
||||
'_characters',
|
||||
'is_active',
|
||||
'date_joined',
|
||||
'_role'
|
||||
)
|
||||
list_display_links = None
|
||||
list_filter = (
|
||||
'profile__state',
|
||||
'groups',
|
||||
MainCorporationsFilter,
|
||||
MainAllianceFilter,
|
||||
MainFactionFilter,
|
||||
'is_active',
|
||||
'date_joined',
|
||||
'is_staff',
|
||||
'is_superuser'
|
||||
)
|
||||
search_fields = (
|
||||
'username',
|
||||
'character_ownerships__character__character_name'
|
||||
)
|
||||
|
||||
def _characters(self, obj):
|
||||
character_ownerships = list(obj.character_ownerships.all())
|
||||
characters = [obj.character.character_name for obj in character_ownerships]
|
||||
@@ -389,22 +395,16 @@ class UserAdmin(BaseUserAdmin):
|
||||
AUTHENTICATION_ADMIN_USERS_MAX_CHARS
|
||||
)
|
||||
|
||||
_characters.short_description = 'characters'
|
||||
|
||||
@admin.display(ordering="profile__state")
|
||||
def _state(self, obj):
|
||||
return obj.profile.state.name
|
||||
|
||||
_state.short_description = 'state'
|
||||
_state.admin_order_field = 'profile__state'
|
||||
|
||||
def _groups(self, obj):
|
||||
my_groups = sorted(group.name for group in list(obj.groups.all()))
|
||||
return self._list_2_html_w_tooltips(
|
||||
my_groups, AUTHENTICATION_ADMIN_USERS_MAX_GROUPS
|
||||
)
|
||||
|
||||
_groups.short_description = 'groups'
|
||||
|
||||
def _role(self, obj):
|
||||
if obj.is_superuser:
|
||||
role = 'Superuser'
|
||||
@@ -414,8 +414,6 @@ class UserAdmin(BaseUserAdmin):
|
||||
role = 'User'
|
||||
return role
|
||||
|
||||
_role.short_description = 'role'
|
||||
|
||||
def has_change_permission(self, request, obj=None):
|
||||
return request.user.has_perm('auth.change_user')
|
||||
|
||||
@@ -425,12 +423,28 @@ class UserAdmin(BaseUserAdmin):
|
||||
def has_delete_permission(self, request, obj=None):
|
||||
return request.user.has_perm('auth.delete_user')
|
||||
|
||||
def get_object(self, *args , **kwargs):
|
||||
obj = super().get_object(*args , **kwargs)
|
||||
self.obj = obj # storing current object for use in formfield_for_manytomany
|
||||
return obj
|
||||
|
||||
def formfield_for_manytomany(self, db_field, request, **kwargs):
|
||||
"""overriding this formfield to have sorted lists in the form"""
|
||||
if db_field.name == "groups":
|
||||
kwargs["queryset"] = Group.objects.all().order_by(Lower('name'))
|
||||
groups_qs = Group.objects.filter(authgroup__states__isnull=True)
|
||||
obj_state = self.obj.profile.state
|
||||
if obj_state:
|
||||
matching_groups_qs = Group.objects.filter(authgroup__states=obj_state)
|
||||
groups_qs = groups_qs | matching_groups_qs
|
||||
kwargs["queryset"] = groups_qs.order_by(Lower("name"))
|
||||
return super().formfield_for_manytomany(db_field, request, **kwargs)
|
||||
|
||||
def get_readonly_fields(self, request, obj=None):
|
||||
if obj and not request.user.is_superuser:
|
||||
return self.readonly_fields + (
|
||||
"is_staff", "is_superuser", "user_permissions"
|
||||
)
|
||||
return self.readonly_fields
|
||||
|
||||
|
||||
@admin.register(State)
|
||||
class StateAdmin(admin.ModelAdmin):
|
||||
@@ -441,10 +455,9 @@ class StateAdmin(admin.ModelAdmin):
|
||||
qs = super().get_queryset(request)
|
||||
return qs.annotate(user_count=Count("userprofile__id"))
|
||||
|
||||
@admin.display(description="Users", ordering="user_count")
|
||||
def _user_count(self, obj):
|
||||
return obj.user_count
|
||||
_user_count.short_description = 'Users'
|
||||
_user_count.admin_order_field = 'user_count'
|
||||
|
||||
fieldsets = (
|
||||
(None, {
|
||||
@@ -500,13 +513,13 @@ class StateAdmin(admin.ModelAdmin):
|
||||
)
|
||||
return super().get_fieldsets(request, obj=obj)
|
||||
|
||||
def get_readonly_fields(self, request, obj=None):
|
||||
if not request.user.is_superuser:
|
||||
return self.readonly_fields + ("permissions",)
|
||||
return self.readonly_fields
|
||||
|
||||
|
||||
class BaseOwnershipAdmin(admin.ModelAdmin):
|
||||
class Media:
|
||||
css = {
|
||||
"all": ("authentication/css/admin.css",)
|
||||
}
|
||||
|
||||
list_select_related = (
|
||||
'user__profile__state', 'user__profile__main_character', 'character')
|
||||
list_display = (
|
||||
@@ -527,6 +540,11 @@ class BaseOwnershipAdmin(admin.ModelAdmin):
|
||||
MainAllianceFilter,
|
||||
)
|
||||
|
||||
class Media:
|
||||
css = {
|
||||
"all": ("authentication/css/admin.css",)
|
||||
}
|
||||
|
||||
def get_readonly_fields(self, request, obj=None):
|
||||
if obj and obj.pk:
|
||||
return 'owner_hash', 'character'
|
||||
|
||||
@@ -3,10 +3,14 @@ from django.core.checks import register, Tags
|
||||
|
||||
|
||||
class AuthenticationConfig(AppConfig):
|
||||
name = 'allianceauth.authentication'
|
||||
label = 'authentication'
|
||||
name = "allianceauth.authentication"
|
||||
label = "authentication"
|
||||
|
||||
def ready(self):
|
||||
super().ready()
|
||||
from allianceauth.authentication import checks, signals
|
||||
from allianceauth.authentication import checks, signals # noqa: F401
|
||||
from allianceauth.authentication.task_statistics import (
|
||||
signals as celery_signals,
|
||||
)
|
||||
|
||||
register(Tags.security)(checks.check_login_scopes_setting)
|
||||
celery_signals.reset_counters()
|
||||
|
||||
@@ -1,8 +1,66 @@
|
||||
from django import forms
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from django.contrib.auth.forms import UserChangeForm as BaseUserChangeForm
|
||||
from django.contrib.auth.models import Group
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.forms import ModelForm
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from allianceauth.authentication.models import User
|
||||
|
||||
|
||||
class RegistrationForm(forms.Form):
|
||||
email = forms.EmailField(label=_('Email'), max_length=254, required=True)
|
||||
|
||||
class _meta:
|
||||
model = User
|
||||
|
||||
|
||||
class UserProfileForm(ModelForm):
|
||||
"""Allows specifying FK querysets through kwarg"""
|
||||
|
||||
def __init__(self, querysets=None, *args, **kwargs):
|
||||
querysets = querysets or {}
|
||||
super().__init__(*args, **kwargs)
|
||||
for field, qs in querysets.items():
|
||||
self.fields[field].queryset = qs
|
||||
|
||||
|
||||
class UserChangeForm(BaseUserChangeForm):
|
||||
"""Add custom cleaning to UserChangeForm"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.request = kwargs.pop("request") # Inject current request into form object
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
if not self.request.user.is_superuser:
|
||||
if self.instance:
|
||||
current_restricted = set(
|
||||
self.instance.groups.filter(
|
||||
authgroup__restricted=True
|
||||
).values_list("pk", flat=True)
|
||||
)
|
||||
else:
|
||||
current_restricted = set()
|
||||
new_restricted = set(
|
||||
cleaned_data["groups"].filter(
|
||||
authgroup__restricted=True
|
||||
).values_list("pk", flat=True)
|
||||
)
|
||||
if current_restricted != new_restricted:
|
||||
restricted_removed = current_restricted - new_restricted
|
||||
restricted_added = new_restricted - current_restricted
|
||||
restricted_changed = restricted_removed | restricted_added
|
||||
restricted_names_qs = Group.objects.filter(
|
||||
pk__in=restricted_changed
|
||||
).values_list("name", flat=True)
|
||||
restricted_names = ",".join(list(restricted_names_qs))
|
||||
raise ValidationError(
|
||||
{
|
||||
"groups": _(
|
||||
"You are not allowed to add or remove these "
|
||||
"restricted groups: %s" % restricted_names
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
40
allianceauth/authentication/task_statistics/counters.py
Normal file
40
allianceauth/authentication/task_statistics/counters.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from collections import namedtuple
|
||||
import datetime as dt
|
||||
|
||||
from .event_series import EventSeries
|
||||
|
||||
|
||||
"""Global series for counting task events."""
|
||||
succeeded_tasks = EventSeries("SUCCEEDED_TASKS")
|
||||
retried_tasks = EventSeries("RETRIED_TASKS")
|
||||
failed_tasks = EventSeries("FAILED_TASKS")
|
||||
|
||||
|
||||
_TaskCounts = namedtuple(
|
||||
"_TaskCounts", ["succeeded", "retried", "failed", "total", "earliest_task", "hours"]
|
||||
)
|
||||
|
||||
|
||||
def dashboard_results(hours: int) -> _TaskCounts:
|
||||
"""Counts of all task events within the given timeframe."""
|
||||
|
||||
def earliest_if_exists(events: EventSeries, earliest: dt.datetime) -> list:
|
||||
my_earliest = events.first_event(earliest=earliest)
|
||||
return [my_earliest] if my_earliest else []
|
||||
|
||||
earliest = dt.datetime.utcnow() - dt.timedelta(hours=hours)
|
||||
earliest_events = list()
|
||||
succeeded_count = succeeded_tasks.count(earliest=earliest)
|
||||
earliest_events += earliest_if_exists(succeeded_tasks, earliest)
|
||||
retried_count = retried_tasks.count(earliest=earliest)
|
||||
earliest_events += earliest_if_exists(retried_tasks, earliest)
|
||||
failed_count = failed_tasks.count(earliest=earliest)
|
||||
earliest_events += earliest_if_exists(failed_tasks, earliest)
|
||||
return _TaskCounts(
|
||||
succeeded=succeeded_count,
|
||||
retried=retried_count,
|
||||
failed=failed_count,
|
||||
total=succeeded_count + retried_count + failed_count,
|
||||
earliest_task=min(earliest_events) if earliest_events else None,
|
||||
hours=hours,
|
||||
)
|
||||
130
allianceauth/authentication/task_statistics/event_series.py
Normal file
130
allianceauth/authentication/task_statistics/event_series.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import datetime as dt
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from pytz import utc
|
||||
from redis import Redis, RedisError
|
||||
|
||||
from allianceauth.utils.cache import get_redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RedisStub:
|
||||
"""Stub of a Redis client.
|
||||
|
||||
It's purpose is to prevent EventSeries objects from trying to access Redis
|
||||
when it is not available. e.g. when the Sphinx docs are rendered by readthedocs.org.
|
||||
"""
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def incr(self, *args, **kwargs):
|
||||
return 0
|
||||
|
||||
def zadd(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def zcount(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def zrangebyscore(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class EventSeries:
|
||||
"""API for recording and analyzing a series of events."""
|
||||
|
||||
_ROOT_KEY = "ALLIANCEAUTH_EVENT_SERIES"
|
||||
|
||||
def __init__(self, key_id: str, redis: Redis = None) -> None:
|
||||
self._redis = get_redis_client() if not redis else redis
|
||||
try:
|
||||
if not self._redis.ping():
|
||||
raise RuntimeError()
|
||||
except (AttributeError, RedisError, RuntimeError):
|
||||
logger.exception(
|
||||
"Failed to establish a connection with Redis. "
|
||||
"This EventSeries object is disabled.",
|
||||
)
|
||||
self._redis = _RedisStub()
|
||||
self._key_id = str(key_id)
|
||||
self.clear()
|
||||
|
||||
@property
|
||||
def is_disabled(self):
|
||||
"""True when this object is disabled, e.g. Redis was not available at startup."""
|
||||
return isinstance(self._redis, _RedisStub)
|
||||
|
||||
@property
|
||||
def _key_counter(self):
|
||||
return f"{self._ROOT_KEY}_{self._key_id}_COUNTER"
|
||||
|
||||
@property
|
||||
def _key_sorted_set(self):
|
||||
return f"{self._ROOT_KEY}_{self._key_id}_SORTED_SET"
|
||||
|
||||
def add(self, event_time: dt.datetime = None) -> None:
|
||||
"""Add event.
|
||||
|
||||
Args:
|
||||
- event_time: timestamp of event. Will use current time if not specified.
|
||||
"""
|
||||
if not event_time:
|
||||
event_time = dt.datetime.utcnow()
|
||||
id = self._redis.incr(self._key_counter)
|
||||
self._redis.zadd(self._key_sorted_set, {id: event_time.timestamp()})
|
||||
|
||||
def all(self) -> List[dt.datetime]:
|
||||
"""List of all known events."""
|
||||
return [
|
||||
event[1]
|
||||
for event in self._redis.zrangebyscore(
|
||||
self._key_sorted_set,
|
||||
"-inf",
|
||||
"+inf",
|
||||
withscores=True,
|
||||
score_cast_func=self._cast_scores_to_dt,
|
||||
)
|
||||
]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all events."""
|
||||
self._redis.delete(self._key_sorted_set)
|
||||
self._redis.delete(self._key_counter)
|
||||
|
||||
def count(self, earliest: dt.datetime = None, latest: dt.datetime = None) -> int:
|
||||
"""Count of events, can be restricted to given timeframe.
|
||||
|
||||
Args:
|
||||
- earliest: Date of first events to count(inclusive), or -infinite if not specified
|
||||
- latest: Date of last events to count(inclusive), or +infinite if not specified
|
||||
"""
|
||||
min = "-inf" if not earliest else earliest.timestamp()
|
||||
max = "+inf" if not latest else latest.timestamp()
|
||||
return self._redis.zcount(self._key_sorted_set, min=min, max=max)
|
||||
|
||||
def first_event(self, earliest: dt.datetime = None) -> Optional[dt.datetime]:
|
||||
"""Date/Time of first event. Returns `None` if series has no events.
|
||||
|
||||
Args:
|
||||
- earliest: Date of first events to count(inclusive), or any if not specified
|
||||
"""
|
||||
min = "-inf" if not earliest else earliest.timestamp()
|
||||
event = self._redis.zrangebyscore(
|
||||
self._key_sorted_set,
|
||||
min,
|
||||
"+inf",
|
||||
withscores=True,
|
||||
start=0,
|
||||
num=1,
|
||||
score_cast_func=self._cast_scores_to_dt,
|
||||
)
|
||||
if not event:
|
||||
return None
|
||||
return event[0][1]
|
||||
|
||||
@staticmethod
|
||||
def _cast_scores_to_dt(score) -> dt.datetime:
|
||||
return dt.datetime.fromtimestamp(float(score), tz=utc)
|
||||
54
allianceauth/authentication/task_statistics/signals.py
Normal file
54
allianceauth/authentication/task_statistics/signals.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from celery.signals import (
|
||||
task_failure,
|
||||
task_internal_error,
|
||||
task_retry,
|
||||
task_success,
|
||||
worker_ready
|
||||
)
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from .counters import failed_tasks, retried_tasks, succeeded_tasks
|
||||
|
||||
|
||||
def reset_counters():
|
||||
"""Reset all counters for the celery status."""
|
||||
succeeded_tasks.clear()
|
||||
failed_tasks.clear()
|
||||
retried_tasks.clear()
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return not bool(
|
||||
getattr(settings, "ALLIANCEAUTH_DASHBOARD_TASK_STATISTICS_DISABLED", False)
|
||||
)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def reset_counters_when_celery_restarted(*args, **kwargs):
|
||||
if is_enabled():
|
||||
reset_counters()
|
||||
|
||||
|
||||
@task_success.connect
|
||||
def record_task_succeeded(*args, **kwargs):
|
||||
if is_enabled():
|
||||
succeeded_tasks.add()
|
||||
|
||||
|
||||
@task_retry.connect
|
||||
def record_task_retried(*args, **kwargs):
|
||||
if is_enabled():
|
||||
retried_tasks.add()
|
||||
|
||||
|
||||
@task_failure.connect
|
||||
def record_task_failed(*args, **kwargs):
|
||||
if is_enabled():
|
||||
failed_tasks.add()
|
||||
|
||||
|
||||
@task_internal_error.connect
|
||||
def record_task_internal_error(*args, **kwargs):
|
||||
if is_enabled():
|
||||
failed_tasks.add()
|
||||
@@ -0,0 +1,51 @@
|
||||
import datetime as dt
|
||||
|
||||
from django.test import TestCase
|
||||
from django.utils.timezone import now
|
||||
|
||||
from allianceauth.authentication.task_statistics.counters import (
|
||||
dashboard_results,
|
||||
succeeded_tasks,
|
||||
retried_tasks,
|
||||
failed_tasks,
|
||||
)
|
||||
|
||||
|
||||
class TestDashboardResults(TestCase):
|
||||
def test_should_return_counts_for_given_timeframe_only(self):
|
||||
# given
|
||||
earliest_task = now() - dt.timedelta(minutes=15)
|
||||
succeeded_tasks.clear()
|
||||
succeeded_tasks.add(now() - dt.timedelta(hours=1, seconds=1))
|
||||
succeeded_tasks.add(earliest_task)
|
||||
succeeded_tasks.add()
|
||||
succeeded_tasks.add()
|
||||
retried_tasks.clear()
|
||||
retried_tasks.add(now() - dt.timedelta(hours=1, seconds=1))
|
||||
retried_tasks.add(now() - dt.timedelta(seconds=30))
|
||||
retried_tasks.add()
|
||||
failed_tasks.clear()
|
||||
failed_tasks.add(now() - dt.timedelta(hours=1, seconds=1))
|
||||
failed_tasks.add()
|
||||
# when
|
||||
results = dashboard_results(hours=1)
|
||||
# then
|
||||
self.assertEqual(results.succeeded, 3)
|
||||
self.assertEqual(results.retried, 2)
|
||||
self.assertEqual(results.failed, 1)
|
||||
self.assertEqual(results.total, 6)
|
||||
self.assertEqual(results.earliest_task, earliest_task)
|
||||
|
||||
def test_should_work_with_no_data(self):
|
||||
# given
|
||||
succeeded_tasks.clear()
|
||||
retried_tasks.clear()
|
||||
failed_tasks.clear()
|
||||
# when
|
||||
results = dashboard_results(hours=1)
|
||||
# then
|
||||
self.assertEqual(results.succeeded, 0)
|
||||
self.assertEqual(results.retried, 0)
|
||||
self.assertEqual(results.failed, 0)
|
||||
self.assertEqual(results.total, 0)
|
||||
self.assertIsNone(results.earliest_task)
|
||||
@@ -0,0 +1,168 @@
|
||||
import datetime as dt
|
||||
from unittest.mock import patch
|
||||
|
||||
from pytz import utc
|
||||
from redis import RedisError
|
||||
|
||||
from django.test import TestCase
|
||||
from django.utils.timezone import now
|
||||
|
||||
from allianceauth.authentication.task_statistics.event_series import (
|
||||
EventSeries,
|
||||
_RedisStub,
|
||||
)
|
||||
|
||||
MODULE_PATH = "allianceauth.authentication.task_statistics.event_series"
|
||||
|
||||
|
||||
class TestEventSeries(TestCase):
|
||||
def test_should_abort_without_redis_client(self):
|
||||
# when
|
||||
with patch(MODULE_PATH + ".get_redis_client") as mock:
|
||||
mock.return_value = None
|
||||
events = EventSeries("dummy")
|
||||
# then
|
||||
self.assertTrue(events._redis, _RedisStub)
|
||||
self.assertTrue(events.is_disabled)
|
||||
|
||||
def test_should_disable_itself_if_redis_not_available_1(self):
|
||||
# when
|
||||
with patch(MODULE_PATH + ".get_redis_client") as mock_get_master_client:
|
||||
mock_get_master_client.return_value.ping.side_effect = RedisError
|
||||
events = EventSeries("dummy")
|
||||
# then
|
||||
self.assertIsInstance(events._redis, _RedisStub)
|
||||
self.assertTrue(events.is_disabled)
|
||||
|
||||
def test_should_disable_itself_if_redis_not_available_2(self):
|
||||
# when
|
||||
with patch(MODULE_PATH + ".get_redis_client") as mock_get_master_client:
|
||||
mock_get_master_client.return_value.ping.return_value = False
|
||||
events = EventSeries("dummy")
|
||||
# then
|
||||
self.assertIsInstance(events._redis, _RedisStub)
|
||||
self.assertTrue(events.is_disabled)
|
||||
|
||||
def test_should_add_event(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
# when
|
||||
events.add()
|
||||
# then
|
||||
result = events.all()
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertAlmostEqual(result[0], now(), delta=dt.timedelta(seconds=30))
|
||||
|
||||
def test_should_add_event_with_specified_time(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
my_time = dt.datetime(2021, 11, 1, 12, 15, tzinfo=utc)
|
||||
# when
|
||||
events.add(my_time)
|
||||
# then
|
||||
result = events.all()
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertAlmostEqual(result[0], my_time, delta=dt.timedelta(seconds=30))
|
||||
|
||||
def test_should_count_events(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
events.add()
|
||||
events.add()
|
||||
# when
|
||||
result = events.count()
|
||||
# then
|
||||
self.assertEqual(result, 2)
|
||||
|
||||
def test_should_count_zero(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
# when
|
||||
result = events.count()
|
||||
# then
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
def test_should_count_events_within_timeframe_1(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 15, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 30, tzinfo=utc))
|
||||
# when
|
||||
result = events.count(
|
||||
earliest=dt.datetime(2021, 12, 1, 12, 8, tzinfo=utc),
|
||||
latest=dt.datetime(2021, 12, 1, 12, 17, tzinfo=utc),
|
||||
)
|
||||
# then
|
||||
self.assertEqual(result, 2)
|
||||
|
||||
def test_should_count_events_within_timeframe_2(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 15, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 30, tzinfo=utc))
|
||||
# when
|
||||
result = events.count(earliest=dt.datetime(2021, 12, 1, 12, 8))
|
||||
# then
|
||||
self.assertEqual(result, 3)
|
||||
|
||||
def test_should_count_events_within_timeframe_3(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 15, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 30, tzinfo=utc))
|
||||
# when
|
||||
result = events.count(latest=dt.datetime(2021, 12, 1, 12, 12))
|
||||
# then
|
||||
self.assertEqual(result, 2)
|
||||
|
||||
def test_should_clear_events(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
events.add()
|
||||
events.add()
|
||||
# when
|
||||
events.clear()
|
||||
# then
|
||||
self.assertEqual(events.count(), 0)
|
||||
|
||||
def test_should_return_date_of_first_event(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 15, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 30, tzinfo=utc))
|
||||
# when
|
||||
result = events.first_event()
|
||||
# then
|
||||
self.assertEqual(result, dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
|
||||
|
||||
def test_should_return_date_of_first_event_with_range(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 15, tzinfo=utc))
|
||||
events.add(dt.datetime(2021, 12, 1, 12, 30, tzinfo=utc))
|
||||
# when
|
||||
result = events.first_event(
|
||||
earliest=dt.datetime(2021, 12, 1, 12, 8, tzinfo=utc)
|
||||
)
|
||||
# then
|
||||
self.assertEqual(result, dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
|
||||
|
||||
def test_should_return_all_events(self):
|
||||
# given
|
||||
events = EventSeries("dummy")
|
||||
events.add()
|
||||
events.add()
|
||||
# when
|
||||
results = events.all()
|
||||
# then
|
||||
self.assertEqual(len(results), 2)
|
||||
@@ -0,0 +1,93 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from celery.exceptions import Retry
|
||||
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from allianceauth.authentication.task_statistics.counters import (
|
||||
failed_tasks,
|
||||
retried_tasks,
|
||||
succeeded_tasks,
|
||||
)
|
||||
from allianceauth.authentication.task_statistics.signals import (
|
||||
reset_counters,
|
||||
is_enabled,
|
||||
)
|
||||
from allianceauth.eveonline.tasks import update_character
|
||||
|
||||
|
||||
@override_settings(
|
||||
CELERY_ALWAYS_EAGER=True,ALLIANCEAUTH_DASHBOARD_TASK_STATISTICS_DISABLED=False
|
||||
)
|
||||
class TestTaskSignals(TestCase):
|
||||
fixtures = ["disable_analytics"]
|
||||
|
||||
def test_should_record_successful_task(self):
|
||||
# given
|
||||
succeeded_tasks.clear()
|
||||
retried_tasks.clear()
|
||||
failed_tasks.clear()
|
||||
# when
|
||||
with patch(
|
||||
"allianceauth.eveonline.tasks.EveCharacter.objects.update_character"
|
||||
) as mock_update:
|
||||
mock_update.return_value = None
|
||||
update_character.delay(1)
|
||||
# then
|
||||
self.assertEqual(succeeded_tasks.count(), 1)
|
||||
self.assertEqual(retried_tasks.count(), 0)
|
||||
self.assertEqual(failed_tasks.count(), 0)
|
||||
|
||||
def test_should_record_retried_task(self):
|
||||
# given
|
||||
succeeded_tasks.clear()
|
||||
retried_tasks.clear()
|
||||
failed_tasks.clear()
|
||||
# when
|
||||
with patch(
|
||||
"allianceauth.eveonline.tasks.EveCharacter.objects.update_character"
|
||||
) as mock_update:
|
||||
mock_update.side_effect = Retry
|
||||
update_character.delay(1)
|
||||
# then
|
||||
self.assertEqual(succeeded_tasks.count(), 0)
|
||||
self.assertEqual(failed_tasks.count(), 0)
|
||||
self.assertEqual(retried_tasks.count(), 1)
|
||||
|
||||
def test_should_record_failed_task(self):
|
||||
# given
|
||||
succeeded_tasks.clear()
|
||||
retried_tasks.clear()
|
||||
failed_tasks.clear()
|
||||
# when
|
||||
with patch(
|
||||
"allianceauth.eveonline.tasks.EveCharacter.objects.update_character"
|
||||
) as mock_update:
|
||||
mock_update.side_effect = RuntimeError
|
||||
update_character.delay(1)
|
||||
# then
|
||||
self.assertEqual(succeeded_tasks.count(), 0)
|
||||
self.assertEqual(retried_tasks.count(), 0)
|
||||
self.assertEqual(failed_tasks.count(), 1)
|
||||
|
||||
def test_should_reset_counters(self):
|
||||
# given
|
||||
succeeded_tasks.add()
|
||||
retried_tasks.add()
|
||||
failed_tasks.add()
|
||||
# when
|
||||
reset_counters()
|
||||
# then
|
||||
self.assertEqual(succeeded_tasks.count(), 0)
|
||||
self.assertEqual(retried_tasks.count(), 0)
|
||||
self.assertEqual(failed_tasks.count(), 0)
|
||||
|
||||
|
||||
class TestIsEnabled(TestCase):
|
||||
@override_settings(ALLIANCEAUTH_DASHBOARD_TASK_STATISTICS_DISABLED=False)
|
||||
def test_enabled(self):
|
||||
self.assertTrue(is_enabled())
|
||||
|
||||
@override_settings(ALLIANCEAUTH_DASHBOARD_TASK_STATISTICS_DISABLED=True)
|
||||
def test_disabled(self):
|
||||
self.assertFalse(is_enabled())
|
||||
@@ -1,6 +1,9 @@
|
||||
from bs4 import BeautifulSoup
|
||||
from urllib.parse import quote
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from django_webtest import WebTest
|
||||
|
||||
from django.contrib.admin.sites import AdminSite
|
||||
from django.contrib.auth.models import Group
|
||||
from django.test import TestCase, RequestFactory, Client
|
||||
@@ -188,7 +191,7 @@ class TestCaseWithTestData(TestCase):
|
||||
corporation_id=5432,
|
||||
corporation_name="Xavier's School for Gifted Youngsters",
|
||||
corporation_ticker='MUTNT',
|
||||
alliance_id = None,
|
||||
alliance_id=None,
|
||||
faction_id=999,
|
||||
faction_name='The X-Men',
|
||||
)
|
||||
@@ -206,6 +209,7 @@ class TestCaseWithTestData(TestCase):
|
||||
cls.user_4.profile.save()
|
||||
EveFactionInfo.objects.create(faction_id=999, faction_name='The X-Men')
|
||||
|
||||
|
||||
def make_generic_search_request(ModelClass: type, search_term: str):
|
||||
User.objects.create_superuser(
|
||||
username='superuser', password='secret', email='admin@example.com'
|
||||
@@ -218,6 +222,7 @@ def make_generic_search_request(ModelClass: type, search_term: str):
|
||||
|
||||
|
||||
class TestCharacterOwnershipAdmin(TestCaseWithTestData):
|
||||
fixtures = ["disable_analytics"]
|
||||
|
||||
def setUp(self):
|
||||
self.modeladmin = CharacterOwnershipAdmin(
|
||||
@@ -244,6 +249,7 @@ class TestCharacterOwnershipAdmin(TestCaseWithTestData):
|
||||
|
||||
|
||||
class TestOwnershipRecordAdmin(TestCaseWithTestData):
|
||||
fixtures = ["disable_analytics"]
|
||||
|
||||
def setUp(self):
|
||||
self.modeladmin = OwnershipRecordAdmin(
|
||||
@@ -270,11 +276,12 @@ class TestOwnershipRecordAdmin(TestCaseWithTestData):
|
||||
|
||||
|
||||
class TestStateAdmin(TestCaseWithTestData):
|
||||
fixtures = ["disable_analytics"]
|
||||
|
||||
def setUp(self):
|
||||
self.modeladmin = StateAdmin(
|
||||
model=User, admin_site=AdminSite()
|
||||
)
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
cls.modeladmin = StateAdmin(model=User, admin_site=AdminSite())
|
||||
|
||||
def test_change_view_loads_normally(self):
|
||||
User.objects.create_superuser(
|
||||
@@ -299,6 +306,7 @@ class TestStateAdmin(TestCaseWithTestData):
|
||||
|
||||
|
||||
class TestUserAdmin(TestCaseWithTestData):
|
||||
fixtures = ["disable_analytics"]
|
||||
|
||||
def setUp(self):
|
||||
self.factory = RequestFactory()
|
||||
@@ -344,7 +352,7 @@ class TestUserAdmin(TestCaseWithTestData):
|
||||
self.assertEqual(user_main_organization(self.user_3), expected)
|
||||
|
||||
def test_user_main_organization_u4(self):
|
||||
expected="Xavier's School for Gifted Youngsters<br>The X-Men"
|
||||
expected = "Xavier's School for Gifted Youngsters<br>The X-Men"
|
||||
self.assertEqual(user_main_organization(self.user_4), expected)
|
||||
|
||||
def test_characters_u1(self):
|
||||
@@ -537,6 +545,229 @@ class TestUserAdmin(TestCaseWithTestData):
|
||||
self.assertEqual(response.status_code, expected)
|
||||
|
||||
|
||||
class TestStateAdminChangeFormSuperuserExclusiveEdits(WebTest):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
cls.super_admin = User.objects.create_superuser("super_admin")
|
||||
cls.staff_admin = User.objects.create_user("staff_admin")
|
||||
cls.staff_admin.is_staff = True
|
||||
cls.staff_admin.save()
|
||||
cls.staff_admin = AuthUtils.add_permissions_to_user_by_name(
|
||||
[
|
||||
"authentication.add_state",
|
||||
"authentication.change_state",
|
||||
"authentication.view_state",
|
||||
],
|
||||
cls.staff_admin
|
||||
)
|
||||
cls.superuser_exclusive_fields = ["permissions",]
|
||||
|
||||
def test_should_show_all_fields_to_superuser_for_add(self):
|
||||
# given
|
||||
self.app.set_user(self.super_admin)
|
||||
page = self.app.get("/admin/authentication/state/add/")
|
||||
# when
|
||||
form = page.forms["state_form"]
|
||||
# then
|
||||
for field in self.superuser_exclusive_fields:
|
||||
with self.subTest(field=field):
|
||||
self.assertIn(field, form.fields)
|
||||
|
||||
def test_should_not_show_all_fields_to_staff_admins_for_add(self):
|
||||
# given
|
||||
self.app.set_user(self.staff_admin)
|
||||
page = self.app.get("/admin/authentication/state/add/")
|
||||
# when
|
||||
form = page.forms["state_form"]
|
||||
# then
|
||||
for field in self.superuser_exclusive_fields:
|
||||
with self.subTest(field=field):
|
||||
self.assertNotIn(field, form.fields)
|
||||
|
||||
def test_should_show_all_fields_to_superuser_for_change(self):
|
||||
# given
|
||||
self.app.set_user(self.super_admin)
|
||||
state = AuthUtils.get_member_state()
|
||||
page = self.app.get(f"/admin/authentication/state/{state.pk}/change/")
|
||||
# when
|
||||
form = page.forms["state_form"]
|
||||
# then
|
||||
for field in self.superuser_exclusive_fields:
|
||||
with self.subTest(field=field):
|
||||
self.assertIn(field, form.fields)
|
||||
|
||||
def test_should_not_show_all_fields_to_staff_admin_for_change(self):
|
||||
# given
|
||||
self.app.set_user(self.staff_admin)
|
||||
state = AuthUtils.get_member_state()
|
||||
page = self.app.get(f"/admin/authentication/state/{state.pk}/change/")
|
||||
# when
|
||||
form = page.forms["state_form"]
|
||||
# then
|
||||
for field in self.superuser_exclusive_fields:
|
||||
with self.subTest(field=field):
|
||||
self.assertNotIn(field, form.fields)
|
||||
|
||||
|
||||
class TestUserAdminChangeForm(TestCase):
|
||||
fixtures = ["disable_analytics"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
cls.modeladmin = UserAdmin(model=User, admin_site=AdminSite())
|
||||
|
||||
def test_should_show_groups_available_to_user_with_blue_state_only(self):
|
||||
# given
|
||||
superuser = User.objects.create_superuser("Super")
|
||||
user = AuthUtils.create_user("bruce_wayne")
|
||||
character = AuthUtils.add_main_character_2(
|
||||
user,
|
||||
name="Bruce Wayne",
|
||||
character_id=1001,
|
||||
corp_id=2001,
|
||||
corp_name="Wayne Technologies"
|
||||
)
|
||||
blue_state = State.objects.get(name="Blue")
|
||||
blue_state.member_characters.add(character)
|
||||
member_state = AuthUtils.get_member_state()
|
||||
group_1 = Group.objects.create(name="Group 1")
|
||||
group_2 = Group.objects.create(name="Group 2")
|
||||
group_2.authgroup.states.add(blue_state)
|
||||
group_3 = Group.objects.create(name="Group 3")
|
||||
group_3.authgroup.states.add(member_state)
|
||||
self.client.force_login(superuser)
|
||||
# when
|
||||
response = self.client.get(f"/admin/authentication/user/{user.pk}/change/")
|
||||
# then
|
||||
self.assertEqual(response.status_code, 200)
|
||||
soup = BeautifulSoup(response.rendered_content, features="html.parser")
|
||||
groups_select = soup.find("select", {"id": "id_groups"}).find_all('option')
|
||||
group_ids = {int(option["value"]) for option in groups_select}
|
||||
self.assertSetEqual(group_ids, {group_1.pk, group_2.pk})
|
||||
|
||||
|
||||
class TestUserAdminChangeFormSuperuserExclusiveEdits(WebTest):
|
||||
fixtures = ["disable_analytics"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
cls.super_admin = User.objects.create_superuser("super_admin")
|
||||
cls.staff_admin = User.objects.create_user("staff_admin")
|
||||
cls.staff_admin.is_staff = True
|
||||
cls.staff_admin.save()
|
||||
cls.staff_admin = AuthUtils.add_permissions_to_user_by_name(
|
||||
[
|
||||
"auth.change_user",
|
||||
"auth.view_user",
|
||||
"authentication.change_user",
|
||||
"authentication.change_userprofile",
|
||||
"authentication.view_user"
|
||||
],
|
||||
cls.staff_admin
|
||||
)
|
||||
cls.superuser_exclusive_fields = [
|
||||
"is_staff", "is_superuser", "user_permissions"
|
||||
]
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = AuthUtils.create_user("bruce_wayne")
|
||||
|
||||
def test_should_show_all_fields_to_superuser_for_change(self):
|
||||
# given
|
||||
self.app.set_user(self.super_admin)
|
||||
|
||||
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
|
||||
# when
|
||||
form = page.forms["user_form"]
|
||||
# then
|
||||
for field in self.superuser_exclusive_fields:
|
||||
with self.subTest(field=field):
|
||||
self.assertIn(field, form.fields)
|
||||
|
||||
def test_should_not_show_all_fields_to_staff_admin_for_change(self):
|
||||
# given
|
||||
self.app.set_user(self.staff_admin)
|
||||
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
|
||||
# when
|
||||
form = page.forms["user_form"]
|
||||
# then
|
||||
for field in self.superuser_exclusive_fields:
|
||||
with self.subTest(field=field):
|
||||
self.assertNotIn(field, form.fields)
|
||||
|
||||
def test_should_allow_super_admin_to_add_restricted_group_to_user(self):
|
||||
# given
|
||||
self.app.set_user(self.super_admin)
|
||||
group_restricted = Group.objects.create(name="restricted group")
|
||||
group_restricted.authgroup.restricted = True
|
||||
group_restricted.authgroup.save()
|
||||
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
|
||||
form = page.forms["user_form"]
|
||||
# when
|
||||
form["groups"].select_multiple(texts=["restricted group"])
|
||||
response = form.submit("save")
|
||||
# then
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.user.refresh_from_db()
|
||||
self.assertIn(
|
||||
"restricted group", self.user.groups.values_list("name", flat=True)
|
||||
)
|
||||
|
||||
def test_should_not_allow_staff_admin_to_add_restricted_group_to_user(self):
|
||||
# given
|
||||
self.app.set_user(self.staff_admin)
|
||||
group_restricted = Group.objects.create(name="restricted group")
|
||||
group_restricted.authgroup.restricted = True
|
||||
group_restricted.authgroup.save()
|
||||
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
|
||||
form = page.forms["user_form"]
|
||||
# when
|
||||
form["groups"].select_multiple(texts=["restricted group"])
|
||||
response = form.submit("save")
|
||||
# then
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn(
|
||||
"You are not allowed to add or remove these restricted groups",
|
||||
response.text
|
||||
)
|
||||
|
||||
def test_should_not_allow_staff_admin_to_remove_restricted_group_from_user(self):
|
||||
# given
|
||||
self.app.set_user(self.staff_admin)
|
||||
group_restricted = Group.objects.create(name="restricted group")
|
||||
group_restricted.authgroup.restricted = True
|
||||
group_restricted.authgroup.save()
|
||||
self.user.groups.add(group_restricted)
|
||||
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
|
||||
form = page.forms["user_form"]
|
||||
# when
|
||||
form["groups"].select_multiple(texts=[])
|
||||
response = form.submit("save")
|
||||
# then
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn(
|
||||
"You are not allowed to add or remove these restricted groups",
|
||||
response.text
|
||||
)
|
||||
|
||||
def test_should_allow_staff_admin_to_add_normal_group_to_user(self):
|
||||
# given
|
||||
self.app.set_user(self.super_admin)
|
||||
Group.objects.create(name="normal group")
|
||||
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
|
||||
form = page.forms["user_form"]
|
||||
# when
|
||||
form["groups"].select_multiple(texts=["normal group"])
|
||||
response = form.submit("save")
|
||||
# then
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.user.refresh_from_db()
|
||||
self.assertIn("normal group", self.user.groups.values_list("name", flat=True))
|
||||
|
||||
|
||||
class TestMakeServicesHooksActions(TestCaseWithTestData):
|
||||
|
||||
class MyServicesHookTypeA(ServicesHook):
|
||||
|
||||
@@ -55,7 +55,6 @@ TEST_VERSION = '2.6.5'
|
||||
|
||||
|
||||
class TestStatusOverviewTag(TestCase):
|
||||
|
||||
@patch(MODULE_PATH + '.admin_status.__version__', TEST_VERSION)
|
||||
@patch(MODULE_PATH + '.admin_status._fetch_celery_queue_length')
|
||||
@patch(MODULE_PATH + '.admin_status._current_version_summary')
|
||||
@@ -66,6 +65,7 @@ class TestStatusOverviewTag(TestCase):
|
||||
mock_current_version_info,
|
||||
mock_fetch_celery_queue_length
|
||||
):
|
||||
# given
|
||||
notifications = {
|
||||
'notifications': GITHUB_NOTIFICATION_ISSUES[:5]
|
||||
}
|
||||
@@ -83,22 +83,20 @@ class TestStatusOverviewTag(TestCase):
|
||||
}
|
||||
mock_current_version_info.return_value = version_info
|
||||
mock_fetch_celery_queue_length.return_value = 3
|
||||
|
||||
# when
|
||||
result = status_overview()
|
||||
expected = {
|
||||
'notifications': GITHUB_NOTIFICATION_ISSUES[:5],
|
||||
'latest_major': True,
|
||||
'latest_minor': True,
|
||||
'latest_patch': True,
|
||||
'latest_beta': False,
|
||||
'current_version': TEST_VERSION,
|
||||
'latest_major_version': '2.4.5',
|
||||
'latest_minor_version': '2.4.0',
|
||||
'latest_patch_version': '2.4.5',
|
||||
'latest_beta_version': '2.4.4a1',
|
||||
'task_queue_length': 3,
|
||||
}
|
||||
self.assertEqual(result, expected)
|
||||
# then
|
||||
self.assertEqual(result["notifications"], GITHUB_NOTIFICATION_ISSUES[:5])
|
||||
self.assertTrue(result["latest_major"])
|
||||
self.assertTrue(result["latest_minor"])
|
||||
self.assertTrue(result["latest_patch"])
|
||||
self.assertFalse(result["latest_beta"])
|
||||
self.assertEqual(result["current_version"], TEST_VERSION)
|
||||
self.assertEqual(result["latest_major_version"], '2.4.5')
|
||||
self.assertEqual(result["latest_minor_version"], '2.4.0')
|
||||
self.assertEqual(result["latest_patch_version"], '2.4.5')
|
||||
self.assertEqual(result["latest_beta_version"], '2.4.4a1')
|
||||
self.assertEqual(result["task_queue_length"], 3)
|
||||
|
||||
|
||||
class TestNotifications(TestCase):
|
||||
|
||||
@@ -193,6 +193,8 @@
|
||||
"columnDefs": [
|
||||
{ "sortable": false, "targets": [1] },
|
||||
],
|
||||
"stateSave": true,
|
||||
"stateDuration": 0
|
||||
});
|
||||
$('#table-members').DataTable({
|
||||
"columnDefs": [
|
||||
@@ -200,6 +202,8 @@
|
||||
{ "sortable": false, "targets": [0, 2] },
|
||||
],
|
||||
"order": [[ 1, "asc" ]],
|
||||
"stateSave": true,
|
||||
"stateDuration": 0
|
||||
});
|
||||
$('#table-unregistered').DataTable({
|
||||
"columnDefs": [
|
||||
@@ -207,6 +211,8 @@
|
||||
{ "sortable": false, "targets": [0, 2] },
|
||||
],
|
||||
"order": [[ 1, "asc" ]],
|
||||
"stateSave": true,
|
||||
"stateDuration": 0
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
@@ -43,6 +43,9 @@
|
||||
{% endblock %}
|
||||
{% block extra_script %}
|
||||
$(document).ready(function(){
|
||||
$('#table-search').DataTable();
|
||||
$('#table-search').DataTable({
|
||||
"stateSave": true,
|
||||
"stateDuration": 0
|
||||
});
|
||||
});
|
||||
{% endblock %}
|
||||
|
||||
@@ -1,13 +1,27 @@
|
||||
from django.db import models
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from .managers import EveCharacterManager, EveCharacterProviderManager
|
||||
from .managers import EveCorporationManager, EveCorporationProviderManager
|
||||
from .managers import EveAllianceManager, EveAllianceProviderManager
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import models
|
||||
from esi.models import Token
|
||||
|
||||
from allianceauth.notifications import notify
|
||||
|
||||
from . import providers
|
||||
from .evelinks import eveimageserver
|
||||
from .managers import (
|
||||
EveAllianceManager,
|
||||
EveAllianceProviderManager,
|
||||
EveCharacterManager,
|
||||
EveCharacterProviderManager,
|
||||
EveCorporationManager,
|
||||
EveCorporationProviderManager,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_IMAGE_SIZE = 32
|
||||
DOOMHEIM_CORPORATION_ID = 1000001
|
||||
|
||||
|
||||
class EveFactionInfo(models.Model):
|
||||
@@ -68,13 +82,12 @@ class EveAllianceInfo(models.Model):
|
||||
for corp_id in alliance.corp_ids:
|
||||
if not EveCorporationInfo.objects.filter(corporation_id=corp_id).exists():
|
||||
EveCorporationInfo.objects.create_corporation(corp_id)
|
||||
EveCorporationInfo.objects.filter(
|
||||
corporation_id__in=alliance.corp_ids).update(alliance=self
|
||||
EveCorporationInfo.objects.filter(corporation_id__in=alliance.corp_ids).update(
|
||||
alliance=self
|
||||
)
|
||||
EveCorporationInfo.objects\
|
||||
.filter(alliance=self)\
|
||||
.exclude(corporation_id__in=alliance.corp_ids)\
|
||||
.update(alliance=None)
|
||||
EveCorporationInfo.objects.filter(alliance=self).exclude(
|
||||
corporation_id__in=alliance.corp_ids
|
||||
).update(alliance=None)
|
||||
|
||||
def update_alliance(self, alliance: providers.Alliance = None):
|
||||
if alliance is None:
|
||||
@@ -182,6 +195,7 @@ class EveCorporationInfo(models.Model):
|
||||
|
||||
|
||||
class EveCharacter(models.Model):
|
||||
"""Character in Eve Online"""
|
||||
character_id = models.PositiveIntegerField(unique=True)
|
||||
character_name = models.CharField(max_length=254, unique=True)
|
||||
corporation_id = models.PositiveIntegerField()
|
||||
@@ -198,12 +212,20 @@ class EveCharacter(models.Model):
|
||||
|
||||
class Meta:
|
||||
indexes = [
|
||||
models.Index(fields=['corporation_id',]),
|
||||
models.Index(fields=['alliance_id',]),
|
||||
models.Index(fields=['corporation_name',]),
|
||||
models.Index(fields=['alliance_name',]),
|
||||
models.Index(fields=['faction_id',]),
|
||||
]
|
||||
models.Index(fields=['corporation_id',]),
|
||||
models.Index(fields=['alliance_id',]),
|
||||
models.Index(fields=['corporation_name',]),
|
||||
models.Index(fields=['alliance_name',]),
|
||||
models.Index(fields=['faction_id',]),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return self.character_name
|
||||
|
||||
@property
|
||||
def is_biomassed(self) -> bool:
|
||||
"""Whether this character is dead or not."""
|
||||
return self.corporation_id == DOOMHEIM_CORPORATION_ID
|
||||
|
||||
@property
|
||||
def alliance(self) -> Union[EveAllianceInfo, None]:
|
||||
@@ -249,10 +271,36 @@ class EveCharacter(models.Model):
|
||||
self.faction_id = character.faction.id
|
||||
self.faction_name = character.faction.name
|
||||
self.save()
|
||||
if self.is_biomassed:
|
||||
self._remove_tokens_of_biomassed_character()
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
return self.character_name
|
||||
def _remove_tokens_of_biomassed_character(self) -> None:
|
||||
"""Remove tokens of this biomassed character."""
|
||||
try:
|
||||
user = self.character_ownership.user
|
||||
except ObjectDoesNotExist:
|
||||
return
|
||||
tokens_to_delete = Token.objects.filter(character_id=self.character_id)
|
||||
tokens_count = tokens_to_delete.count()
|
||||
if not tokens_count:
|
||||
return
|
||||
tokens_to_delete.delete()
|
||||
logger.info(
|
||||
"%d tokens from user %s for biomassed character %s [id:%s] deleted.",
|
||||
tokens_count,
|
||||
user,
|
||||
self,
|
||||
self.character_id,
|
||||
)
|
||||
notify(
|
||||
user=user,
|
||||
title=f"Character {self} biomassed",
|
||||
message=(
|
||||
f"Your former character {self} has been biomassed "
|
||||
"and has been removed from the list of your alts."
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generic_portrait_url(
|
||||
@@ -336,7 +384,6 @@ class EveCharacter(models.Model):
|
||||
"""image URL for alliance of this character or empty string"""
|
||||
return self.alliance_logo_url(256)
|
||||
|
||||
|
||||
def faction_logo_url(self, size=_DEFAULT_IMAGE_SIZE) -> str:
|
||||
"""image URL for alliance of this character or empty string"""
|
||||
if self.faction_id:
|
||||
|
||||
@@ -170,7 +170,7 @@ class EveProvider:
|
||||
"""
|
||||
:return: an ItemType object for the given ID
|
||||
"""
|
||||
raise NotImplemented()
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class EveSwaggerProvider(EveProvider):
|
||||
@@ -207,7 +207,8 @@ class EveSwaggerProvider(EveProvider):
|
||||
def __str__(self):
|
||||
return 'esi'
|
||||
|
||||
def get_alliance(self, alliance_id):
|
||||
def get_alliance(self, alliance_id: int) -> Alliance:
|
||||
"""Fetch alliance from ESI."""
|
||||
try:
|
||||
data = self.client.Alliance.get_alliances_alliance_id(alliance_id=alliance_id).result()
|
||||
corps = self.client.Alliance.get_alliances_alliance_id_corporations(alliance_id=alliance_id).result()
|
||||
@@ -223,7 +224,8 @@ class EveSwaggerProvider(EveProvider):
|
||||
except HTTPNotFound:
|
||||
raise ObjectNotFound(alliance_id, 'alliance')
|
||||
|
||||
def get_corp(self, corp_id):
|
||||
def get_corp(self, corp_id: int) -> Corporation:
|
||||
"""Fetch corporation from ESI."""
|
||||
try:
|
||||
data = self.client.Corporation.get_corporations_corporation_id(corporation_id=corp_id).result()
|
||||
model = Corporation(
|
||||
@@ -239,29 +241,43 @@ class EveSwaggerProvider(EveProvider):
|
||||
except HTTPNotFound:
|
||||
raise ObjectNotFound(corp_id, 'corporation')
|
||||
|
||||
def get_character(self, character_id):
|
||||
def get_character(self, character_id: int) -> Character:
|
||||
"""Fetch character from ESI."""
|
||||
try:
|
||||
data = self.client.Character.get_characters_character_id(character_id=character_id).result()
|
||||
character_name = self._fetch_character_name(character_id)
|
||||
affiliation = self.client.Character.post_characters_affiliation(characters=[character_id]).result()[0]
|
||||
|
||||
model = Character(
|
||||
id=character_id,
|
||||
name=data['name'],
|
||||
name=character_name,
|
||||
corp_id=affiliation['corporation_id'],
|
||||
alliance_id=affiliation['alliance_id'] if 'alliance_id' in affiliation else None,
|
||||
faction_id=affiliation['faction_id'] if 'faction_id' in affiliation else None,
|
||||
)
|
||||
return model
|
||||
except (HTTPNotFound, HTTPUnprocessableEntity):
|
||||
except (HTTPNotFound, HTTPUnprocessableEntity, ObjectNotFound):
|
||||
raise ObjectNotFound(character_id, 'character')
|
||||
|
||||
def _fetch_character_name(self, character_id: int) -> str:
|
||||
"""Fetch character name from ESI."""
|
||||
data = self.client.Universe.post_universe_names(ids=[character_id]).result()
|
||||
character = data.pop() if data else None
|
||||
if (
|
||||
not character
|
||||
or character["category"] != "character"
|
||||
or character["id"] != character_id
|
||||
):
|
||||
raise ObjectNotFound(character_id, 'character')
|
||||
return character["name"]
|
||||
|
||||
def get_all_factions(self):
|
||||
"""Fetch all factions from ESI."""
|
||||
if not self._faction_list:
|
||||
self._faction_list = self.client.Universe.get_universe_factions().result()
|
||||
return self._faction_list
|
||||
|
||||
def get_faction(self, faction_id):
|
||||
faction_id=int(faction_id)
|
||||
def get_faction(self, faction_id: int):
|
||||
"""Fetch faction from ESI."""
|
||||
faction_id = int(faction_id)
|
||||
try:
|
||||
if not self._faction_list:
|
||||
_ = self.get_all_factions()
|
||||
@@ -273,7 +289,8 @@ class EveSwaggerProvider(EveProvider):
|
||||
except (HTTPNotFound, HTTPUnprocessableEntity, KeyError):
|
||||
raise ObjectNotFound(faction_id, 'faction')
|
||||
|
||||
def get_itemtype(self, type_id):
|
||||
def get_itemtype(self, type_id: int) -> ItemType:
|
||||
"""Fetch inventory item from ESI."""
|
||||
try:
|
||||
data = self.client.Universe.get_universe_types_type_id(type_id=type_id).result()
|
||||
return ItemType(id=type_id, name=data['name'])
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from .models import EveAllianceInfo
|
||||
from .models import EveCharacter
|
||||
from .models import EveCorporationInfo
|
||||
|
||||
from .models import EveAllianceInfo, EveCharacter, EveCorporationInfo
|
||||
from . import providers
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TASK_PRIORITY = 7
|
||||
@@ -32,8 +31,8 @@ def update_alliance(alliance_id):
|
||||
|
||||
|
||||
@shared_task
|
||||
def update_character(character_id):
|
||||
"""Update given character from ESI"""
|
||||
def update_character(character_id: int) -> None:
|
||||
"""Update given character from ESI."""
|
||||
EveCharacter.objects.update_character(character_id)
|
||||
|
||||
|
||||
@@ -65,17 +64,17 @@ def update_character_chunk(character_ids_chunk: list):
|
||||
.post_characters_affiliation(characters=character_ids_chunk).result()
|
||||
character_names = providers.provider.client.Universe\
|
||||
.post_universe_names(ids=character_ids_chunk).result()
|
||||
except:
|
||||
except OSError:
|
||||
logger.info("Failed to bulk update characters. Attempting single updates")
|
||||
for character_id in character_ids_chunk:
|
||||
update_character.apply_async(
|
||||
args=[character_id], priority=TASK_PRIORITY
|
||||
)
|
||||
args=[character_id], priority=TASK_PRIORITY
|
||||
)
|
||||
return
|
||||
|
||||
affiliations = {
|
||||
affiliation.get('character_id'): affiliation
|
||||
for affiliation in affiliations_raw
|
||||
affiliation.get('character_id'): affiliation
|
||||
for affiliation in affiliations_raw
|
||||
}
|
||||
# add character names to affiliations
|
||||
for character in character_names:
|
||||
@@ -108,5 +107,5 @@ def update_character_chunk(character_ids_chunk: list):
|
||||
|
||||
if corp_changed or alliance_changed or name_changed:
|
||||
update_character.apply_async(
|
||||
args=[character.get('character_id')], priority=TASK_PRIORITY
|
||||
)
|
||||
args=[character.get('character_id')], priority=TASK_PRIORITY
|
||||
)
|
||||
|
||||
168
allianceauth/eveonline/tests/esi_client_stub.py
Normal file
168
allianceauth/eveonline/tests/esi_client_stub.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from bravado.exception import HTTPNotFound
|
||||
|
||||
|
||||
class BravadoResponseStub:
|
||||
"""Stub for IncomingResponse in bravado, e.g. for HTTPError exceptions"""
|
||||
|
||||
def __init__(
|
||||
self, status_code, reason="", text="", headers=None, raw_bytes=None
|
||||
) -> None:
|
||||
self.reason = reason
|
||||
self.status_code = status_code
|
||||
self.text = text
|
||||
self.headers = headers if headers else dict()
|
||||
self.raw_bytes = raw_bytes
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.status_code} {self.reason}"
|
||||
|
||||
|
||||
class BravadoOperationStub:
|
||||
"""Stub to simulate the operation object return from bravado via django-esi"""
|
||||
|
||||
class RequestConfig:
|
||||
def __init__(self, also_return_response):
|
||||
self.also_return_response = also_return_response
|
||||
|
||||
class ResponseStub:
|
||||
def __init__(self, headers):
|
||||
self.headers = headers
|
||||
|
||||
def __init__(self, data, headers: dict = None, also_return_response: bool = False):
|
||||
self._data = data
|
||||
self._headers = headers if headers else {"x-pages": 1}
|
||||
self.request_config = BravadoOperationStub.RequestConfig(also_return_response)
|
||||
|
||||
def result(self, **kwargs):
|
||||
if self.request_config.also_return_response:
|
||||
return [self._data, self.ResponseStub(self._headers)]
|
||||
else:
|
||||
return self._data
|
||||
|
||||
def results(self, **kwargs):
|
||||
return self.result(**kwargs)
|
||||
|
||||
|
||||
class EsiClientStub:
|
||||
"""Stub for an ESI client."""
|
||||
class Alliance:
|
||||
@staticmethod
|
||||
def get_alliances_alliance_id(alliance_id):
|
||||
data = {
|
||||
3001: {
|
||||
"name": "Wayne Enterprises",
|
||||
"ticker": "WYE",
|
||||
"executor_corporation_id": 2001
|
||||
}
|
||||
}
|
||||
try:
|
||||
return BravadoOperationStub(data[int(alliance_id)])
|
||||
except KeyError:
|
||||
response = BravadoResponseStub(
|
||||
404, f"Alliance with ID {alliance_id} not found"
|
||||
)
|
||||
raise HTTPNotFound(response)
|
||||
|
||||
@staticmethod
|
||||
def get_alliances_alliance_id_corporations(alliance_id):
|
||||
data = [2001, 2002, 2003]
|
||||
return BravadoOperationStub(data)
|
||||
|
||||
class Character:
|
||||
@staticmethod
|
||||
def get_characters_character_id(character_id):
|
||||
data = {
|
||||
1001: {
|
||||
"corporation_id": 2001,
|
||||
"name": "Bruce Wayne",
|
||||
},
|
||||
1002: {
|
||||
"corporation_id": 2001,
|
||||
"name": "Peter Parker",
|
||||
},
|
||||
1011: {
|
||||
"corporation_id": 2011,
|
||||
"name": "Lex Luthor",
|
||||
}
|
||||
}
|
||||
try:
|
||||
return BravadoOperationStub(data[int(character_id)])
|
||||
except KeyError:
|
||||
response = BravadoResponseStub(
|
||||
404, f"Character with ID {character_id} not found"
|
||||
)
|
||||
raise HTTPNotFound(response)
|
||||
|
||||
@staticmethod
|
||||
def post_characters_affiliation(characters: list):
|
||||
data = [
|
||||
{'character_id': 1001, 'corporation_id': 2001, 'alliance_id': 3001},
|
||||
{'character_id': 1002, 'corporation_id': 2001, 'alliance_id': 3001},
|
||||
{'character_id': 1011, 'corporation_id': 2011},
|
||||
{'character_id': 1666, 'corporation_id': 1000001},
|
||||
]
|
||||
return BravadoOperationStub(
|
||||
[x for x in data if x['character_id'] in characters]
|
||||
)
|
||||
|
||||
class Corporation:
|
||||
@staticmethod
|
||||
def get_corporations_corporation_id(corporation_id):
|
||||
data = {
|
||||
2001: {
|
||||
"ceo_id": 1091,
|
||||
"member_count": 10,
|
||||
"name": "Wayne Technologies",
|
||||
"ticker": "WTE",
|
||||
"alliance_id": 3001
|
||||
},
|
||||
2002: {
|
||||
"ceo_id": 1092,
|
||||
"member_count": 10,
|
||||
"name": "Wayne Food",
|
||||
"ticker": "WFO",
|
||||
"alliance_id": 3001
|
||||
},
|
||||
2003: {
|
||||
"ceo_id": 1093,
|
||||
"member_count": 10,
|
||||
"name": "Wayne Energy",
|
||||
"ticker": "WEG",
|
||||
"alliance_id": 3001
|
||||
},
|
||||
2011: {
|
||||
"ceo_id": 1,
|
||||
"member_count": 3,
|
||||
"name": "LexCorp",
|
||||
"ticker": "LC",
|
||||
},
|
||||
1000001: {
|
||||
"ceo_id": 3000001,
|
||||
"creator_id": 1,
|
||||
"description": "The internal corporation used for characters in graveyard.",
|
||||
"member_count": 6329026,
|
||||
"name": "Doomheim",
|
||||
"ticker": "666",
|
||||
}
|
||||
}
|
||||
try:
|
||||
return BravadoOperationStub(data[int(corporation_id)])
|
||||
except KeyError:
|
||||
response = BravadoResponseStub(
|
||||
404, f"Corporation with ID {corporation_id} not found"
|
||||
)
|
||||
raise HTTPNotFound(response)
|
||||
|
||||
class Universe:
|
||||
@staticmethod
|
||||
def post_universe_names(ids: list):
|
||||
data = [
|
||||
{"category": "character", "id": 1001, "name": "Bruce Wayne"},
|
||||
{"category": "character", "id": 1002, "name": "Peter Parker"},
|
||||
{"category": "character", "id": 1011, "name": "Lex Luthor"},
|
||||
{"category": "character", "id": 1666, "name": "Hal Jordan"},
|
||||
{"category": "corporation", "id": 2001, "name": "Wayne Technologies"},
|
||||
{"category": "corporation","id": 2002, "name": "Wayne Food"},
|
||||
{"category": "corporation","id": 1000001, "name": "Doomheim"},
|
||||
]
|
||||
return BravadoOperationStub([x for x in data if x['id'] in ids])
|
||||
@@ -1,12 +1,15 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.test import TestCase
|
||||
from esi.models import Token
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from ..models import (
|
||||
EveCharacter, EveCorporationInfo, EveAllianceInfo, EveFactionInfo
|
||||
)
|
||||
from ..providers import Alliance, Corporation, Character
|
||||
from ..evelinks import eveimageserver
|
||||
from ..models import EveAllianceInfo, EveCharacter, EveCorporationInfo, EveFactionInfo
|
||||
from ..providers import Alliance, Character, Corporation
|
||||
from .esi_client_stub import EsiClientStub
|
||||
|
||||
|
||||
class EveCharacterTestCase(TestCase):
|
||||
@@ -402,8 +405,8 @@ class EveAllianceTestCase(TestCase):
|
||||
my_alliance.save()
|
||||
my_alliance.populate_alliance()
|
||||
|
||||
for corporation in EveCorporationInfo.objects\
|
||||
.filter(corporation_id__in=[2001, 2002]
|
||||
for corporation in (
|
||||
EveCorporationInfo.objects.filter(corporation_id__in=[2001, 2002])
|
||||
):
|
||||
self.assertEqual(corporation.alliance, my_alliance)
|
||||
|
||||
@@ -587,3 +590,98 @@ class EveCorporationTestCase(TestCase):
|
||||
self.my_corp.logo_url_256,
|
||||
'https://images.evetech.net/corporations/2001/logo?size=256'
|
||||
)
|
||||
|
||||
|
||||
@patch('allianceauth.eveonline.providers.esi_client_factory')
|
||||
@patch("allianceauth.eveonline.models.notify")
|
||||
class TestCharacterUpdate(TestCase):
|
||||
def test_should_update_normal_character(self, mock_notify, mock_esi_client_factory):
|
||||
# given
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
my_character = EveCharacter.objects.create(
|
||||
character_id=1001,
|
||||
character_name="not my name",
|
||||
corporation_id=2002,
|
||||
corporation_name="Wayne Food",
|
||||
corporation_ticker="WYF",
|
||||
alliance_id=None
|
||||
)
|
||||
# when
|
||||
my_character.update_character()
|
||||
# then
|
||||
my_character.refresh_from_db()
|
||||
self.assertEqual(my_character.character_name, "Bruce Wayne")
|
||||
self.assertEqual(my_character.corporation_id, 2001)
|
||||
self.assertEqual(my_character.corporation_name, "Wayne Technologies")
|
||||
self.assertEqual(my_character.corporation_ticker, "WTE")
|
||||
self.assertEqual(my_character.alliance_id, 3001)
|
||||
self.assertEqual(my_character.alliance_name, "Wayne Enterprises")
|
||||
self.assertEqual(my_character.alliance_ticker, "WYE")
|
||||
self.assertFalse(mock_notify.called)
|
||||
|
||||
def test_should_update_dead_character_with_owner(
|
||||
self, mock_notify, mock_esi_client_factory
|
||||
):
|
||||
# given
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
character_1666 = EveCharacter.objects.create(
|
||||
character_id=1666,
|
||||
character_name="Hal Jordan",
|
||||
corporation_id=2002,
|
||||
corporation_name="Wayne Food",
|
||||
corporation_ticker="WYF",
|
||||
alliance_id=None
|
||||
)
|
||||
user = AuthUtils.create_user("Bruce Wayne")
|
||||
token_1666 = Token.objects.create(
|
||||
user=user,
|
||||
character_id=character_1666.character_id,
|
||||
character_name=character_1666.character_name,
|
||||
character_owner_hash="ABC123-1666",
|
||||
)
|
||||
character_1001 = EveCharacter.objects.create(
|
||||
character_id=1001,
|
||||
character_name="Bruce Wayne",
|
||||
corporation_id=2001,
|
||||
corporation_name="Wayne Technologies",
|
||||
corporation_ticker="WYT",
|
||||
alliance_id=None
|
||||
)
|
||||
token_1001 = Token.objects.create(
|
||||
user=user,
|
||||
character_id=character_1001.character_id,
|
||||
character_name=character_1001.character_name,
|
||||
character_owner_hash="ABC123-1001",
|
||||
)
|
||||
# when
|
||||
character_1666.update_character()
|
||||
# then
|
||||
character_1666.refresh_from_db()
|
||||
self.assertTrue(character_1666.is_biomassed)
|
||||
self.assertNotIn(token_1666, user.token_set.all())
|
||||
self.assertIn(token_1001, user.token_set.all())
|
||||
with self.assertRaises(ObjectDoesNotExist):
|
||||
self.assertTrue(character_1666.character_ownership)
|
||||
user.profile.refresh_from_db()
|
||||
self.assertIsNone(user.profile.main_character)
|
||||
self.assertTrue(mock_notify.called)
|
||||
|
||||
def test_should_handle_dead_character_without_owner(
|
||||
self, mock_notify, mock_esi_client_factory
|
||||
):
|
||||
# given
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
character_1666 = EveCharacter.objects.create(
|
||||
character_id=1666,
|
||||
character_name="Hal Jordan",
|
||||
corporation_id=1011,
|
||||
corporation_name="LexCorp",
|
||||
corporation_ticker='LC',
|
||||
alliance_id=None
|
||||
)
|
||||
# when
|
||||
character_1666.update_character()
|
||||
# then
|
||||
character_1666.refresh_from_db()
|
||||
self.assertTrue(character_1666.is_biomassed)
|
||||
self.assertFalse(mock_notify.called)
|
||||
|
||||
@@ -7,6 +7,7 @@ from jsonschema.exceptions import RefResolutionError
|
||||
from django.test import TestCase
|
||||
|
||||
from . import set_logger
|
||||
from .esi_client_stub import EsiClientStub
|
||||
from ..providers import (
|
||||
ObjectNotFound,
|
||||
Entity,
|
||||
@@ -632,13 +633,7 @@ class TestEveSwaggerProvider(TestCase):
|
||||
|
||||
@patch(MODULE_PATH + '.esi_client_factory')
|
||||
def test_get_character(self, mock_esi_client_factory):
|
||||
mock_esi_client_factory.return_value \
|
||||
.Character.get_characters_character_id \
|
||||
= TestEveSwaggerProvider.esi_get_characters_character_id
|
||||
mock_esi_client_factory.return_value \
|
||||
.Character.post_characters_affiliation \
|
||||
= TestEveSwaggerProvider.esi_post_characters_affiliation
|
||||
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
my_provider = EveSwaggerProvider()
|
||||
|
||||
# character with alliance
|
||||
@@ -649,8 +644,8 @@ class TestEveSwaggerProvider(TestCase):
|
||||
self.assertEqual(my_character.alliance_id, 3001)
|
||||
|
||||
# character wo/ alliance
|
||||
my_character = my_provider.get_character(1002)
|
||||
self.assertEqual(my_character.id, 1002)
|
||||
my_character = my_provider.get_character(1011)
|
||||
self.assertEqual(my_character.id, 1011)
|
||||
self.assertEqual(my_character.alliance_id, None)
|
||||
|
||||
# character not found
|
||||
|
||||
@@ -1,245 +1,271 @@
|
||||
from unittest.mock import patch, Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TestCase
|
||||
from django.test import TestCase, TransactionTestCase, override_settings
|
||||
|
||||
from ..models import EveCharacter, EveCorporationInfo, EveAllianceInfo
|
||||
from ..models import EveAllianceInfo, EveCharacter, EveCorporationInfo
|
||||
from ..tasks import (
|
||||
run_model_update,
|
||||
update_alliance,
|
||||
update_corp,
|
||||
update_character,
|
||||
run_model_update
|
||||
update_character_chunk,
|
||||
update_corp,
|
||||
)
|
||||
from .esi_client_stub import EsiClientStub
|
||||
|
||||
|
||||
class TestTasks(TestCase):
|
||||
|
||||
@patch('allianceauth.eveonline.tasks.EveCorporationInfo')
|
||||
def test_update_corp(self, mock_EveCorporationInfo):
|
||||
update_corp(42)
|
||||
self.assertEqual(
|
||||
mock_EveCorporationInfo.objects.update_corporation.call_count, 1
|
||||
)
|
||||
self.assertEqual(
|
||||
mock_EveCorporationInfo.objects.update_corporation.call_args[0][0], 42
|
||||
@patch('allianceauth.eveonline.providers.esi_client_factory')
|
||||
class TestUpdateTasks(TestCase):
|
||||
def test_should_update_alliance(self, mock_esi_client_factory):
|
||||
# given
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
my_alliance = EveAllianceInfo.objects.create(
|
||||
alliance_id=3001,
|
||||
alliance_name="Wayne Enterprises",
|
||||
alliance_ticker="WYE",
|
||||
executor_corp_id=2003
|
||||
)
|
||||
# when
|
||||
update_alliance(my_alliance.alliance_id)
|
||||
# then
|
||||
my_alliance.refresh_from_db()
|
||||
self.assertEqual(my_alliance.executor_corp_id, 2001)
|
||||
|
||||
@patch('allianceauth.eveonline.tasks.EveAllianceInfo')
|
||||
def test_update_alliance(self, mock_EveAllianceInfo):
|
||||
update_alliance(42)
|
||||
self.assertEqual(
|
||||
mock_EveAllianceInfo.objects.update_alliance.call_args[0][0], 42
|
||||
)
|
||||
self.assertEqual(
|
||||
mock_EveAllianceInfo.objects
|
||||
.update_alliance.return_value.populate_alliance.call_count, 1
|
||||
def test_should_update_character(self, mock_esi_client_factory):
|
||||
# given
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
my_character = EveCharacter.objects.create(
|
||||
character_id=1001,
|
||||
character_name="Bruce Wayne",
|
||||
corporation_id=2002,
|
||||
corporation_name="Wayne Food",
|
||||
corporation_ticker="WYF",
|
||||
alliance_id=None
|
||||
)
|
||||
# when
|
||||
update_character(my_character.character_id)
|
||||
# then
|
||||
my_character.refresh_from_db()
|
||||
self.assertEqual(my_character.corporation_id, 2001)
|
||||
|
||||
@patch('allianceauth.eveonline.tasks.EveCharacter')
|
||||
def test_update_character(self, mock_EveCharacter):
|
||||
update_character(42)
|
||||
self.assertEqual(
|
||||
mock_EveCharacter.objects.update_character.call_count, 1
|
||||
def test_should_update_corp(self, mock_esi_client_factory):
|
||||
# given
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
EveAllianceInfo.objects.create(
|
||||
alliance_id=3001,
|
||||
alliance_name="Wayne Enterprises",
|
||||
alliance_ticker="WYE",
|
||||
executor_corp_id=2003
|
||||
)
|
||||
self.assertEqual(
|
||||
mock_EveCharacter.objects.update_character.call_args[0][0], 42
|
||||
my_corporation = EveCorporationInfo.objects.create(
|
||||
corporation_id=2003,
|
||||
corporation_name="Wayne Food",
|
||||
corporation_ticker="WFO",
|
||||
member_count=1,
|
||||
alliance=None,
|
||||
ceo_id=1999
|
||||
)
|
||||
# when
|
||||
update_corp(my_corporation.corporation_id)
|
||||
# then
|
||||
my_corporation.refresh_from_db()
|
||||
self.assertEqual(my_corporation.alliance.alliance_id, 3001)
|
||||
|
||||
# @patch('allianceauth.eveonline.tasks.EveCharacter')
|
||||
# def test_update_character(self, mock_EveCharacter):
|
||||
# update_character(42)
|
||||
# self.assertEqual(
|
||||
# mock_EveCharacter.objects.update_character.call_count, 1
|
||||
# )
|
||||
# self.assertEqual(
|
||||
# mock_EveCharacter.objects.update_character.call_args[0][0], 42
|
||||
# )
|
||||
|
||||
|
||||
@patch('allianceauth.eveonline.tasks.update_character')
|
||||
@patch('allianceauth.eveonline.tasks.update_alliance')
|
||||
@patch('allianceauth.eveonline.tasks.update_corp')
|
||||
@patch('allianceauth.eveonline.providers.provider')
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
@patch('allianceauth.eveonline.providers.esi_client_factory')
|
||||
@patch('allianceauth.eveonline.tasks.providers')
|
||||
@patch('allianceauth.eveonline.tasks.CHUNK_SIZE', 2)
|
||||
class TestRunModelUpdate(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
EveCorporationInfo.objects.all().delete()
|
||||
EveAllianceInfo.objects.all().delete()
|
||||
EveCharacter.objects.all().delete()
|
||||
|
||||
class TestRunModelUpdate(TransactionTestCase):
|
||||
def test_should_run_updates(self, mock_providers, mock_esi_client_factory):
|
||||
# given
|
||||
mock_providers.provider.client = EsiClientStub()
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
EveCorporationInfo.objects.create(
|
||||
corporation_id=2345,
|
||||
corporation_name='corp.name',
|
||||
corporation_ticker='c.c.t',
|
||||
corporation_id=2001,
|
||||
corporation_name="Wayne Technologies",
|
||||
corporation_ticker="WTE",
|
||||
member_count=10,
|
||||
alliance=None,
|
||||
)
|
||||
EveAllianceInfo.objects.create(
|
||||
alliance_id=3456,
|
||||
alliance_name='alliance.name',
|
||||
alliance_ticker='a.t',
|
||||
executor_corp_id=5,
|
||||
alliance_3001 = EveAllianceInfo.objects.create(
|
||||
alliance_id=3001,
|
||||
alliance_name="Wayne Enterprises",
|
||||
alliance_ticker="WYE",
|
||||
executor_corp_id=2003
|
||||
)
|
||||
EveCharacter.objects.create(
|
||||
character_id=1,
|
||||
character_name='character.name1',
|
||||
corporation_id=2345,
|
||||
corporation_name='character.corp.name',
|
||||
corporation_ticker='c.c.t', # max 5 chars
|
||||
corporation_2003 = EveCorporationInfo.objects.create(
|
||||
corporation_id=2003,
|
||||
corporation_name="Wayne Energy",
|
||||
corporation_ticker="WEG",
|
||||
member_count=99,
|
||||
alliance=None,
|
||||
)
|
||||
character_1001 = EveCharacter.objects.create(
|
||||
character_id=1001,
|
||||
character_name="Bruce Wayne",
|
||||
corporation_id=2002,
|
||||
corporation_name="Wayne Food",
|
||||
corporation_ticker="WYF",
|
||||
alliance_id=None
|
||||
)
|
||||
EveCharacter.objects.create(
|
||||
character_id=2,
|
||||
character_name='character.name2',
|
||||
corporation_id=9876,
|
||||
corporation_name='character.corp.name',
|
||||
corporation_ticker='c.c.t', # max 5 chars
|
||||
alliance_id=3456,
|
||||
alliance_name='character.alliance.name',
|
||||
)
|
||||
EveCharacter.objects.create(
|
||||
character_id=3,
|
||||
character_name='character.name3',
|
||||
corporation_id=9876,
|
||||
corporation_name='character.corp.name',
|
||||
corporation_ticker='c.c.t', # max 5 chars
|
||||
alliance_id=3456,
|
||||
alliance_name='character.alliance.name',
|
||||
)
|
||||
EveCharacter.objects.create(
|
||||
character_id=4,
|
||||
character_name='character.name4',
|
||||
corporation_id=9876,
|
||||
corporation_name='character.corp.name',
|
||||
corporation_ticker='c.c.t', # max 5 chars
|
||||
alliance_id=3456,
|
||||
alliance_name='character.alliance.name',
|
||||
)
|
||||
"""
|
||||
EveCharacter.objects.create(
|
||||
character_id=5,
|
||||
character_name='character.name5',
|
||||
corporation_id=9876,
|
||||
corporation_name='character.corp.name',
|
||||
corporation_ticker='c.c.t', # max 5 chars
|
||||
alliance_id=3456,
|
||||
alliance_name='character.alliance.name',
|
||||
)
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.affiliations = [
|
||||
{'character_id': 1, 'corporation_id': 5},
|
||||
{'character_id': 2, 'corporation_id': 9876, 'alliance_id': 3456},
|
||||
{'character_id': 3, 'corporation_id': 9876, 'alliance_id': 7456},
|
||||
{'character_id': 4, 'corporation_id': 9876, 'alliance_id': 3456}
|
||||
]
|
||||
self.names = [
|
||||
{'id': 1, 'name': 'character.name1'},
|
||||
{'id': 2, 'name': 'character.name2'},
|
||||
{'id': 3, 'name': 'character.name3'},
|
||||
{'id': 4, 'name': 'character.name4_new'}
|
||||
]
|
||||
|
||||
def test_normal_run(
|
||||
self,
|
||||
mock_provider,
|
||||
mock_update_corp,
|
||||
mock_update_alliance,
|
||||
mock_update_character,
|
||||
):
|
||||
def get_affiliations(characters: list):
|
||||
response = [x for x in self.affiliations if x['character_id'] in characters]
|
||||
mock_operator = Mock(**{'result.return_value': response})
|
||||
return mock_operator
|
||||
|
||||
def get_names(ids: list):
|
||||
response = [x for x in self.names if x['id'] in ids]
|
||||
mock_operator = Mock(**{'result.return_value': response})
|
||||
return mock_operator
|
||||
|
||||
mock_provider.client.Character.post_characters_affiliation.side_effect \
|
||||
= get_affiliations
|
||||
|
||||
mock_provider.client.Universe.post_universe_names.side_effect = get_names
|
||||
|
||||
# when
|
||||
run_model_update()
|
||||
|
||||
# then
|
||||
character_1001.refresh_from_db()
|
||||
self.assertEqual(
|
||||
mock_provider.client.Character.post_characters_affiliation.call_count, 2
|
||||
character_1001.corporation_id, 2001 # char has new corp
|
||||
)
|
||||
corporation_2003.refresh_from_db()
|
||||
self.assertEqual(
|
||||
mock_provider.client.Universe.post_universe_names.call_count, 2
|
||||
corporation_2003.alliance.alliance_id, 3001 # corp has new alliance
|
||||
)
|
||||
alliance_3001.refresh_from_db()
|
||||
self.assertEqual(
|
||||
alliance_3001.executor_corp_id, 2001 # alliance has been updated
|
||||
)
|
||||
|
||||
# character 1 has changed corp
|
||||
# character 2 no change
|
||||
# character 3 has changed alliance
|
||||
# character 4 has changed name
|
||||
self.assertEqual(mock_update_corp.apply_async.call_count, 1)
|
||||
self.assertEqual(
|
||||
int(mock_update_corp.apply_async.call_args[1]['args'][0]), 2345
|
||||
)
|
||||
self.assertEqual(mock_update_alliance.apply_async.call_count, 1)
|
||||
self.assertEqual(
|
||||
int(mock_update_alliance.apply_async.call_args[1]['args'][0]), 3456
|
||||
)
|
||||
characters_updated = {
|
||||
x[1]['args'][0] for x in mock_update_character.apply_async.call_args_list
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
@patch('allianceauth.eveonline.tasks.update_character', wraps=update_character)
|
||||
@patch('allianceauth.eveonline.providers.esi_client_factory')
|
||||
@patch('allianceauth.eveonline.tasks.providers')
|
||||
@patch('allianceauth.eveonline.tasks.CHUNK_SIZE', 2)
|
||||
class TestUpdateCharacterChunk(TestCase):
|
||||
@staticmethod
|
||||
def _updated_character_ids(spy_update_character) -> set:
|
||||
"""Character IDs passed to update_character task for update."""
|
||||
return {
|
||||
x[1]["args"][0] for x in spy_update_character.apply_async.call_args_list
|
||||
}
|
||||
excepted = {1, 3, 4}
|
||||
self.assertSetEqual(characters_updated, excepted)
|
||||
|
||||
def test_ignore_character_not_in_affiliations(
|
||||
self,
|
||||
mock_provider,
|
||||
mock_update_corp,
|
||||
mock_update_alliance,
|
||||
mock_update_character,
|
||||
def test_should_update_corp_change(
|
||||
self, mock_providers, mock_esi_client_factory, spy_update_character
|
||||
):
|
||||
def get_affiliations(characters: list):
|
||||
response = [x for x in self.affiliations if x['character_id'] in characters]
|
||||
mock_operator = Mock(**{'result.return_value': response})
|
||||
return mock_operator
|
||||
# given
|
||||
mock_providers.provider.client = EsiClientStub()
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
character_1001 = EveCharacter.objects.create(
|
||||
character_id=1001,
|
||||
character_name="Bruce Wayne",
|
||||
corporation_id=2003,
|
||||
corporation_name="Wayne Energy",
|
||||
corporation_ticker="WEG",
|
||||
alliance_id=3001,
|
||||
alliance_name="Wayne Enterprises",
|
||||
alliance_ticker="WYE",
|
||||
)
|
||||
character_1002 = EveCharacter.objects.create(
|
||||
character_id=1002,
|
||||
character_name="Peter Parker",
|
||||
corporation_id=2001,
|
||||
corporation_name="Wayne Technologies",
|
||||
corporation_ticker="WTE",
|
||||
alliance_id=3001,
|
||||
alliance_name="Wayne Enterprises",
|
||||
alliance_ticker="WYE",
|
||||
)
|
||||
# when
|
||||
update_character_chunk([
|
||||
character_1001.character_id, character_1002.character_id
|
||||
])
|
||||
# then
|
||||
character_1001.refresh_from_db()
|
||||
self.assertEqual(character_1001.corporation_id, 2001)
|
||||
self.assertSetEqual(self._updated_character_ids(spy_update_character), {1001})
|
||||
|
||||
def get_names(ids: list):
|
||||
response = [x for x in self.names if x['id'] in ids]
|
||||
mock_operator = Mock(**{'result.return_value': response})
|
||||
return mock_operator
|
||||
|
||||
del self.affiliations[0]
|
||||
|
||||
mock_provider.client.Character.post_characters_affiliation.side_effect \
|
||||
= get_affiliations
|
||||
|
||||
mock_provider.client.Universe.post_universe_names.side_effect = get_names
|
||||
|
||||
run_model_update()
|
||||
characters_updated = {
|
||||
x[1]['args'][0] for x in mock_update_character.apply_async.call_args_list
|
||||
}
|
||||
excepted = {3, 4}
|
||||
self.assertSetEqual(characters_updated, excepted)
|
||||
|
||||
def test_ignore_character_not_in_names(
|
||||
self,
|
||||
mock_provider,
|
||||
mock_update_corp,
|
||||
mock_update_alliance,
|
||||
mock_update_character,
|
||||
def test_should_update_name_change(
|
||||
self, mock_providers, mock_esi_client_factory, spy_update_character
|
||||
):
|
||||
def get_affiliations(characters: list):
|
||||
response = [x for x in self.affiliations if x['character_id'] in characters]
|
||||
mock_operator = Mock(**{'result.return_value': response})
|
||||
return mock_operator
|
||||
# given
|
||||
mock_providers.provider.client = EsiClientStub()
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
character_1001 = EveCharacter.objects.create(
|
||||
character_id=1001,
|
||||
character_name="Batman",
|
||||
corporation_id=2001,
|
||||
corporation_name="Wayne Technologies",
|
||||
corporation_ticker="WTE",
|
||||
alliance_id=3001,
|
||||
alliance_name="Wayne Technologies",
|
||||
alliance_ticker="WYT",
|
||||
)
|
||||
# when
|
||||
update_character_chunk([character_1001.character_id])
|
||||
# then
|
||||
character_1001.refresh_from_db()
|
||||
self.assertEqual(character_1001.character_name, "Bruce Wayne")
|
||||
self.assertSetEqual(self._updated_character_ids(spy_update_character), {1001})
|
||||
|
||||
def get_names(ids: list):
|
||||
response = [x for x in self.names if x['id'] in ids]
|
||||
mock_operator = Mock(**{'result.return_value': response})
|
||||
return mock_operator
|
||||
def test_should_update_alliance_change(
|
||||
self, mock_providers, mock_esi_client_factory, spy_update_character
|
||||
):
|
||||
# given
|
||||
mock_providers.provider.client = EsiClientStub()
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
character_1001 = EveCharacter.objects.create(
|
||||
character_id=1001,
|
||||
character_name="Bruce Wayne",
|
||||
corporation_id=2001,
|
||||
corporation_name="Wayne Technologies",
|
||||
corporation_ticker="WTE",
|
||||
alliance_id=None,
|
||||
)
|
||||
# when
|
||||
update_character_chunk([character_1001.character_id])
|
||||
# then
|
||||
character_1001.refresh_from_db()
|
||||
self.assertEqual(character_1001.alliance_id, 3001)
|
||||
self.assertSetEqual(self._updated_character_ids(spy_update_character), {1001})
|
||||
|
||||
del self.names[3]
|
||||
def test_should_not_update_when_not_changed(
|
||||
self, mock_providers, mock_esi_client_factory, spy_update_character
|
||||
):
|
||||
# given
|
||||
mock_providers.provider.client = EsiClientStub()
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
character_1001 = EveCharacter.objects.create(
|
||||
character_id=1001,
|
||||
character_name="Bruce Wayne",
|
||||
corporation_id=2001,
|
||||
corporation_name="Wayne Technologies",
|
||||
corporation_ticker="WTE",
|
||||
alliance_id=3001,
|
||||
alliance_name="Wayne Technologies",
|
||||
alliance_ticker="WYT",
|
||||
)
|
||||
# when
|
||||
update_character_chunk([character_1001.character_id])
|
||||
# then
|
||||
self.assertSetEqual(self._updated_character_ids(spy_update_character), set())
|
||||
|
||||
mock_provider.client.Character.post_characters_affiliation.side_effect \
|
||||
= get_affiliations
|
||||
|
||||
mock_provider.client.Universe.post_universe_names.side_effect = get_names
|
||||
|
||||
run_model_update()
|
||||
characters_updated = {
|
||||
x[1]['args'][0] for x in mock_update_character.apply_async.call_args_list
|
||||
}
|
||||
excepted = {1, 3}
|
||||
self.assertSetEqual(characters_updated, excepted)
|
||||
def test_should_fall_back_to_single_updates_when_bulk_update_failed(
|
||||
self, mock_providers, mock_esi_client_factory, spy_update_character
|
||||
):
|
||||
# given
|
||||
mock_providers.provider.client.Character.post_characters_affiliation\
|
||||
.side_effect = OSError
|
||||
mock_esi_client_factory.return_value = EsiClientStub()
|
||||
character_1001 = EveCharacter.objects.create(
|
||||
character_id=1001,
|
||||
character_name="Bruce Wayne",
|
||||
corporation_id=2001,
|
||||
corporation_name="Wayne Technologies",
|
||||
corporation_ticker="WTE",
|
||||
alliance_id=3001,
|
||||
alliance_name="Wayne Technologies",
|
||||
alliance_ticker="WYT",
|
||||
)
|
||||
# when
|
||||
update_character_chunk([character_1001.character_id])
|
||||
# then
|
||||
self.assertSetEqual(self._updated_character_ids(spy_update_character), {1001})
|
||||
|
||||
@@ -212,7 +212,14 @@ def fatlink_monthly_personal_statistics_view(request, year, month, char_id=None)
|
||||
start_of_previous_month = first_day_of_previous_month(year, month)
|
||||
|
||||
if request.user.has_perm('auth.fleetactivitytracking_statistics') and char_id:
|
||||
user = EveCharacter.objects.get(character_id=char_id).user
|
||||
try:
|
||||
user = EveCharacter.objects.get(character_id=char_id).character_ownership.user
|
||||
except EveCharacter.DoesNotExist:
|
||||
messages.error(request, _('Character does not exist'))
|
||||
return redirect('fatlink:view')
|
||||
except AttributeError:
|
||||
messages.error(request, _('User does not exist'))
|
||||
return redirect('fatlink:view')
|
||||
else:
|
||||
user = request.user
|
||||
logger.debug(f"Personal monthly statistics view for user {user} called by {request.user}")
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
from django import forms
|
||||
from django.apps import apps
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib import admin
|
||||
from django.contrib.auth.models import Group as BaseGroup, User
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db.models import Count
|
||||
from django.db.models.functions import Lower
|
||||
from django.db.models.signals import pre_save, post_save, pre_delete, \
|
||||
post_delete, m2m_changed
|
||||
from django.dispatch import receiver
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from .models import AuthGroup, ReservedGroupName
|
||||
from .models import GroupRequest
|
||||
from django.contrib.auth.models import Group as BaseGroup, Permission, User
|
||||
from django.db.models import Count, Exists, OuterRef
|
||||
from django.db.models.functions import Lower
|
||||
from django.db.models.signals import (
|
||||
m2m_changed,
|
||||
post_delete,
|
||||
post_save,
|
||||
pre_delete,
|
||||
pre_save
|
||||
)
|
||||
from django.dispatch import receiver
|
||||
|
||||
from .forms import GroupAdminForm, ReservedGroupNameAdminForm
|
||||
from .models import AuthGroup, GroupRequest, ReservedGroupName
|
||||
from .tasks import remove_users_not_matching_states_from_group
|
||||
|
||||
if 'eve_autogroups' in apps.app_configs:
|
||||
_has_auto_groups = True
|
||||
@@ -28,10 +30,12 @@ class AuthGroupInlineAdmin(admin.StackedInline):
|
||||
'description',
|
||||
'group_leaders',
|
||||
'group_leader_groups',
|
||||
'states', 'internal',
|
||||
'states',
|
||||
'internal',
|
||||
'hidden',
|
||||
'open',
|
||||
'public'
|
||||
'public',
|
||||
'restricted',
|
||||
)
|
||||
verbose_name_plural = 'Auth Settings'
|
||||
verbose_name = ''
|
||||
@@ -50,6 +54,11 @@ class AuthGroupInlineAdmin(admin.StackedInline):
|
||||
def has_change_permission(self, request, obj=None):
|
||||
return request.user.has_perm('auth.change_group')
|
||||
|
||||
def get_readonly_fields(self, request, obj=None):
|
||||
if not request.user.is_superuser:
|
||||
return self.readonly_fields + ("restricted",)
|
||||
return self.readonly_fields
|
||||
|
||||
|
||||
if _has_auto_groups:
|
||||
class IsAutoGroupFilter(admin.SimpleListFilter):
|
||||
@@ -96,27 +105,15 @@ class HasLeaderFilter(admin.SimpleListFilter):
|
||||
return queryset
|
||||
|
||||
|
||||
class GroupAdminForm(forms.ModelForm):
|
||||
def clean_name(self):
|
||||
my_name = self.cleaned_data['name']
|
||||
if ReservedGroupName.objects.filter(name__iexact=my_name).exists():
|
||||
raise ValidationError(
|
||||
_("This name has been reserved and can not be used for groups."),
|
||||
code='reserved_name'
|
||||
)
|
||||
return my_name
|
||||
|
||||
|
||||
class GroupAdmin(admin.ModelAdmin):
|
||||
form = GroupAdminForm
|
||||
list_select_related = ('authgroup',)
|
||||
ordering = ('name',)
|
||||
list_display = (
|
||||
'name',
|
||||
'_description',
|
||||
'_properties',
|
||||
'_member_count',
|
||||
'has_leader'
|
||||
'has_leader',
|
||||
)
|
||||
list_filter = [
|
||||
'authgroup__internal',
|
||||
@@ -132,34 +129,51 @@ class GroupAdmin(admin.ModelAdmin):
|
||||
|
||||
def get_queryset(self, request):
|
||||
qs = super().get_queryset(request)
|
||||
if _has_auto_groups:
|
||||
qs = qs.prefetch_related('managedalliancegroup_set', 'managedcorpgroup_set')
|
||||
qs = qs.prefetch_related('authgroup__group_leaders').select_related('authgroup')
|
||||
qs = qs.annotate(
|
||||
member_count=Count('user', distinct=True),
|
||||
has_leader_qs = (
|
||||
AuthGroup.objects.filter(group=OuterRef('pk'), group_leaders__isnull=False)
|
||||
)
|
||||
has_leader_groups_qs = (
|
||||
AuthGroup.objects.filter(
|
||||
group=OuterRef('pk'), group_leader_groups__isnull=False
|
||||
)
|
||||
)
|
||||
qs = (
|
||||
qs.select_related('authgroup')
|
||||
.annotate(member_count=Count('user', distinct=True))
|
||||
.annotate(has_leader=Exists(has_leader_qs))
|
||||
.annotate(has_leader_groups=Exists(has_leader_groups_qs))
|
||||
)
|
||||
if _has_auto_groups:
|
||||
is_autogroup_corp = (
|
||||
Group.objects.filter(
|
||||
pk=OuterRef('pk'), managedcorpgroup__isnull=False
|
||||
)
|
||||
)
|
||||
is_autogroup_alliance = (
|
||||
Group.objects.filter(
|
||||
pk=OuterRef('pk'), managedalliancegroup__isnull=False
|
||||
)
|
||||
)
|
||||
qs = (
|
||||
qs.annotate(is_autogroup_corp=Exists(is_autogroup_corp))
|
||||
.annotate(is_autogroup_alliance=Exists(is_autogroup_alliance))
|
||||
)
|
||||
return qs
|
||||
|
||||
def _description(self, obj):
|
||||
return obj.authgroup.description
|
||||
|
||||
@admin.display(description='Members', ordering='member_count')
|
||||
def _member_count(self, obj):
|
||||
return obj.member_count
|
||||
|
||||
_member_count.short_description = 'Members'
|
||||
_member_count.admin_order_field = 'member_count'
|
||||
|
||||
@admin.display(boolean=True)
|
||||
def has_leader(self, obj):
|
||||
return obj.authgroup.group_leaders.exists() or obj.authgroup.group_leader_groups.exists()
|
||||
|
||||
has_leader.boolean = True
|
||||
return obj.has_leader or obj.has_leader_groups
|
||||
|
||||
def _properties(self, obj):
|
||||
properties = list()
|
||||
if _has_auto_groups and (
|
||||
obj.managedalliancegroup_set.exists()
|
||||
or obj.managedcorpgroup_set.exists()
|
||||
):
|
||||
if _has_auto_groups and (obj.is_autogroup_corp or obj.is_autogroup_alliance):
|
||||
properties.append('Auto Group')
|
||||
elif obj.authgroup.internal:
|
||||
properties.append('Internal')
|
||||
@@ -172,11 +186,10 @@ class GroupAdmin(admin.ModelAdmin):
|
||||
properties.append('Public')
|
||||
if not properties:
|
||||
properties.append('Default')
|
||||
|
||||
if obj.authgroup.restricted:
|
||||
properties.append('Restricted')
|
||||
return properties
|
||||
|
||||
_properties.short_description = "properties"
|
||||
|
||||
filter_horizontal = ('permissions',)
|
||||
inlines = (AuthGroupInlineAdmin,)
|
||||
|
||||
@@ -190,8 +203,15 @@ class GroupAdmin(admin.ModelAdmin):
|
||||
ag_instance = inline_form.save(commit=False)
|
||||
ag_instance.group = form.instance
|
||||
ag_instance.save()
|
||||
if ag_instance.states.exists():
|
||||
remove_users_not_matching_states_from_group.delay(ag_instance.group.pk)
|
||||
formset.save()
|
||||
|
||||
def get_readonly_fields(self, request, obj=None):
|
||||
if not request.user.is_superuser:
|
||||
return self.readonly_fields + ("permissions",)
|
||||
return self.readonly_fields
|
||||
|
||||
|
||||
class Group(BaseGroup):
|
||||
class Meta:
|
||||
@@ -216,33 +236,10 @@ class GroupRequestAdmin(admin.ModelAdmin):
|
||||
'leave_request',
|
||||
)
|
||||
|
||||
@admin.display(boolean=True, description="is leave request")
|
||||
def _leave_request(self, obj) -> True:
|
||||
return obj.leave_request
|
||||
|
||||
_leave_request.short_description = 'is leave request'
|
||||
_leave_request.boolean = True
|
||||
|
||||
|
||||
class ReservedGroupNameAdminForm(forms.ModelForm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields['created_by'].initial = self.current_user.username
|
||||
self.fields['created_at'].initial = _("(auto)")
|
||||
|
||||
created_by = forms.CharField(disabled=True)
|
||||
created_at = forms.CharField(disabled=True)
|
||||
|
||||
def clean_name(self):
|
||||
my_name = self.cleaned_data['name'].lower()
|
||||
if Group.objects.filter(name__iexact=my_name).exists():
|
||||
raise ValidationError(
|
||||
_("There already exists a group with that name."), code='already_exists'
|
||||
)
|
||||
return my_name
|
||||
|
||||
def clean_created_at(self):
|
||||
return now()
|
||||
|
||||
|
||||
@admin.register(ReservedGroupName)
|
||||
class ReservedGroupNameAdmin(admin.ModelAdmin):
|
||||
|
||||
39
allianceauth/groupmanagement/forms.py
Normal file
39
allianceauth/groupmanagement/forms.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from django import forms
|
||||
from django.contrib.auth.models import Group
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from .models import ReservedGroupName
|
||||
|
||||
|
||||
class GroupAdminForm(forms.ModelForm):
|
||||
def clean_name(self):
|
||||
my_name = self.cleaned_data['name']
|
||||
if ReservedGroupName.objects.filter(name__iexact=my_name).exists():
|
||||
raise ValidationError(
|
||||
_("This name has been reserved and can not be used for groups."),
|
||||
code='reserved_name'
|
||||
)
|
||||
return my_name
|
||||
|
||||
|
||||
class ReservedGroupNameAdminForm(forms.ModelForm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields['created_by'].initial = self.current_user.username
|
||||
self.fields['created_at'].initial = _("(auto)")
|
||||
|
||||
created_by = forms.CharField(disabled=True)
|
||||
created_at = forms.CharField(disabled=True)
|
||||
|
||||
def clean_name(self):
|
||||
my_name = self.cleaned_data['name'].lower()
|
||||
if Group.objects.filter(name__iexact=my_name).exists():
|
||||
raise ValidationError(
|
||||
_("There already exists a group with that name."), code='already_exists'
|
||||
)
|
||||
return my_name
|
||||
|
||||
def clean_created_at(self):
|
||||
return now()
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 3.2.10 on 2022-04-08 19:30
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('groupmanagement', '0018_reservedgroupname'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='authgroup',
|
||||
name='restricted',
|
||||
field=models.BooleanField(default=False, help_text='Group is restricted. This means that adding or removing users for this group requires a superuser admin.'),
|
||||
),
|
||||
]
|
||||
@@ -13,6 +13,7 @@ from allianceauth.notifications import notify
|
||||
|
||||
class GroupRequest(models.Model):
|
||||
"""Request from a user for joining or leaving a group."""
|
||||
|
||||
leave_request = models.BooleanField(default=0)
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||
@@ -44,6 +45,7 @@ class GroupRequest(models.Model):
|
||||
|
||||
class RequestLog(models.Model):
|
||||
"""Log entry about who joined and left a group and who approved it."""
|
||||
|
||||
request_type = models.BooleanField(null=True)
|
||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||
request_info = models.CharField(max_length=254)
|
||||
@@ -95,6 +97,7 @@ class AuthGroup(models.Model):
|
||||
Open - Users are automatically accepted into the group
|
||||
Not Open - Users requests must be approved before they are added to the group
|
||||
"""
|
||||
|
||||
group = models.OneToOneField(Group, on_delete=models.CASCADE, primary_key=True)
|
||||
internal = models.BooleanField(
|
||||
default=True,
|
||||
@@ -126,6 +129,13 @@ class AuthGroup(models.Model):
|
||||
"are no longer authenticated."
|
||||
)
|
||||
)
|
||||
restricted = models.BooleanField(
|
||||
default=False,
|
||||
help_text=_(
|
||||
"Group is restricted. This means that adding or removing users "
|
||||
"for this group requires a superuser admin."
|
||||
)
|
||||
)
|
||||
group_leaders = models.ManyToManyField(
|
||||
User,
|
||||
related_name='leads_groups',
|
||||
@@ -179,12 +189,22 @@ class AuthGroup(models.Model):
|
||||
| User.objects.filter(groups__in=list(self.group_leader_groups.all()))
|
||||
)
|
||||
|
||||
def remove_users_not_matching_states(self):
|
||||
"""Remove users not matching defined states from related group."""
|
||||
states_qs = self.states.all()
|
||||
if states_qs.exists():
|
||||
states = list(states_qs)
|
||||
non_compliant_users = self.group.user_set.exclude(profile__state__in=states)
|
||||
for user in non_compliant_users:
|
||||
self.group.user_set.remove(user)
|
||||
|
||||
|
||||
class ReservedGroupName(models.Model):
|
||||
"""Name that can not be used for groups.
|
||||
|
||||
This enables AA to ignore groups on other services (e.g. Discord) with that name.
|
||||
"""
|
||||
|
||||
name = models.CharField(
|
||||
_('name'),
|
||||
max_length=150,
|
||||
|
||||
10
allianceauth/groupmanagement/tasks.py
Normal file
10
allianceauth/groupmanagement/tasks.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from celery import shared_task
|
||||
|
||||
from django.contrib.auth.models import Group
|
||||
|
||||
|
||||
@shared_task
|
||||
def remove_users_not_matching_states_from_group(group_pk: int) -> None:
|
||||
"""Remove users not matching defined states from related group."""
|
||||
group = Group.objects.get(pk=group_pk)
|
||||
group.authgroup.remove_users_not_matching_states()
|
||||
@@ -127,6 +127,8 @@
|
||||
],
|
||||
bootstrap: true
|
||||
},
|
||||
"stateSave": true,
|
||||
"stateDuration": 0
|
||||
});
|
||||
});
|
||||
{% endblock %}
|
||||
|
||||
@@ -104,7 +104,9 @@
|
||||
"sortable": false,
|
||||
"targets": [2]
|
||||
},
|
||||
]
|
||||
],
|
||||
"stateSave": true,
|
||||
"stateDuration": 0
|
||||
});
|
||||
});
|
||||
{% endblock %}
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from django_webtest import WebTest
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib import admin
|
||||
from django.contrib.admin.sites import AdminSite
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import TestCase, RequestFactory, Client
|
||||
from django.test import TestCase, RequestFactory, Client, override_settings
|
||||
|
||||
from allianceauth.authentication.models import CharacterOwnership, State
|
||||
from allianceauth.eveonline.models import (
|
||||
EveCharacter, EveCorporationInfo, EveAllianceInfo
|
||||
)
|
||||
from ..admin import HasLeaderFilter, GroupAdmin, Group
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from . import get_admin_change_view_url
|
||||
from ..admin import HasLeaderFilter, GroupAdmin, Group
|
||||
from ..models import ReservedGroupName
|
||||
|
||||
|
||||
@@ -33,7 +37,6 @@ class MockRequest:
|
||||
|
||||
|
||||
class TestGroupAdmin(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
@@ -233,60 +236,104 @@ class TestGroupAdmin(TestCase):
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_member_count(self):
|
||||
expected = 1
|
||||
obj = self.modeladmin.get_queryset(MockRequest(user=self.user_1))\
|
||||
.get(pk=self.group_1.pk)
|
||||
# given
|
||||
request = MockRequest(user=self.user_1)
|
||||
obj = self.modeladmin.get_queryset(request).get(pk=self.group_1.pk)
|
||||
# when
|
||||
result = self.modeladmin._member_count(obj)
|
||||
self.assertEqual(result, expected)
|
||||
# then
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
def test_has_leader_user(self):
|
||||
result = self.modeladmin.has_leader(self.group_1)
|
||||
# given
|
||||
request = MockRequest(user=self.user_1)
|
||||
obj = self.modeladmin.get_queryset(request).get(pk=self.group_1.pk)
|
||||
# when
|
||||
result = self.modeladmin.has_leader(obj)
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_has_leader_group(self):
|
||||
result = self.modeladmin.has_leader(self.group_2)
|
||||
# given
|
||||
request = MockRequest(user=self.user_1)
|
||||
obj = self.modeladmin.get_queryset(request).get(pk=self.group_2.pk)
|
||||
# when
|
||||
result = self.modeladmin.has_leader(obj)
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_properties_1(self):
|
||||
expected = ['Default']
|
||||
result = self.modeladmin._properties(self.group_1)
|
||||
self.assertListEqual(result, expected)
|
||||
# given
|
||||
request = MockRequest(user=self.user_1)
|
||||
obj = self.modeladmin.get_queryset(request).get(pk=self.group_1.pk)
|
||||
# when
|
||||
result = self.modeladmin._properties(obj)
|
||||
self.assertListEqual(result, ['Default'])
|
||||
|
||||
def test_properties_2(self):
|
||||
expected = ['Internal']
|
||||
result = self.modeladmin._properties(self.group_2)
|
||||
self.assertListEqual(result, expected)
|
||||
# given
|
||||
request = MockRequest(user=self.user_1)
|
||||
obj = self.modeladmin.get_queryset(request).get(pk=self.group_2.pk)
|
||||
# when
|
||||
result = self.modeladmin._properties(obj)
|
||||
self.assertListEqual(result, ['Internal'])
|
||||
|
||||
def test_properties_3(self):
|
||||
expected = ['Hidden']
|
||||
result = self.modeladmin._properties(self.group_3)
|
||||
self.assertListEqual(result, expected)
|
||||
# given
|
||||
request = MockRequest(user=self.user_1)
|
||||
obj = self.modeladmin.get_queryset(request).get(pk=self.group_3.pk)
|
||||
# when
|
||||
result = self.modeladmin._properties(obj)
|
||||
self.assertListEqual(result, ['Hidden'])
|
||||
|
||||
def test_properties_4(self):
|
||||
expected = ['Open']
|
||||
result = self.modeladmin._properties(self.group_4)
|
||||
self.assertListEqual(result, expected)
|
||||
# given
|
||||
request = MockRequest(user=self.user_1)
|
||||
obj = self.modeladmin.get_queryset(request).get(pk=self.group_4.pk)
|
||||
# when
|
||||
result = self.modeladmin._properties(obj)
|
||||
self.assertListEqual(result, ['Open'])
|
||||
|
||||
def test_properties_5(self):
|
||||
expected = ['Public']
|
||||
result = self.modeladmin._properties(self.group_5)
|
||||
self.assertListEqual(result, expected)
|
||||
# given
|
||||
request = MockRequest(user=self.user_1)
|
||||
obj = self.modeladmin.get_queryset(request).get(pk=self.group_5.pk)
|
||||
# when
|
||||
result = self.modeladmin._properties(obj)
|
||||
self.assertListEqual(result, ['Public'])
|
||||
|
||||
def test_properties_6(self):
|
||||
expected = ['Hidden', 'Open', 'Public']
|
||||
result = self.modeladmin._properties(self.group_6)
|
||||
self.assertListEqual(result, expected)
|
||||
# given
|
||||
request = MockRequest(user=self.user_1)
|
||||
obj = self.modeladmin.get_queryset(request).get(pk=self.group_6.pk)
|
||||
# when
|
||||
result = self.modeladmin._properties(obj)
|
||||
self.assertListEqual(result, ['Hidden', 'Open', 'Public'])
|
||||
|
||||
if _has_auto_groups:
|
||||
@patch(MODULE_PATH + '._has_auto_groups', True)
|
||||
def test_properties_7(self):
|
||||
def test_should_show_autogroup_for_corporation(self):
|
||||
# given
|
||||
self._create_autogroups()
|
||||
expected = ['Auto Group']
|
||||
my_group = Group.objects\
|
||||
.filter(managedcorpgroup__isnull=False)\
|
||||
.first()
|
||||
result = self.modeladmin._properties(my_group)
|
||||
self.assertListEqual(result, expected)
|
||||
request = MockRequest(user=self.user_1)
|
||||
queryset = self.modeladmin.get_queryset(request)
|
||||
obj = queryset.filter(managedcorpgroup__isnull=False).first()
|
||||
# when
|
||||
result = self.modeladmin._properties(obj)
|
||||
# then
|
||||
self.assertListEqual(result, ['Auto Group'])
|
||||
|
||||
@patch(MODULE_PATH + '._has_auto_groups', True)
|
||||
def test_should_show_autogroup_for_alliance(self):
|
||||
# given
|
||||
self._create_autogroups()
|
||||
request = MockRequest(user=self.user_1)
|
||||
queryset = self.modeladmin.get_queryset(request)
|
||||
obj = queryset.filter(managedalliancegroup__isnull=False).first()
|
||||
# when
|
||||
result = self.modeladmin._properties(obj)
|
||||
# then
|
||||
self.assertListEqual(result, ['Auto Group'])
|
||||
|
||||
# actions
|
||||
|
||||
@@ -468,6 +515,136 @@ class TestGroupAdmin(TestCase):
|
||||
self.assertFalse(Group.objects.filter(name="new group").exists())
|
||||
|
||||
|
||||
class TestGroupAdminChangeFormSuperuserExclusiveEdits(WebTest):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
cls.super_admin = User.objects.create_superuser("super_admin")
|
||||
cls.staff_admin = User.objects.create_user("staff_admin")
|
||||
cls.staff_admin.is_staff = True
|
||||
cls.staff_admin.save()
|
||||
cls.staff_admin = AuthUtils.add_permissions_to_user_by_name(
|
||||
[
|
||||
"auth.add_group",
|
||||
"auth.change_group",
|
||||
"auth.view_group",
|
||||
"groupmanagement.add_group",
|
||||
"groupmanagement.change_group",
|
||||
"groupmanagement.view_group",
|
||||
],
|
||||
cls.staff_admin
|
||||
)
|
||||
cls.superuser_exclusive_fields = ["permissions", "authgroup-0-restricted"]
|
||||
|
||||
def test_should_show_all_fields_to_superuser_for_add(self):
|
||||
# given
|
||||
self.app.set_user(self.super_admin)
|
||||
page = self.app.get("/admin/groupmanagement/group/add/")
|
||||
# when
|
||||
form = page.forms["group_form"]
|
||||
# then
|
||||
for field in self.superuser_exclusive_fields:
|
||||
with self.subTest(field=field):
|
||||
self.assertIn(field, form.fields)
|
||||
|
||||
def test_should_not_show_all_fields_to_staff_admins_for_add(self):
|
||||
# given
|
||||
self.app.set_user(self.staff_admin)
|
||||
page = self.app.get("/admin/groupmanagement/group/add/")
|
||||
# when
|
||||
form = page.forms["group_form"]
|
||||
# then
|
||||
for field in self.superuser_exclusive_fields:
|
||||
with self.subTest(field=field):
|
||||
self.assertNotIn(field, form.fields)
|
||||
|
||||
def test_should_show_all_fields_to_superuser_for_change(self):
|
||||
# given
|
||||
self.app.set_user(self.super_admin)
|
||||
group = Group.objects.create(name="Dummy group")
|
||||
page = self.app.get(f"/admin/groupmanagement/group/{group.pk}/change/")
|
||||
# when
|
||||
form = page.forms["group_form"]
|
||||
# then
|
||||
for field in self.superuser_exclusive_fields:
|
||||
with self.subTest(field=field):
|
||||
self.assertIn(field, form.fields)
|
||||
|
||||
def test_should_not_show_all_fields_to_staff_admin_for_change(self):
|
||||
# given
|
||||
self.app.set_user(self.staff_admin)
|
||||
group = Group.objects.create(name="Dummy group")
|
||||
page = self.app.get(f"/admin/groupmanagement/group/{group.pk}/change/")
|
||||
# when
|
||||
form = page.forms["group_form"]
|
||||
# then
|
||||
for field in self.superuser_exclusive_fields:
|
||||
with self.subTest(field=field):
|
||||
self.assertNotIn(field, form.fields)
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
|
||||
class TestGroupAdmin2(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.superuser = User.objects.create_superuser("super")
|
||||
|
||||
def test_should_remove_users_from_state_groups(self):
|
||||
# given
|
||||
user_member = AuthUtils.create_user("Bruce Wayne")
|
||||
character_member = AuthUtils.add_main_character_2(
|
||||
user_member,
|
||||
name="Bruce Wayne",
|
||||
character_id=1001,
|
||||
corp_id=2001,
|
||||
corp_name="Wayne Technologies",
|
||||
)
|
||||
user_guest = AuthUtils.create_user("Lex Luthor")
|
||||
AuthUtils.add_main_character_2(
|
||||
user_guest,
|
||||
name="Lex Luthor",
|
||||
character_id=1011,
|
||||
corp_id=2011,
|
||||
corp_name="Luthor Corp",
|
||||
)
|
||||
member_state = AuthUtils.get_member_state()
|
||||
member_state.member_characters.add(character_member)
|
||||
user_member.refresh_from_db()
|
||||
user_guest.refresh_from_db()
|
||||
group = Group.objects.create(name="dummy")
|
||||
user_member.groups.add(group)
|
||||
user_guest.groups.add(group)
|
||||
group.authgroup.states.add(member_state)
|
||||
self.client.force_login(self.superuser)
|
||||
# when
|
||||
response = self.client.post(
|
||||
f"/admin/groupmanagement/group/{group.pk}/change/",
|
||||
data={
|
||||
"name": f"{group.name}",
|
||||
"authgroup-TOTAL_FORMS": "1",
|
||||
"authgroup-INITIAL_FORMS": "1",
|
||||
"authgroup-MIN_NUM_FORMS": "0",
|
||||
"authgroup-MAX_NUM_FORMS": "1",
|
||||
"authgroup-0-description": "",
|
||||
"authgroup-0-states": f"{member_state.pk}",
|
||||
"authgroup-0-internal": "on",
|
||||
"authgroup-0-hidden": "on",
|
||||
"authgroup-0-group": f"{group.pk}",
|
||||
"authgroup-__prefix__-description": "",
|
||||
"authgroup-__prefix__-internal": "on",
|
||||
"authgroup-__prefix__-hidden": "on",
|
||||
"authgroup-__prefix__-group": f"{group.pk}",
|
||||
"_save": "Save"
|
||||
}
|
||||
)
|
||||
# then
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, "/admin/groupmanagement/group/")
|
||||
self.assertIn(group, user_member.groups.all())
|
||||
self.assertNotIn(group, user_guest.groups.all())
|
||||
|
||||
|
||||
class TestReservedGroupNameAdmin(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@@ -232,6 +232,38 @@ class TestAuthGroup(TestCase):
|
||||
expected = 'Superheros'
|
||||
self.assertEqual(str(group.authgroup), expected)
|
||||
|
||||
def test_should_remove_guests_from_group_when_restricted_to_members_only(self):
|
||||
# given
|
||||
user_member = AuthUtils.create_user("Bruce Wayne")
|
||||
character_member = AuthUtils.add_main_character_2(
|
||||
user_member,
|
||||
name="Bruce Wayne",
|
||||
character_id=1001,
|
||||
corp_id=2001,
|
||||
corp_name="Wayne Technologies",
|
||||
)
|
||||
user_guest = AuthUtils.create_user("Lex Luthor")
|
||||
AuthUtils.add_main_character_2(
|
||||
user_guest,
|
||||
name="Lex Luthor",
|
||||
character_id=1011,
|
||||
corp_id=2011,
|
||||
corp_name="Luthor Corp",
|
||||
)
|
||||
member_state = AuthUtils.get_member_state()
|
||||
member_state.member_characters.add(character_member)
|
||||
user_member.refresh_from_db()
|
||||
user_guest.refresh_from_db()
|
||||
group = Group.objects.create(name="dummy")
|
||||
user_member.groups.add(group)
|
||||
user_guest.groups.add(group)
|
||||
group.authgroup.states.add(member_state)
|
||||
# when
|
||||
group.authgroup.remove_users_not_matching_states()
|
||||
# then
|
||||
self.assertIn(group, user_member.groups.all())
|
||||
self.assertNotIn(group, user_guest.groups.all())
|
||||
|
||||
|
||||
class TestAuthGroupRequestApprovers(TestCase):
|
||||
def setUp(self) -> None:
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
from .core import notify # noqa: F401
|
||||
|
||||
default_app_config = 'allianceauth.notifications.apps.NotificationsConfig'
|
||||
|
||||
|
||||
def notify(
|
||||
user: object, title: str, message: str = None, level: str = 'info'
|
||||
) -> None:
|
||||
"""Sends a new notification to user. Convenience function to manager pendant."""
|
||||
from .models import Notification
|
||||
Notification.objects.notify_user(user, title, message, level)
|
||||
|
||||
33
allianceauth/notifications/core.py
Normal file
33
allianceauth/notifications/core.py
Normal file
@@ -0,0 +1,33 @@
|
||||
class NotifyApiWrapper:
|
||||
"""Wrapper to create notify API."""
|
||||
|
||||
def __call__(self, *args, **kwargs): # provide old API for backwards compatibility
|
||||
return self._add_notification(*args, **kwargs)
|
||||
|
||||
def danger(self, user: object, title: str, message: str = None) -> None:
|
||||
"""Add danger notification for user."""
|
||||
self._add_notification(user, title, message, level="danger")
|
||||
|
||||
def info(self, user: object, title: str, message: str = None) -> None:
|
||||
"""Add info notification for user."""
|
||||
self._add_notification(user=user, title=title, message=message, level="info")
|
||||
|
||||
def success(self, user: object, title: str, message: str = None) -> None:
|
||||
"""Add success notification for user."""
|
||||
self._add_notification(user, title, message, level="success")
|
||||
|
||||
def warning(self, user: object, title: str, message: str = None) -> None:
|
||||
"""Add warning notification for user."""
|
||||
self._add_notification(user, title, message, level="warning")
|
||||
|
||||
def _add_notification(
|
||||
self, user: object, title: str, message: str = None, level: str = "info"
|
||||
) -> None:
|
||||
from .models import Notification
|
||||
|
||||
Notification.objects.notify_user(
|
||||
user=user, title=title, message=message, level=level
|
||||
)
|
||||
|
||||
|
||||
notify = NotifyApiWrapper()
|
||||
@@ -5,91 +5,34 @@
|
||||
{% block page_title %}{% translate "Notifications" %}{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<div class="col-lg-12">
|
||||
<h1 class="page-header text-center">{% translate "Notifications" %}</h1>
|
||||
<div class="col-lg-12 container" id="example">
|
||||
<div class="row">
|
||||
<div class="col-lg-12">
|
||||
<div class="panel panel-default">
|
||||
<div class="panel-heading">
|
||||
<ul class="nav nav-pills">
|
||||
<li class="active"><a data-toggle="pill" href="#unread">{% translate "Unread" %}
|
||||
<b>({{ unread|length }})</b></a></li>
|
||||
<li><a data-toggle="pill" href="#read">{% translate "Read" %} <b>({{ read|length }})</b></a>
|
||||
</li>
|
||||
<div class="pull-right">
|
||||
<a href="{% url 'notifications:mark_all_read' %}" class="btn btn-primary">{% translate "Mark All Read" %}</a>
|
||||
<a href="{% url 'notifications:delete_all_read' %}" class="btn btn-danger">{% translate "Delete All Read" %}</a>
|
||||
</div>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="panel-body">
|
||||
<div class="tab-content">
|
||||
<div id="unread" class="tab-pane fade in active">
|
||||
<div class="table-responsive">
|
||||
{% if unread %}
|
||||
<table class="table table-condensed table-hover table-striped">
|
||||
<tr>
|
||||
<th class="text-center">{% translate "Timestamp" %}</th>
|
||||
<th class="text-center">{% translate "Title" %}</th>
|
||||
<th class="text-center">{% translate "Action" %}</th>
|
||||
</tr>
|
||||
{% for notif in unread %}
|
||||
<tr class="{{ notif.level }}">
|
||||
<td class="text-center">{{ notif.timestamp }}</td>
|
||||
<td class="text-center">{{ notif.title }}</td>
|
||||
<td class="text-center">
|
||||
<a href="{% url 'notifications:view' notif.id %}" class="btn btn-success" title="View">
|
||||
<span class="glyphicon glyphicon-eye-open"></span>
|
||||
</a>
|
||||
<a href="{% url 'notifications:remove' notif.id %}" class="btn btn-danger" title="Remove">
|
||||
<span class="glyphicon glyphicon-remove"></span>
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</table>
|
||||
{% else %}
|
||||
<div class="alert alert-warning text-center">{% translate "No unread notifications." %}</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
<div id="read" class="tab-pane fade">
|
||||
<div class="panel-body">
|
||||
<div class="table-responsive">
|
||||
{% if read %}
|
||||
<table class="table table-condensed table-hover table-striped">
|
||||
<tr>
|
||||
<th class="text-center">{% translate "Timestamp" %}</th>
|
||||
<th class="text-center">{% translate "Title" %}</th>
|
||||
<th class="text-center">{% translate "Action" %}</th>
|
||||
</tr>
|
||||
{% for notif in read %}
|
||||
<tr class="{{ notif.level }}">
|
||||
<td class="text-center">{{ notif.timestamp }}</td>
|
||||
<td class="text-center">{{ notif.title }}</td>
|
||||
<td class="text-center">
|
||||
<a href="{% url 'notifications:view' notif.id %}" class="btn btn-success" title="View">
|
||||
<span class="glyphicon glyphicon-eye-open"></span>
|
||||
</a>
|
||||
<a href="{% url 'notifications:remove' notif.id %}" class="btn btn-danger" title="remove">
|
||||
<span class="glyphicon glyphicon-remove"></span>
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</table>
|
||||
{% else %}
|
||||
<div class="alert alert-warning text-center">{% translate "No read notifications." %}</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<h1 class="page-header text-center">{% translate "Notifications" %}</h1>
|
||||
|
||||
<div class="panel panel-default">
|
||||
|
||||
<div class="panel-heading">
|
||||
<ul class="nav nav-pills">
|
||||
<li class="active"><a data-toggle="tab" href="#unread">{% translate "Unread" %}<b>({{ unread|length }})</b></a></li>
|
||||
<li><a data-toggle="tab" href="#read">{% translate "Read" %} <b>({{ read|length }})</b></a></li>
|
||||
<div class="pull-right">
|
||||
<a href="{% url 'notifications:mark_all_read' %}" class="btn btn-warning">{% translate "Mark All Read" %}</a>
|
||||
<a href="{% url 'notifications:delete_all_read' %}" class="btn btn-danger">{% translate "Delete All Read" %}</a>
|
||||
</div>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="panel-body">
|
||||
<div class="tab-content">
|
||||
|
||||
<div id="unread" class="tab-pane fade in active">
|
||||
{% include "notifications/list_partial.html" with notifications=unread %}
|
||||
</div>
|
||||
|
||||
<div id="read" class="tab-pane fade">
|
||||
{% include "notifications/list_partial.html" with notifications=read %}
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
{% endblock %}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
{% load i18n %}
|
||||
|
||||
{% if notifications %}
|
||||
<div class="table-responsive">
|
||||
<table class="table table-condensed table-hover table-striped">
|
||||
<tr>
|
||||
<th class="text-center">{% translate "Timestamp" %}</th>
|
||||
<th class="text-center">{% translate "Title" %}</th>
|
||||
<th class="text-center">{% translate "Action" %}</th>
|
||||
</tr>
|
||||
{% for notif in notifications %}
|
||||
<tr class="{{ notif.level }}">
|
||||
<td class="text-center">{{ notif.timestamp }}</td>
|
||||
<td class="text-center">{{ notif.title }}</td>
|
||||
<td class="text-center">
|
||||
<a href="{% url 'notifications:view' notif.id %}" class="btn btn-primary" title="View">
|
||||
<span class="glyphicon glyphicon-eye-open"></span>
|
||||
</a>
|
||||
<a href="{% url 'notifications:remove' notif.id %}" class="btn btn-danger" title="Remove">
|
||||
<span class="glyphicon glyphicon-remove"></span>
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</table>
|
||||
</div>
|
||||
{% else %}
|
||||
<div class="alert alert-default text-center">{% translate "No notifications." %}</div>
|
||||
{% endif %}
|
||||
@@ -5,25 +5,22 @@
|
||||
{% block page_title %}{% translate "View Notification" %}{% endblock page_title %}
|
||||
|
||||
{% block content %}
|
||||
<h1 class="page-header text-center">
|
||||
{% translate "View Notification" %}
|
||||
<div class="text-right">
|
||||
<a href="{% url 'notifications:list' %}" class="btn btn-primary btn-lg">
|
||||
<span class="glyphicon glyphicon-arrow-left"></span>
|
||||
</a>
|
||||
</div>
|
||||
</h1>
|
||||
|
||||
<div class="col-lg-12">
|
||||
<h1 class="page-header text-center">
|
||||
{% translate "View Notification" %}
|
||||
<div class="text-right">
|
||||
<a href="{% url 'notifications:list' %}" class="btn btn-primary btn-lg">
|
||||
<span class="glyphicon glyphicon-arrow-left"></span>
|
||||
</a>
|
||||
</div>
|
||||
</h1>
|
||||
<div class="col-lg-12 container">
|
||||
<div class="row">
|
||||
<div class="col-lg-12">
|
||||
<div class="panel panel-{{ notif.level }}">
|
||||
<div class="panel-heading">{{ notif.timestamp }} {{ notif.title }}</div>
|
||||
<div class="panel-body"><pre>{{ notif.message }}</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col-lg-12">
|
||||
<div class="panel panel-{{ notif.level }}">
|
||||
<div class="panel-heading">{{ notif.timestamp }} {{ notif.title }}</div>
|
||||
<div class="panel-body"><pre>{{ notif.message }}</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{% endblock %}
|
||||
|
||||
85
allianceauth/notifications/tests/test_core.py
Normal file
85
allianceauth/notifications/tests/test_core.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from django.test import TestCase
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from ..core import NotifyApiWrapper
|
||||
from ..models import Notification
|
||||
|
||||
|
||||
class TestUserNotificationCount(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
cls.user = AuthUtils.create_user("bruce_wayne")
|
||||
|
||||
def test_should_add_danger_notification(self):
|
||||
# given
|
||||
notify = NotifyApiWrapper()
|
||||
# when
|
||||
notify.danger(user=self.user, title="title", message="message")
|
||||
# then
|
||||
obj = Notification.objects.first()
|
||||
self.assertEqual(obj.user, self.user)
|
||||
self.assertEqual(obj.title, "title")
|
||||
self.assertEqual(obj.message, "message")
|
||||
self.assertEqual(obj.level, Notification.Level.DANGER)
|
||||
|
||||
def test_should_add_info_notification(self):
|
||||
# given
|
||||
notify = NotifyApiWrapper()
|
||||
# when
|
||||
notify.info(user=self.user, title="title", message="message")
|
||||
# then
|
||||
obj = Notification.objects.first()
|
||||
self.assertEqual(obj.user, self.user)
|
||||
self.assertEqual(obj.title, "title")
|
||||
self.assertEqual(obj.message, "message")
|
||||
self.assertEqual(obj.level, Notification.Level.INFO)
|
||||
|
||||
def test_should_add_success_notification(self):
|
||||
# given
|
||||
notify = NotifyApiWrapper()
|
||||
# when
|
||||
notify.success(user=self.user, title="title", message="message")
|
||||
# then
|
||||
obj = Notification.objects.first()
|
||||
self.assertEqual(obj.user, self.user)
|
||||
self.assertEqual(obj.title, "title")
|
||||
self.assertEqual(obj.message, "message")
|
||||
self.assertEqual(obj.level, Notification.Level.SUCCESS)
|
||||
|
||||
def test_should_add_warning_notification(self):
|
||||
# given
|
||||
notify = NotifyApiWrapper()
|
||||
# when
|
||||
notify.warning(user=self.user, title="title", message="message")
|
||||
# then
|
||||
obj = Notification.objects.first()
|
||||
self.assertEqual(obj.user, self.user)
|
||||
self.assertEqual(obj.title, "title")
|
||||
self.assertEqual(obj.message, "message")
|
||||
self.assertEqual(obj.level, Notification.Level.WARNING)
|
||||
|
||||
def test_should_add_info_notification_via_callable(self):
|
||||
# given
|
||||
notify = NotifyApiWrapper()
|
||||
# when
|
||||
notify(user=self.user, title="title", message="message")
|
||||
# then
|
||||
obj = Notification.objects.first()
|
||||
self.assertEqual(obj.user, self.user)
|
||||
self.assertEqual(obj.title, "title")
|
||||
self.assertEqual(obj.message, "message")
|
||||
self.assertEqual(obj.level, Notification.Level.INFO)
|
||||
|
||||
def test_should_add_danger_notification_via_callable(self):
|
||||
# given
|
||||
notify = NotifyApiWrapper()
|
||||
# when
|
||||
notify(user=self.user, title="title", message="message", level="danger")
|
||||
# then
|
||||
obj = Notification.objects.first()
|
||||
self.assertEqual(obj.user, self.user)
|
||||
self.assertEqual(obj.title, "title")
|
||||
self.assertEqual(obj.message, "message")
|
||||
self.assertEqual(obj.level, Notification.Level.DANGER)
|
||||
@@ -4,11 +4,8 @@ from allianceauth.tests.auth_utils import AuthUtils
|
||||
from .. import notify
|
||||
from ..models import Notification
|
||||
|
||||
MODULE_PATH = 'allianceauth.notifications'
|
||||
|
||||
|
||||
class TestUserNotificationCount(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
cls.user = AuthUtils.create_user('magic_mike')
|
||||
@@ -23,6 +20,18 @@ class TestUserNotificationCount(TestCase):
|
||||
alliance_name='RIDERS'
|
||||
)
|
||||
|
||||
def test_can_notify(self):
|
||||
notify(self.user, 'dummy')
|
||||
def test_can_notify_short(self):
|
||||
# when
|
||||
notify(self.user, "dummy")
|
||||
# then
|
||||
self.assertEqual(Notification.objects.filter(user=self.user).count(), 1)
|
||||
|
||||
def test_can_notify_full(self):
|
||||
# when
|
||||
notify(user=self.user, title="title", message="message", level="danger")
|
||||
# then
|
||||
obj = Notification.objects.first()
|
||||
self.assertEqual(obj.user, self.user)
|
||||
self.assertEqual(obj.title, "title")
|
||||
self.assertEqual(obj.message, "message")
|
||||
self.assertEqual(obj.level, "danger")
|
||||
|
||||
@@ -73,6 +73,8 @@
|
||||
],
|
||||
bootstrap: true
|
||||
},
|
||||
"stateSave": true,
|
||||
"stateDuration": 0,
|
||||
drawCallback: function ( settings ) {
|
||||
let api = this.api();
|
||||
let rows = api.rows( {page:'current'} ).nodes();
|
||||
|
||||
@@ -106,8 +106,10 @@
|
||||
idx: 1
|
||||
}
|
||||
],
|
||||
bootstrap: true
|
||||
bootstrap: true,
|
||||
},
|
||||
"stateSave": true,
|
||||
"stateDuration": 0,
|
||||
drawCallback: function ( settings ) {
|
||||
let api = this.api();
|
||||
let rows = api.rows( {page:'current'} ).nodes();
|
||||
|
||||
@@ -3,11 +3,11 @@ from django.contrib import admin
|
||||
|
||||
from allianceauth import hooks
|
||||
from allianceauth.authentication.admin import (
|
||||
MainAllianceFilter,
|
||||
MainCorporationsFilter,
|
||||
user_main_organization,
|
||||
user_profile_pic,
|
||||
user_username,
|
||||
user_main_organization,
|
||||
MainCorporationsFilter,
|
||||
MainAllianceFilter
|
||||
)
|
||||
|
||||
from .models import NameFormatConfig
|
||||
@@ -36,19 +36,18 @@ class ServicesUserAdmin(admin.ModelAdmin):
|
||||
MainAllianceFilter,
|
||||
'user__date_joined',
|
||||
)
|
||||
list_select_related = (
|
||||
'user', 'user__profile__main_character', 'user__profile__state'
|
||||
)
|
||||
|
||||
@admin.display(ordering='user__profile__state__name')
|
||||
def _state(self, obj):
|
||||
return obj.user.profile.state.name
|
||||
|
||||
_state.short_description = 'state'
|
||||
_state.admin_order_field = 'user__profile__state__name'
|
||||
|
||||
@admin.display(ordering='user__date_joined')
|
||||
def _date_joined(self, obj):
|
||||
return obj.user.date_joined
|
||||
|
||||
_date_joined.short_description = 'date joined'
|
||||
_date_joined.admin_order_field = 'user__date_joined'
|
||||
|
||||
|
||||
class NameFormatConfigForm(forms.ModelForm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -62,6 +61,7 @@ class NameFormatConfigForm(forms.ModelForm):
|
||||
self.fields['service_name'] = forms.ChoiceField(choices=SERVICE_CHOICES)
|
||||
|
||||
|
||||
@admin.register(NameFormatConfig)
|
||||
class NameFormatConfigAdmin(admin.ModelAdmin):
|
||||
form = NameFormatConfigForm
|
||||
list_display = ('service_name', 'get_state_display_string')
|
||||
@@ -69,6 +69,3 @@ class NameFormatConfigAdmin(admin.ModelAdmin):
|
||||
def get_state_display_string(self, obj):
|
||||
return ', '.join([state.name for state in obj.states.all()])
|
||||
get_state_display_string.short_description = 'States'
|
||||
|
||||
|
||||
admin.site.register(NameFormatConfig, NameFormatConfigAdmin)
|
||||
|
||||
@@ -2,12 +2,11 @@ import logging
|
||||
|
||||
from django.contrib import admin
|
||||
|
||||
from . import __title__
|
||||
from ...admin import ServicesUserAdmin
|
||||
from . import __title__
|
||||
from .models import DiscordUser
|
||||
from .utils import LoggerAddTag
|
||||
|
||||
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
|
||||
@@ -18,21 +17,16 @@ class DiscordUserAdmin(ServicesUserAdmin):
|
||||
list_filter = ServicesUserAdmin.list_filter + ('activated',)
|
||||
ordering = ('-activated',)
|
||||
|
||||
def _uid(self, obj):
|
||||
return obj.uid
|
||||
|
||||
_uid.short_description = 'Discord ID (UID)'
|
||||
_uid.admin_order_field = 'uid'
|
||||
|
||||
def _username(self, obj):
|
||||
if obj.username and obj.discriminator:
|
||||
return f'{obj.username}#{obj.discriminator}'
|
||||
else:
|
||||
return ''
|
||||
|
||||
def delete_queryset(self, request, queryset):
|
||||
for user in queryset:
|
||||
user.delete_user()
|
||||
|
||||
_username.short_description = 'Discord Username'
|
||||
_username.admin_order_field = 'username'
|
||||
@admin.display(description='Discord ID (UID)', ordering='uid')
|
||||
def _uid(self, obj):
|
||||
return obj.uid
|
||||
|
||||
@admin.display(description='Discord Username', ordering='username')
|
||||
def _username(self, obj):
|
||||
if obj.username and obj.discriminator:
|
||||
return f'{obj.username}#{obj.discriminator}'
|
||||
return ''
|
||||
|
||||
37
allianceauth/services/modules/discord/api.py
Normal file
37
allianceauth/services/modules/discord/api.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Public interface for community apps who want to interact with the Discord server
|
||||
of the current Alliance Auth instance.
|
||||
|
||||
Example
|
||||
=======
|
||||
|
||||
Here is an example for using the api to fetch the current roles from the configured Discord server.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from allianceauth.services.modules.discord.api import create_bot_client, discord_guild_id
|
||||
|
||||
client = create_bot_client() # create a new Discord client
|
||||
guild_id = discord_guild_id() # get the ID of the configured Discord server
|
||||
roles = client.guild_roles(guild_id) # fetch the roles from our Discord server
|
||||
|
||||
.. seealso::
|
||||
The docs for the client class can be found here: :py:class:`~allianceauth.services.modules.discord.discord_client.client.DiscordClient`
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .app_settings import DISCORD_GUILD_ID
|
||||
from .core import create_bot_client, group_to_role, server_name # noqa
|
||||
from .discord_client.models import Role # noqa
|
||||
from .models import DiscordUser # noqa
|
||||
|
||||
__all__ = ["create_bot_client", "group_to_role", "server_name", "DiscordUser", "Role"]
|
||||
|
||||
|
||||
def discord_guild_id() -> Optional[int]:
|
||||
"""Guild ID of configured Discord server.
|
||||
|
||||
Returns:
|
||||
Guild ID or ``None`` if not configured
|
||||
"""
|
||||
return int(DISCORD_GUILD_ID) if DISCORD_GUILD_ID else None
|
||||
@@ -2,16 +2,25 @@ from .utils import clean_setting
|
||||
|
||||
|
||||
DISCORD_APP_ID = clean_setting('DISCORD_APP_ID', '')
|
||||
"""App ID for the AA bot on Discord. Needs to be set."""
|
||||
|
||||
DISCORD_APP_SECRET = clean_setting('DISCORD_APP_SECRET', '')
|
||||
"""App secret for the AA bot on Discord. Needs to be set."""
|
||||
|
||||
DISCORD_BOT_TOKEN = clean_setting('DISCORD_BOT_TOKEN', '')
|
||||
"""Token used by the AA bot on Discord. Needs to be set."""
|
||||
|
||||
DISCORD_CALLBACK_URL = clean_setting('DISCORD_CALLBACK_URL', '')
|
||||
"""Callback URL for OAuth with Discord. Needs to be set."""
|
||||
|
||||
DISCORD_GUILD_ID = clean_setting('DISCORD_GUILD_ID', '')
|
||||
"""ID of the Discord Server. Needs to be set."""
|
||||
|
||||
# max retries of tasks after an error occurred
|
||||
DISCORD_TASKS_MAX_RETRIES = clean_setting('DISCORD_TASKS_MAX_RETRIES', 3)
|
||||
"""Max retries of tasks after an error occurred."""
|
||||
|
||||
# Pause in seconds until next retry for tasks after the API returned an error
|
||||
DISCORD_TASKS_RETRY_PAUSE = clean_setting('DISCORD_TASKS_RETRY_PAUSE', 60)
|
||||
"""Pause in seconds until next retry for tasks after the API returned an error."""
|
||||
|
||||
# automatically sync Discord users names to user's main character name when created
|
||||
DISCORD_SYNC_NAMES = clean_setting('DISCORD_SYNC_NAMES', False)
|
||||
"""Automatically sync Discord users names to user's main character name when created."""
|
||||
|
||||
@@ -6,6 +6,7 @@ from django.template.loader import render_to_string
|
||||
from allianceauth import hooks
|
||||
from allianceauth.services.hooks import ServicesHook
|
||||
|
||||
from .core import server_name, user_formatted_nick
|
||||
from .models import DiscordUser
|
||||
from .urls import urlpatterns
|
||||
from .utils import LoggerAddTag
|
||||
@@ -53,7 +54,7 @@ class DiscordService(ServicesHook):
|
||||
return render_to_string(
|
||||
self.service_ctrl_template,
|
||||
{
|
||||
'server_name': DiscordUser.objects.server_name(),
|
||||
'server_name': server_name(),
|
||||
'user_has_account': user_has_account,
|
||||
'discord_username': discord_username
|
||||
},
|
||||
@@ -73,7 +74,7 @@ class DiscordService(ServicesHook):
|
||||
'user_pk': user.pk,
|
||||
# since the new nickname is not yet in the DB we need to
|
||||
# provide it manually to the task
|
||||
'nickname': DiscordUser.objects.user_formatted_nick(user)
|
||||
'nickname': user_formatted_nick(user)
|
||||
},
|
||||
priority=SINGLE_TASK_PRIORITY
|
||||
)
|
||||
|
||||
129
allianceauth/services/modules/discord/core.py
Normal file
129
allianceauth/services/modules/discord/core.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Core functionality of the Discord service not directly related to models."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.contrib.auth.models import Group, User
|
||||
|
||||
from allianceauth.groupmanagement.models import ReservedGroupName
|
||||
from allianceauth.services.hooks import NameFormatter
|
||||
|
||||
from . import __title__
|
||||
from .app_settings import DISCORD_BOT_TOKEN, DISCORD_GUILD_ID
|
||||
from .discord_client import DiscordClient, RolesSet, Role
|
||||
from .discord_client.exceptions import DiscordClientException
|
||||
from .utils import LoggerAddTag
|
||||
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
|
||||
def create_bot_client(is_rate_limited: bool = True) -> DiscordClient:
|
||||
"""Create new bot client for accessing the configured Discord server.
|
||||
|
||||
Args:
|
||||
is_rate_limited: Set to False to turn off rate limiting (use with care).
|
||||
|
||||
Return:
|
||||
Discord client instance
|
||||
"""
|
||||
return DiscordClient(DISCORD_BOT_TOKEN, is_rate_limited=is_rate_limited)
|
||||
|
||||
|
||||
def calculate_roles_for_user(
|
||||
user: User,
|
||||
client: DiscordClient,
|
||||
discord_uid: int,
|
||||
state_name: str = None,
|
||||
) -> Tuple[RolesSet, Optional[bool]]:
|
||||
"""Calculate current Discord roles for an Auth user.
|
||||
|
||||
Takes into account reserved groups and existing managed roles (e.g. nitro).
|
||||
|
||||
Returns:
|
||||
- Discord roles, changed flag:
|
||||
- True when roles have changed,
|
||||
- False when they have not changed,
|
||||
- None if user is not a member of the guild
|
||||
"""
|
||||
roles_calculated = client.match_or_create_roles_from_names_2(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
role_names=_user_group_names(user=user, state_name=state_name),
|
||||
)
|
||||
logger.debug("Calculated roles for user %s: %s", user, roles_calculated.ids())
|
||||
roles_current = client.guild_member_roles(
|
||||
guild_id=DISCORD_GUILD_ID, user_id=discord_uid
|
||||
)
|
||||
if roles_current is None:
|
||||
logger.debug("User %s is not a member of the guild.", user)
|
||||
return roles_calculated, None
|
||||
logger.debug("Current roles user %s: %s", user, roles_current.ids())
|
||||
reserved_role_names = ReservedGroupName.objects.values_list("name", flat=True)
|
||||
roles_reserved = roles_current.subset(role_names=reserved_role_names)
|
||||
roles_managed = roles_current.subset(managed_only=True)
|
||||
roles_persistent = roles_managed.union(roles_reserved)
|
||||
if roles_calculated == roles_current.difference(roles_persistent):
|
||||
return roles_calculated, False
|
||||
return roles_calculated.union(roles_persistent), True
|
||||
|
||||
|
||||
def _user_group_names(user: User, state_name: str = None) -> List[str]:
|
||||
"""Names of groups and state the given user is a member of."""
|
||||
if not state_name:
|
||||
state_name = user.profile.state.name
|
||||
group_names = [group.name for group in user.groups.all()] + [state_name]
|
||||
logger.debug("Group names for roles updates of user %s are: %s", user, group_names)
|
||||
return group_names
|
||||
|
||||
|
||||
def user_formatted_nick(user: User) -> Optional[str]:
|
||||
"""Name of the given user's main character with name formatting applied.
|
||||
|
||||
Returns:
|
||||
Name or ``None`` if user has no main.
|
||||
"""
|
||||
from .auth_hooks import DiscordService
|
||||
|
||||
if user.profile.main_character:
|
||||
return NameFormatter(DiscordService(), user).format_name()
|
||||
return None
|
||||
|
||||
|
||||
def group_to_role(group: Group) -> Optional[Role]:
|
||||
"""Fetch the Discord role matching the given Django group by name.
|
||||
|
||||
Returns:
|
||||
Discord role or None if no matching role exist
|
||||
"""
|
||||
return default_bot_client.match_role_from_name(
|
||||
guild_id=DISCORD_GUILD_ID, role_name=group.name
|
||||
)
|
||||
|
||||
|
||||
def server_name(use_cache: bool = True) -> str:
|
||||
"""Fetches the name of the current Discord server.
|
||||
|
||||
Args:
|
||||
use_cache: When set False will force an API call to get the server name
|
||||
|
||||
Returns:
|
||||
Server name or an empty string if the name could not be retrieved
|
||||
"""
|
||||
try:
|
||||
server_name = default_bot_client.guild_name(
|
||||
guild_id=DISCORD_GUILD_ID, use_cache=use_cache
|
||||
)
|
||||
except (HTTPError, DiscordClientException):
|
||||
server_name = ""
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Unexpected error when trying to retrieve the server name from Discord",
|
||||
exc_info=True,
|
||||
)
|
||||
server_name = ""
|
||||
return server_name
|
||||
|
||||
|
||||
# Default bot client to be used by modules of this package
|
||||
default_bot_client = create_bot_client()
|
||||
@@ -1,3 +1,10 @@
|
||||
from .client import DiscordClient # noqa
|
||||
from .exceptions import DiscordApiBackoff # noqa
|
||||
from .helpers import DiscordRoles # noqa
|
||||
from .app_settings import DISCORD_OAUTH_BASE_URL, DISCORD_OAUTH_TOKEN_URL # noqa
|
||||
from .client import DiscordClient # noqa
|
||||
from .exceptions import ( # noqa
|
||||
DiscordApiBackoff,
|
||||
DiscordClientException,
|
||||
DiscordRateLimitExhausted,
|
||||
DiscordTooManyRequestsError,
|
||||
)
|
||||
from .helpers import RolesSet # noqa
|
||||
from .models import Guild, GuildMember, Role, User # noqa
|
||||
|
||||
@@ -1,45 +1,56 @@
|
||||
"""Settings for the Discord client.
|
||||
|
||||
To overwrite a default set the variable in your local Django settings, e.g:
|
||||
|
||||
.. code:: python
|
||||
|
||||
DISCORD_GUILD_NAME_CACHE_MAX_AGE = 7200
|
||||
"""
|
||||
|
||||
from ..utils import clean_setting
|
||||
|
||||
|
||||
# Base URL for all API calls. Must end with /.
|
||||
DISCORD_API_BASE_URL = clean_setting(
|
||||
'DISCORD_API_BASE_URL', 'https://discord.com/api/'
|
||||
)
|
||||
"""Base URL for all API calls. Must end with /."""
|
||||
|
||||
# Low level connecttimeout for requests to the Discord API in seconds
|
||||
DISCORD_API_TIMEOUT_CONNECT = clean_setting(
|
||||
'DISCORD_API_TIMEOUT', 5
|
||||
)
|
||||
"""Low level connect timeout for requests to the Discord API in seconds."""
|
||||
|
||||
# Low level read timeout for requests to the Discord API in seconds
|
||||
DISCORD_API_TIMEOUT_READ = clean_setting(
|
||||
'DISCORD_API_TIMEOUT', 30
|
||||
)
|
||||
"""Low level read timeout for requests to the Discord API in seconds."""
|
||||
|
||||
# Base authorization URL for Discord Oauth
|
||||
DISCORD_OAUTH_BASE_URL = clean_setting(
|
||||
'DISCORD_OAUTH_BASE_URL', 'https://discord.com/api/oauth2/authorize'
|
||||
)
|
||||
"""Base authorization URL for Discord Oauth."""
|
||||
|
||||
# Base authorization URL for Discord Oauth
|
||||
DISCORD_OAUTH_TOKEN_URL = clean_setting(
|
||||
'DISCORD_OAUTH_TOKEN_URL', 'https://discord.com/api/oauth2/token'
|
||||
)
|
||||
"""Base authorization URL for Discord Oauth."""
|
||||
|
||||
# How long the Discord guild names retrieved from the server are
|
||||
# caches locally in seconds.
|
||||
DISCORD_GUILD_NAME_CACHE_MAX_AGE = clean_setting(
|
||||
'DISCORD_GUILD_NAME_CACHE_MAX_AGE', 3600 * 24
|
||||
)
|
||||
"""How long the Discord guild names retrieved from the server
|
||||
are caches locally in seconds.
|
||||
"""
|
||||
|
||||
# How long Discord roles retrieved from the server are caches locally in seconds.
|
||||
DISCORD_ROLES_CACHE_MAX_AGE = clean_setting(
|
||||
'DISCORD_ROLES_CACHE_MAX_AGE', 3600 * 1
|
||||
)
|
||||
"""How long Discord roles retrieved from the server are caches locally in seconds."""
|
||||
|
||||
# Turns off creation of new roles. In case the rate limit for creating roles is
|
||||
# exhausted, this setting allows the Discord service to continue to function
|
||||
# and wait out the reset. Rate limit is about 250 per 48 hrs.
|
||||
DISCORD_DISABLE_ROLE_CREATION = clean_setting(
|
||||
'DISCORD_DISABLE_ROLE_CREATION', False
|
||||
)
|
||||
"""Turns off creation of new roles. In case the rate limit for creating roles is
|
||||
exhausted, this setting allows the Discord service to continue to function
|
||||
and wait out the reset. Rate limit is about 250 per 48 hrs.
|
||||
"""
|
||||
|
||||
@@ -1,32 +1,37 @@
|
||||
from hashlib import md5
|
||||
"""Client for interacting with the Discord API."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import IntEnum
|
||||
from hashlib import md5
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from urllib.parse import urljoin
|
||||
from uuid import uuid1
|
||||
|
||||
from redis import Redis
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
from redis import Redis
|
||||
|
||||
from django.core.cache import caches
|
||||
from allianceauth.utils.cache import get_redis_client
|
||||
|
||||
from allianceauth import __title__ as AUTH_TITLE, __url__, __version__
|
||||
from allianceauth import __title__ as AUTH_TITLE
|
||||
from allianceauth import __url__, __version__
|
||||
|
||||
from .. import __title__
|
||||
from ..utils import LoggerAddTag
|
||||
from .app_settings import (
|
||||
DISCORD_API_BASE_URL,
|
||||
DISCORD_API_TIMEOUT_CONNECT,
|
||||
DISCORD_API_TIMEOUT_READ,
|
||||
DISCORD_DISABLE_ROLE_CREATION,
|
||||
DISCORD_GUILD_NAME_CACHE_MAX_AGE,
|
||||
DISCORD_OAUTH_BASE_URL,
|
||||
DISCORD_OAUTH_TOKEN_URL,
|
||||
DISCORD_ROLES_CACHE_MAX_AGE,
|
||||
)
|
||||
from .exceptions import DiscordRateLimitExhausted, DiscordTooManyRequestsError
|
||||
from .helpers import DiscordRoles
|
||||
from ..utils import LoggerAddTag
|
||||
|
||||
from .helpers import RolesSet
|
||||
from .models import Guild, GuildMember, Role, User
|
||||
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
@@ -58,8 +63,13 @@ MINIMUM_BLOCKING_WAIT = 50
|
||||
RATE_LIMIT_RETRIES = 1000
|
||||
|
||||
|
||||
class DiscordApiStatusCode(IntEnum):
|
||||
"""Status code returned from the Discord API."""
|
||||
UNKNOWN_MEMBER = 10007 #:
|
||||
|
||||
|
||||
class DiscordClient:
|
||||
"""This class provides a web client for interacting with the Discord API
|
||||
"""This class provides a web client for interacting with the Discord API.
|
||||
|
||||
The client has rate limiting that supports concurrency.
|
||||
This means it is able to ensure the API rate limit is not violated,
|
||||
@@ -67,24 +77,30 @@ class DiscordClient:
|
||||
|
||||
In addition the client support proper API backoff.
|
||||
|
||||
Synchronization of rate limit infos accross multiple processes
|
||||
Synchronization of rate limit infos across multiple processes
|
||||
is implemented with Redis and thus requires Redis as Django cache backend.
|
||||
|
||||
All durations are in milliseconds.
|
||||
"""
|
||||
OAUTH_BASE_URL = DISCORD_OAUTH_BASE_URL
|
||||
OAUTH_TOKEN_URL = DISCORD_OAUTH_TOKEN_URL
|
||||
The cache is shared across all clients and processes (also using Redis).
|
||||
|
||||
All durations are in milliseconds.
|
||||
|
||||
Most errors from the API will raise a requests.HTTPError.
|
||||
|
||||
Args:
|
||||
access_token: Discord access token used to authenticate all calls to the API
|
||||
redis: Redis instance to be used.
|
||||
is_rate_limited: Set to False to turn off rate limiting (use with care).
|
||||
If not specified will try to use the Redis instance
|
||||
from the default Django cache backend.
|
||||
|
||||
Raises:
|
||||
ValueError: No access token provided
|
||||
"""
|
||||
_KEY_GLOBAL_BACKOFF_UNTIL = 'DISCORD_GLOBAL_BACKOFF_UNTIL'
|
||||
_KEY_GLOBAL_RATE_LIMIT_REMAINING = 'DISCORD_GLOBAL_RATE_LIMIT_REMAINING'
|
||||
_KEYPREFIX_GUILD_NAME = 'DISCORD_GUILD_NAME'
|
||||
_KEYPREFIX_GUILD_ROLES = 'DISCORD_GUILD_ROLES'
|
||||
_KEYPREFIX_ROLE_NAME = 'DISCORD_ROLE_NAME'
|
||||
_NICK_MAX_CHARS = 32
|
||||
|
||||
_HTTP_STATUS_CODE_NOT_FOUND = 404
|
||||
_HTTP_STATUS_CODE_RATE_LIMITED = 429
|
||||
_DISCORD_STATUS_CODE_UNKNOWN_MEMBER = 10007
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -92,19 +108,12 @@ class DiscordClient:
|
||||
redis: Redis = None,
|
||||
is_rate_limited: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Params:
|
||||
- access_token: Discord access token used to authenticate all calls to the API
|
||||
- redis: Redis instance to be used.
|
||||
- is_rate_limited: Set to False to run of rate limiting (use with care)
|
||||
If not specified will try to use the Redis instance
|
||||
from the default Django cache backend.
|
||||
"""
|
||||
if not access_token:
|
||||
raise ValueError('You must provide an access token.')
|
||||
self._access_token = str(access_token)
|
||||
self._is_rate_limited = bool(is_rate_limited)
|
||||
if not redis:
|
||||
default_cache = caches['default']
|
||||
self._redis = default_cache.get_master_client()
|
||||
self._redis = get_redis_client()
|
||||
if not isinstance(self._redis, Redis):
|
||||
raise RuntimeError(
|
||||
'This class requires a Redis client, but none was provided '
|
||||
@@ -132,19 +141,20 @@ class DiscordClient:
|
||||
self.__redis_script_set_longer = self._redis.register_script(lua_2)
|
||||
|
||||
@property
|
||||
def access_token(self):
|
||||
def access_token(self) -> str:
|
||||
"""Discord access token."""
|
||||
return self._access_token
|
||||
|
||||
@property
|
||||
def is_rate_limited(self):
|
||||
def is_rate_limited(self) -> bool:
|
||||
"""Wether this instance is rate limited."""
|
||||
return self._is_rate_limited
|
||||
|
||||
def __repr__(self):
|
||||
return f'{type(self).__name__}(access_token=...{self.access_token[-5:]})'
|
||||
|
||||
def _redis_decr_or_set(self, name: str, value: str, px: int) -> bool:
|
||||
"""decreases the key value if it exists and returns the result
|
||||
else sets the key
|
||||
"""Decrease the key value if it exists and returns the result else set the key.
|
||||
|
||||
Implemented as Lua script to ensure atomicity.
|
||||
"""
|
||||
@@ -153,7 +163,7 @@ class DiscordClient:
|
||||
)
|
||||
|
||||
def _redis_set_if_longer(self, name: str, value: str, px: int) -> bool:
|
||||
"""like set, but only goes through if either key doesn't exist
|
||||
"""Like set, but only goes through if either key doesn't exist
|
||||
or px would be extended.
|
||||
|
||||
Implemented as Lua script to ensure atomicity.
|
||||
@@ -164,111 +174,134 @@ class DiscordClient:
|
||||
|
||||
# users
|
||||
|
||||
def current_user(self) -> dict:
|
||||
"""returns the user belonging to the current access_token"""
|
||||
def current_user(self) -> User:
|
||||
"""Fetch user belonging to the current access_token."""
|
||||
authorization = f'Bearer {self.access_token}'
|
||||
r = self._api_request(
|
||||
method='get', route='users/@me', authorization=authorization
|
||||
)
|
||||
return r.json()
|
||||
return User.from_dict(r.json())
|
||||
|
||||
# guild
|
||||
|
||||
def guild_infos(self, guild_id: int) -> dict:
|
||||
"""Returns all basic infos about this guild"""
|
||||
def guild_infos(self, guild_id: int) -> Guild:
|
||||
"""Fetch all basic infos about this guild.
|
||||
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
"""
|
||||
route = f"guilds/{guild_id}"
|
||||
r = self._api_request(method='get', route=route)
|
||||
return r.json()
|
||||
return Guild.from_dict(r.json())
|
||||
|
||||
def guild_name(self, guild_id: int, use_cache: bool = True) -> str:
|
||||
"""returns the name of this guild (cached)
|
||||
or an empty string if something went wrong
|
||||
"""Fetch the name of this guild (cached).
|
||||
|
||||
Params:
|
||||
- guild_id: ID of current guild
|
||||
- use_cache: When set to False will force an API call to get the server name
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
use_cache: When set to False will force an API call to get the server name
|
||||
|
||||
Returns:
|
||||
Name of the server or an empty string if something went wrong.
|
||||
"""
|
||||
key_name = self._guild_name_cache_key(guild_id)
|
||||
if use_cache:
|
||||
guild_name = self._redis_decode(self._redis.get(key_name))
|
||||
else:
|
||||
guild_name = None
|
||||
guild_name = ""
|
||||
if not guild_name:
|
||||
guild_infos = self.guild_infos(guild_id)
|
||||
if 'name' in guild_infos:
|
||||
guild_name = guild_infos['name']
|
||||
self._redis.set(
|
||||
name=key_name,
|
||||
value=guild_name,
|
||||
ex=DISCORD_GUILD_NAME_CACHE_MAX_AGE
|
||||
)
|
||||
try:
|
||||
guild = self.guild_infos(guild_id)
|
||||
except HTTPError:
|
||||
guild_name = ""
|
||||
else:
|
||||
guild_name = ''
|
||||
|
||||
guild_name = guild.name
|
||||
self._redis.set(
|
||||
name=key_name, value=guild_name, ex=DISCORD_GUILD_NAME_CACHE_MAX_AGE
|
||||
)
|
||||
return guild_name
|
||||
|
||||
@classmethod
|
||||
def _guild_name_cache_key(cls, guild_id: int) -> str:
|
||||
"""Returns key for accessing role given by name in the role cache"""
|
||||
"""Construct key for accessing role given by name in the role cache.
|
||||
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
"""
|
||||
gen_key = DiscordClient._generate_hash(f'{guild_id}')
|
||||
return f'{cls._KEYPREFIX_GUILD_NAME}__{gen_key}'
|
||||
|
||||
# guild roles
|
||||
|
||||
def guild_roles(self, guild_id: int, use_cache: bool = True) -> list:
|
||||
"""Returns the list of all roles for this guild
|
||||
def guild_roles(self, guild_id: int, use_cache: bool = True) -> Set[Role]:
|
||||
"""Fetch all roles for this guild.
|
||||
|
||||
If use_cache is set to False it will always hit the API to retrieve
|
||||
fresh data and update the cache
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
use_cache: If is set to False it will always hit the API to retrieve
|
||||
fresh data and update the cache.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
cache_key = self._guild_roles_cache_key(guild_id)
|
||||
roles = None
|
||||
if use_cache:
|
||||
roles_raw = self._redis.get(name=cache_key)
|
||||
if roles_raw:
|
||||
logger.debug('Returning roles for guild %s from cache', guild_id)
|
||||
return json.loads(self._redis_decode(roles_raw))
|
||||
else:
|
||||
logger.debug('No roles for guild %s in cache', guild_id)
|
||||
|
||||
route = f"guilds/{guild_id}/roles"
|
||||
r = self._api_request(method='get', route=route)
|
||||
roles = r.json()
|
||||
if roles and isinstance(roles, list):
|
||||
roles = json.loads(self._redis_decode(roles_raw))
|
||||
logger.debug('No roles for guild %s in cache', guild_id)
|
||||
if roles is None:
|
||||
route = f"guilds/{guild_id}/roles"
|
||||
r = self._api_request(method='get', route=route)
|
||||
roles = r.json()
|
||||
if not roles or not isinstance(roles, list):
|
||||
raise RuntimeError(
|
||||
f"Unexpected response when fetching roles from API: {roles}"
|
||||
)
|
||||
self._redis.set(
|
||||
name=cache_key,
|
||||
value=json.dumps(roles),
|
||||
ex=DISCORD_ROLES_CACHE_MAX_AGE
|
||||
)
|
||||
return roles
|
||||
return {Role.from_dict(role) for role in roles}
|
||||
|
||||
def create_guild_role(self, guild_id: int, role_name: str, **kwargs) -> dict:
|
||||
def create_guild_role(
|
||||
self, guild_id: int, role_name: str, **kwargs
|
||||
) -> Optional[Role]:
|
||||
"""Create a new guild role with the given name.
|
||||
|
||||
See official documentation for additional optional parameters.
|
||||
|
||||
Note that Discord allows the creation of multiple roles with the same name,
|
||||
so to avoid duplicates it's important to check existing roles
|
||||
before creating new one
|
||||
|
||||
returns a new role dict on success
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
role_name: Name of new role to create
|
||||
|
||||
Returns:
|
||||
new role on success
|
||||
"""
|
||||
route = f"guilds/{guild_id}/roles"
|
||||
data = {'name': DiscordRoles.sanitize_role_name(role_name)}
|
||||
data = {'name': Role.sanitize_name(role_name)}
|
||||
data.update(kwargs)
|
||||
r = self._api_request(method='post', route=route, data=data)
|
||||
role = r.json()
|
||||
if role:
|
||||
self._invalidate_guild_roles_cache(guild_id)
|
||||
return role
|
||||
return Role.from_dict(role)
|
||||
return None
|
||||
|
||||
def delete_guild_role(self, guild_id: int, role_id: int) -> bool:
|
||||
"""Deletes a guild role"""
|
||||
"""Delete a guild role."""
|
||||
route = f"guilds/{guild_id}/roles/{role_id}"
|
||||
r = self._api_request(method='delete', route=route)
|
||||
if r.status_code == 204:
|
||||
self._invalidate_guild_roles_cache(guild_id)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def _invalidate_guild_roles_cache(self, guild_id: int) -> None:
|
||||
cache_key = self._guild_roles_cache_key(guild_id)
|
||||
@@ -277,67 +310,79 @@ class DiscordClient:
|
||||
|
||||
@classmethod
|
||||
def _guild_roles_cache_key(cls, guild_id: int) -> str:
|
||||
"""Returns key for accessing cached roles for a guild"""
|
||||
"""Construct key for accessing cached roles for a guild.
|
||||
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
"""
|
||||
gen_key = cls._generate_hash(f'{guild_id}')
|
||||
return f'{cls._KEYPREFIX_GUILD_ROLES}__{gen_key}'
|
||||
|
||||
def match_role_from_name(self, guild_id: int, role_name: str) -> dict:
|
||||
"""returns Discord role matching the given name or an empty dict"""
|
||||
guild_roles = DiscordRoles(self.guild_roles(guild_id))
|
||||
def match_role_from_name(self, guild_id: int, role_name: str) -> Optional[Role]:
|
||||
"""Fetch Discord role matching the given name (cached).
|
||||
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
role_name: Name of role
|
||||
|
||||
Returns:
|
||||
Matching role or None if no match is found
|
||||
"""
|
||||
guild_roles = RolesSet(self.guild_roles(guild_id))
|
||||
return guild_roles.role_by_name(role_name)
|
||||
|
||||
def match_or_create_roles_from_names(self, guild_id: int, role_names: list) -> list:
|
||||
"""returns Discord roles matching the given names
|
||||
|
||||
Returns as list of tuple of role and created flag
|
||||
def match_or_create_roles_from_names(
|
||||
self, guild_id: int, role_names: Iterable[str]
|
||||
) -> List[Tuple[Role, bool]]:
|
||||
"""Fetch or create Discord roles matching the given names (cached).
|
||||
|
||||
Will try to match with existing roles names
|
||||
Non-existing roles will be created, then created flag will be True
|
||||
|
||||
Params:
|
||||
- guild_id: ID of guild
|
||||
- role_names: list of name strings each defining a role
|
||||
Args:
|
||||
guild_id: ID of guild
|
||||
role_names: list of name strings each defining a role
|
||||
|
||||
Returns:
|
||||
List of tuple of Role and created flag
|
||||
"""
|
||||
roles = list()
|
||||
guild_roles = DiscordRoles(self.guild_roles(guild_id))
|
||||
role_names_cleaned = {
|
||||
DiscordRoles.sanitize_role_name(name) for name in role_names
|
||||
}
|
||||
guild_roles = RolesSet(self.guild_roles(guild_id))
|
||||
role_names_cleaned = {Role.sanitize_name(name) for name in role_names}
|
||||
for role_name in role_names_cleaned:
|
||||
role, created = self.match_or_create_role_from_name(
|
||||
guild_id=guild_id,
|
||||
role_name=DiscordRoles.sanitize_role_name(role_name),
|
||||
guild_roles=guild_roles
|
||||
guild_id=guild_id, role_name=role_name, guild_roles=guild_roles
|
||||
)
|
||||
if role:
|
||||
roles.append((role, created))
|
||||
if created:
|
||||
guild_roles = guild_roles.union(DiscordRoles([role]))
|
||||
guild_roles = guild_roles.union(RolesSet([role]))
|
||||
return roles
|
||||
|
||||
def match_or_create_role_from_name(
|
||||
self, guild_id: int, role_name: str, guild_roles: DiscordRoles = None
|
||||
) -> tuple:
|
||||
"""returns Discord role matching the given name
|
||||
|
||||
Returns as tuple of role and created flag
|
||||
self, guild_id: int, role_name: str, guild_roles: RolesSet = None
|
||||
) -> Tuple[Role, bool]:
|
||||
"""Fetch or create Discord role matching the given name.
|
||||
|
||||
Will try to match with existing roles names
|
||||
Non-existing roles will be created, then created flag will be True
|
||||
|
||||
Params:
|
||||
- guild_id: ID of guild
|
||||
- role_name: strings defining name of a role
|
||||
- guild_roles: All known guild roles as DiscordRoles object.
|
||||
Helps to void redundant lookups of guild roles
|
||||
when this method is used multiple times.
|
||||
Args:
|
||||
guild_id: ID of guild
|
||||
role_name: strings defining name of a role
|
||||
guild_roles: All known guild roles as RolesSet object.
|
||||
Helps to void redundant lookups of guild roles
|
||||
when this method is used multiple times.
|
||||
|
||||
Returns:
|
||||
Tuple of Role and created flag
|
||||
"""
|
||||
if not isinstance(role_name, str):
|
||||
raise TypeError('role_name must be of type string')
|
||||
|
||||
created = False
|
||||
if guild_roles is None:
|
||||
guild_roles = DiscordRoles(self.guild_roles(guild_id))
|
||||
guild_roles = RolesSet(self.guild_roles(guild_id))
|
||||
role = guild_roles.role_by_name(role_name)
|
||||
if not role:
|
||||
if not DISCORD_DISABLE_ROLE_CREATION:
|
||||
@@ -346,9 +391,24 @@ class DiscordClient:
|
||||
created = True
|
||||
else:
|
||||
role = None
|
||||
|
||||
return role, created
|
||||
|
||||
def match_or_create_roles_from_names_2(
|
||||
self, guild_id: int, role_names: Iterable[str]
|
||||
) -> RolesSet:
|
||||
"""Fetch or create Discord role matching the given name.
|
||||
|
||||
Wrapper for ``match_or_create_role_from_name()``
|
||||
|
||||
Returns:
|
||||
Roles as RolesSet object.
|
||||
"""
|
||||
return RolesSet.create_from_matched_roles(
|
||||
self.match_or_create_roles_from_names(
|
||||
guild_id=guild_id, role_names=role_names
|
||||
)
|
||||
)
|
||||
|
||||
# guild members
|
||||
|
||||
def add_guild_member(
|
||||
@@ -358,13 +418,13 @@ class DiscordClient:
|
||||
access_token: str,
|
||||
role_ids: list = None,
|
||||
nick: str = None
|
||||
) -> bool:
|
||||
"""Adds a user to the guilds.
|
||||
) -> Optional[bool]:
|
||||
"""Adds a user to the guild.
|
||||
|
||||
Returns:
|
||||
- True when a new user was added
|
||||
- None if the user already existed
|
||||
- False when something went wrong or raises exception
|
||||
- True when a new user was added
|
||||
- None if the user already existed
|
||||
- False when something went wrong or raises exception
|
||||
"""
|
||||
route = f"guilds/{guild_id}/members/{user_id}"
|
||||
data = {
|
||||
@@ -372,42 +432,49 @@ class DiscordClient:
|
||||
}
|
||||
if role_ids:
|
||||
data['roles'] = self._sanitize_role_ids(role_ids)
|
||||
|
||||
if nick:
|
||||
data['nick'] = str(nick)[:self._NICK_MAX_CHARS]
|
||||
|
||||
data['nick'] = GuildMember.sanitize_nick(nick)
|
||||
r = self._api_request(method='put', route=route, data=data)
|
||||
r.raise_for_status()
|
||||
if r.status_code == 201:
|
||||
return True
|
||||
elif r.status_code == 204:
|
||||
return None
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def guild_member(self, guild_id: int, user_id: int) -> dict:
|
||||
"""returns the user info for a guild member
|
||||
def guild_member(self, guild_id: int, user_id: int) -> Optional[GuildMember]:
|
||||
"""Fetch info for a guild member.
|
||||
|
||||
or None if the user is not a member of the guild
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
user_id: Discord ID of the user
|
||||
|
||||
Returns:
|
||||
guild member or ``None`` if the user is not a member of the guild
|
||||
"""
|
||||
route = f'guilds/{guild_id}/members/{user_id}'
|
||||
r = self._api_request(method='get', route=route, raise_for_status=False)
|
||||
if self._is_member_unknown_error(r):
|
||||
logger.warning("Discord user ID %s could not be found on server.", user_id)
|
||||
return None
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
r.raise_for_status()
|
||||
return GuildMember.from_dict(r.json())
|
||||
|
||||
def modify_guild_member(
|
||||
self, guild_id: int, user_id: int, role_ids: list = None, nick: str = None
|
||||
) -> bool:
|
||||
"""Modify attributes of a guild member.
|
||||
self, guild_id: int, user_id: int, role_ids: List[int] = None, nick: str = None
|
||||
) -> Optional[bool]:
|
||||
"""Set properties of a guild member.
|
||||
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
user_id: Discord ID of the user
|
||||
roles_id: New list of role IDs (if provided)
|
||||
nick: New nickname (if provided)
|
||||
|
||||
Returns
|
||||
- True when successful
|
||||
- None if user is not a member of this guild
|
||||
- False otherwise
|
||||
- True when successful
|
||||
- None if user is not a member of this guild
|
||||
- False otherwise
|
||||
"""
|
||||
if not role_ids and not nick:
|
||||
raise ValueError('Must specify role_ids or nick')
|
||||
@@ -420,7 +487,7 @@ class DiscordClient:
|
||||
data['roles'] = self._sanitize_role_ids(role_ids)
|
||||
|
||||
if nick:
|
||||
data['nick'] = self._sanitize_nick(nick)
|
||||
data['nick'] = GuildMember.sanitize_nick(nick)
|
||||
|
||||
route = f"guilds/{guild_id}/members/{user_id}"
|
||||
r = self._api_request(
|
||||
@@ -429,21 +496,22 @@ class DiscordClient:
|
||||
if self._is_member_unknown_error(r):
|
||||
logger.warning('User ID %s is not a member of this guild', user_id)
|
||||
return None
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
r.raise_for_status()
|
||||
if r.status_code == 204:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def remove_guild_member(self, guild_id: int, user_id: int) -> bool:
|
||||
"""Remove a member from a guild
|
||||
def remove_guild_member(self, guild_id: int, user_id: int) -> Optional[bool]:
|
||||
"""Remove a member from a guild.
|
||||
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
user_id: Discord ID of the user
|
||||
|
||||
Returns:
|
||||
- True when successful
|
||||
- None if member does not exist
|
||||
- False otherwise
|
||||
- True when successful
|
||||
- None if member does not exist
|
||||
- False otherwise
|
||||
"""
|
||||
route = f"guilds/{guild_id}/members/{user_id}"
|
||||
r = self._api_request(
|
||||
@@ -452,19 +520,16 @@ class DiscordClient:
|
||||
if self._is_member_unknown_error(r):
|
||||
logger.warning('User ID %s is not a member of this guild', user_id)
|
||||
return None
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
r.raise_for_status()
|
||||
if r.status_code == 204:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
# Guild member roles
|
||||
|
||||
def add_guild_member_role(
|
||||
self, guild_id: int, user_id: int, role_id: int
|
||||
) -> bool:
|
||||
) -> Optional[bool]:
|
||||
"""Adds a role to a guild member
|
||||
|
||||
Returns:
|
||||
@@ -477,43 +542,69 @@ class DiscordClient:
|
||||
if self._is_member_unknown_error(r):
|
||||
logger.warning('User ID %s is not a member of this guild', user_id)
|
||||
return None
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
r.raise_for_status()
|
||||
if r.status_code == 204:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def remove_guild_member_role(
|
||||
self, guild_id: int, user_id: int, role_id: int
|
||||
) -> bool:
|
||||
"""Removes a role to a guild member
|
||||
) -> Optional[bool]:
|
||||
"""Remove a role to a guild member
|
||||
|
||||
Args:
|
||||
guild_id: Discord ID of the guild
|
||||
user_id: Discord ID of the user
|
||||
role_id: Discord ID of role to be removed
|
||||
|
||||
Returns:
|
||||
- True when successful
|
||||
- None if member does not exist
|
||||
- False otherwise
|
||||
- True when successful
|
||||
- None if member does not exist
|
||||
- False otherwise
|
||||
"""
|
||||
route = f"guilds/{guild_id}/members/{user_id}/roles/{role_id}"
|
||||
r = self._api_request(method='delete', route=route, raise_for_status=False)
|
||||
if self._is_member_unknown_error(r):
|
||||
logger.warning('User ID %s is not a member of this guild', user_id)
|
||||
return None
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
r.raise_for_status()
|
||||
if r.status_code == 204:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def guild_member_roles(self, guild_id: int, user_id: int) -> Optional[RolesSet]:
|
||||
"""Fetch the current guild roles of a guild member.
|
||||
|
||||
Args:
|
||||
- guild_id: Discord guild ID
|
||||
- user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
- Member roles
|
||||
- None if user is not a member of the guild
|
||||
"""
|
||||
member_info = self.guild_member(guild_id=guild_id, user_id=user_id)
|
||||
if member_info is None:
|
||||
return None # User is no longer a member
|
||||
guild_roles = RolesSet(self.guild_roles(guild_id=guild_id))
|
||||
logger.debug('Current guild roles: %s', guild_roles.ids())
|
||||
if not guild_roles.has_roles(member_info.roles):
|
||||
guild_roles = RolesSet(
|
||||
self.guild_roles(guild_id=guild_id, use_cache=False)
|
||||
)
|
||||
if not guild_roles.has_roles(member_info.roles):
|
||||
role_ids = set(member_info.roles).difference(guild_roles.ids())
|
||||
raise RuntimeError(
|
||||
f'Discord user {user_id} has unknown roles: {role_ids}'
|
||||
)
|
||||
return guild_roles.subset(member_info.roles)
|
||||
|
||||
@classmethod
|
||||
def _is_member_unknown_error(cls, r: requests.Response) -> bool:
|
||||
try:
|
||||
result = (
|
||||
r.status_code == cls._HTTP_STATUS_CODE_NOT_FOUND
|
||||
and r.json()['code'] == cls._DISCORD_STATUS_CODE_UNKNOWN_MEMBER
|
||||
r.status_code == HTTPStatus.NOT_FOUND
|
||||
and r.json()['code'] == DiscordApiStatusCode.UNKNOWN_MEMBER
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
result = False
|
||||
@@ -530,7 +621,19 @@ class DiscordClient:
|
||||
authorization: str = None,
|
||||
raise_for_status: bool = True
|
||||
) -> requests.Response:
|
||||
"""Core method for performing all API calls"""
|
||||
"""Core method for performing all API calls.
|
||||
|
||||
Args:
|
||||
method: HTTP method of the request, e.g. "get"
|
||||
route: Route in the Discord API, e.g. "users/@me"
|
||||
data: Data to be send with the request
|
||||
authorization: The authorization string to be used.
|
||||
Will use the default bot token if not set.
|
||||
raise_for_status: Whether a requests exception is to be raised when not ok
|
||||
|
||||
Returns:
|
||||
The raw response from the API
|
||||
"""
|
||||
uid = uuid1().hex
|
||||
|
||||
if not hasattr(requests, method):
|
||||
@@ -578,7 +681,7 @@ class DiscordClient:
|
||||
r.text
|
||||
)
|
||||
|
||||
if r.status_code == self._HTTP_STATUS_CODE_RATE_LIMITED:
|
||||
if r.status_code == HTTPStatus.TOO_MANY_REQUESTS:
|
||||
self._handle_new_api_backoff(r, uid)
|
||||
|
||||
self._report_rate_limit_from_api(r, uid)
|
||||
@@ -589,9 +692,10 @@ class DiscordClient:
|
||||
return r
|
||||
|
||||
def _handle_ongoing_api_backoff(self, uid: str) -> None:
|
||||
"""checks if api is currently on backoff
|
||||
if on backoff: will do a blocking wait if it expires soon,
|
||||
else raises exception
|
||||
"""Check if api is currently on backoff.
|
||||
|
||||
If on backoff: will do a blocking wait if it expires soon,
|
||||
else raises exception.
|
||||
"""
|
||||
global_backoff_duration = self._redis.pttl(self._KEY_GLOBAL_BACKOFF_UNTIL)
|
||||
if global_backoff_duration > 0:
|
||||
@@ -611,8 +715,9 @@ class DiscordClient:
|
||||
raise DiscordTooManyRequestsError(retry_after=global_backoff_duration)
|
||||
|
||||
def _ensure_rate_limed_not_exhausted(self, uid: str) -> int:
|
||||
"""ensures that the rate limit is not exhausted
|
||||
if exhausted: will do a blocking wait if rate limit resets soon,
|
||||
"""Ensures that the rate limit is not exhausted.
|
||||
|
||||
If exhausted: will do a blocking wait if rate limit resets soon,
|
||||
else raises exception
|
||||
|
||||
returns requests remaining on success
|
||||
@@ -655,10 +760,10 @@ class DiscordClient:
|
||||
)
|
||||
raise DiscordRateLimitExhausted(resets_in)
|
||||
|
||||
raise RuntimeError('Failed to handle rate limit after after too tries.')
|
||||
raise RuntimeError('Failed to handle rate limit after after too many tries.')
|
||||
|
||||
def _handle_new_api_backoff(self, r: requests.Response, uid: str) -> None:
|
||||
"""raises exception for new API backoff error"""
|
||||
"""Raise exception for new API backoff error."""
|
||||
response = r.json()
|
||||
if 'retry_after' in response:
|
||||
try:
|
||||
@@ -680,8 +785,8 @@ class DiscordClient:
|
||||
)
|
||||
raise DiscordTooManyRequestsError(retry_after=retry_after)
|
||||
|
||||
def _report_rate_limit_from_api(self, r, uid):
|
||||
"""Tries to log the current rate limit reported from API"""
|
||||
def _report_rate_limit_from_api(self, r, uid) -> None:
|
||||
"""Try to log the current rate limit reported from API."""
|
||||
if (
|
||||
logger.getEffectiveLevel() <= logging.DEBUG
|
||||
and 'x-ratelimit-limit' in r.headers
|
||||
@@ -704,22 +809,17 @@ class DiscordClient:
|
||||
|
||||
@staticmethod
|
||||
def _redis_decode(value: str) -> str:
|
||||
"""Decodes a string from Redis and passes through None and Booleans"""
|
||||
"""Decode a string from Redis and passes through None and Booleans."""
|
||||
if value is not None and not isinstance(value, bool):
|
||||
return value.decode('utf-8')
|
||||
else:
|
||||
return value
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _generate_hash(key: str) -> str:
|
||||
"""Generate hash key for given string."""
|
||||
return md5(key.encode('utf-8')).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_role_ids(role_ids: list) -> list:
|
||||
"""make sure its a list of integers"""
|
||||
return [int(role_id) for role_id in list(role_ids)]
|
||||
|
||||
@classmethod
|
||||
def _sanitize_nick(cls, nick: str) -> str:
|
||||
"""shortens too long strings if necessary"""
|
||||
return str(nick)[:cls._NICK_MAX_CHARS]
|
||||
def _sanitize_role_ids(role_ids: Iterable[int]) -> List[int]:
|
||||
"""Sanitize a list of role IDs, i.e. make sure its a list of unique integers."""
|
||||
return [int(role_id) for role_id in set(role_ids)]
|
||||
|
||||
@@ -1,23 +1,26 @@
|
||||
"""Custom exceptions for the Discord Client package."""
|
||||
|
||||
import math
|
||||
|
||||
|
||||
class DiscordClientException(Exception):
|
||||
"""Base Exception for the Discord client"""
|
||||
"""Base Exception for the Discord client."""
|
||||
|
||||
|
||||
class DiscordApiBackoff(DiscordClientException):
|
||||
"""Exception signaling we need to backoff from sending requests to the API for now
|
||||
"""Exception signaling we need to backoff from sending requests to the API for now.
|
||||
|
||||
Args:
|
||||
retry_after: time to retry after in milliseconds
|
||||
"""
|
||||
|
||||
def __init__(self, retry_after: int):
|
||||
"""
|
||||
:param retry_after: int time to retry after in milliseconds
|
||||
"""
|
||||
super().__init__()
|
||||
self.retry_after = int(retry_after)
|
||||
|
||||
@property
|
||||
def retry_after_seconds(self):
|
||||
"""Time to retry after in seconds."""
|
||||
return math.ceil(self.retry_after / 1000)
|
||||
|
||||
|
||||
|
||||
@@ -1,27 +1,37 @@
|
||||
from copy import copy
|
||||
from typing import Set, Iterable
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from .models import Role
|
||||
|
||||
|
||||
class DiscordRoles:
|
||||
"""Container class that helps dealing with Discord roles.
|
||||
class RolesSet:
|
||||
"""Container of Discord roles with added functionality.
|
||||
|
||||
Objects of this class are immutable and work in many ways like sets.
|
||||
|
||||
Ideally objects are initialized from raw API responses,
|
||||
e.g. from DiscordClient.guild.roles()
|
||||
"""
|
||||
_ROLE_NAME_MAX_CHARS = 100
|
||||
e.g. from DiscordClient.guild.roles().
|
||||
|
||||
def __init__(self, roles_lst: list) -> None:
|
||||
"""roles_lst must be a list of dict, each defining a role"""
|
||||
Args:
|
||||
roles_lst: List of dicts, each defining a role
|
||||
"""
|
||||
def __init__(self, roles_lst: Iterable[Role]) -> None:
|
||||
if not isinstance(roles_lst, (list, set, tuple)):
|
||||
raise TypeError('roles_lst must be of type list, set or tuple')
|
||||
self._roles = dict()
|
||||
self._roles_by_name = dict()
|
||||
for role in list(roles_lst):
|
||||
self._assert_valid_role(role)
|
||||
self._roles[int(role['id'])] = role
|
||||
self._roles_by_name[self.sanitize_role_name(role['name'])] = role
|
||||
if not isinstance(role, Role):
|
||||
raise TypeError('Roles must be of type Role: %s' % role)
|
||||
self._roles[role.id] = role
|
||||
self._roles_by_name[role.name] = role
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._roles_by_name:
|
||||
roles = '"' + '", "'.join(sorted(list(self._roles_by_name.keys()))) + '"'
|
||||
else:
|
||||
roles = ""
|
||||
return f'{self.__class__.__name__}([{roles}])'
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, type(self)):
|
||||
@@ -41,15 +51,15 @@ class DiscordRoles:
|
||||
return len(self._roles.keys())
|
||||
|
||||
def has_roles(self, role_ids: Set[int]) -> bool:
|
||||
"""returns true if this objects contains all roles defined by given role_ids
|
||||
incl. managed roles
|
||||
"""True if this objects contains all roles defined by given role_ids
|
||||
incl. managed roles.
|
||||
"""
|
||||
role_ids = {int(id) for id in role_ids}
|
||||
all_role_ids = self._roles.keys()
|
||||
return role_ids.issubset(all_role_ids)
|
||||
|
||||
def ids(self) -> Set[int]:
|
||||
"""return a set of all role IDs"""
|
||||
"""Set of all role IDs."""
|
||||
return set(self._roles.keys())
|
||||
|
||||
def subset(
|
||||
@@ -57,13 +67,13 @@ class DiscordRoles:
|
||||
role_ids: Iterable[int] = None,
|
||||
managed_only: bool = False,
|
||||
role_names: Iterable[str] = None
|
||||
) -> "DiscordRoles":
|
||||
"""returns a new object containing the subset of roles
|
||||
) -> "RolesSet":
|
||||
"""Create instance containing the subset of roles
|
||||
|
||||
Args:
|
||||
- role_ids: role ids must be in the provided list
|
||||
- managed_only: roles must be managed
|
||||
- role_names: role names must match provided list (not case sensitive)
|
||||
role_ids: role ids must be in the provided list
|
||||
managed_only: roles must be managed
|
||||
role_names: role names must match provided list (not case sensitive)
|
||||
"""
|
||||
if role_ids is not None:
|
||||
role_ids = {int(id) for id in role_ids}
|
||||
@@ -75,72 +85,50 @@ class DiscordRoles:
|
||||
|
||||
elif role_ids is None and managed_only:
|
||||
return type(self)([
|
||||
role for _, role in self._roles.items() if role['managed']
|
||||
role for _, role in self._roles.items() if role.managed
|
||||
])
|
||||
|
||||
elif role_ids is not None and managed_only:
|
||||
return type(self)([
|
||||
role for role_id, role in self._roles.items()
|
||||
if role_id in role_ids and role['managed']
|
||||
if role_id in role_ids and role.managed
|
||||
])
|
||||
|
||||
elif role_ids is None and managed_only is False and role_names is not None:
|
||||
role_names = {self.sanitize_role_name(name).lower() for name in role_names}
|
||||
role_names = {Role.sanitize_name(name).lower() for name in role_names}
|
||||
return type(self)([
|
||||
role for role in self._roles.values()
|
||||
if role["name"].lower() in role_names
|
||||
if role.name.lower() in role_names
|
||||
])
|
||||
|
||||
return copy(self)
|
||||
|
||||
def union(self, other: object) -> "DiscordRoles":
|
||||
"""returns a new roles object that is the union of this roles object
|
||||
with other"""
|
||||
def union(self, other: object) -> "RolesSet":
|
||||
"""Create instance that is the union of this roles object with other."""
|
||||
return type(self)(list(self) + list(other))
|
||||
|
||||
def difference(self, other: object) -> "DiscordRoles":
|
||||
"""returns a new roles object that only contains the roles
|
||||
that exist in the current objects, but not in other
|
||||
def difference(self, other: object) -> "RolesSet":
|
||||
"""Create instance that only contains the roles
|
||||
that exist in the current objects, but not in other.
|
||||
"""
|
||||
new_ids = self.ids().difference(other.ids())
|
||||
return self.subset(role_ids=new_ids)
|
||||
|
||||
def role_by_name(self, role_name: str) -> dict:
|
||||
"""returns role if one with matching name is found else an empty dict"""
|
||||
role_name = self.sanitize_role_name(role_name)
|
||||
def role_by_name(self, role_name: str) -> Optional[Role]:
|
||||
"""Role if one with matching name is found else None."""
|
||||
role_name = Role.sanitize_name(role_name)
|
||||
if role_name in self._roles_by_name:
|
||||
return self._roles_by_name[role_name]
|
||||
return dict()
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def create_from_matched_roles(cls, matched_roles: list) -> "DiscordRoles":
|
||||
"""returns a new object created from the given list of matches roles
|
||||
def create_from_matched_roles(
|
||||
cls, matched_roles: List[Tuple[Role, bool]]
|
||||
) -> "RolesSet":
|
||||
"""Create new instance from the given list of matches roles.
|
||||
|
||||
matches_roles must be a list of tuples in the form: (role, created)
|
||||
Args:
|
||||
matches_roles: list of matches roles
|
||||
"""
|
||||
raw_roles = [x[0] for x in matched_roles]
|
||||
return cls(raw_roles)
|
||||
|
||||
@staticmethod
|
||||
def _assert_valid_role(role: dict) -> None:
|
||||
if not isinstance(role, dict):
|
||||
raise TypeError('Roles must be of type dict: %s' % role)
|
||||
|
||||
if 'id' not in role or 'name' not in role or 'managed' not in role:
|
||||
raise ValueError('This role is not valid: %s' % role)
|
||||
|
||||
@classmethod
|
||||
def sanitize_role_name(cls, role_name: str) -> str:
|
||||
"""shortens too long strings if necessary"""
|
||||
return str(role_name)[:cls._ROLE_NAME_MAX_CHARS]
|
||||
|
||||
|
||||
def match_or_create_roles_from_names(
|
||||
client: object, guild_id: int, role_names: list
|
||||
) -> DiscordRoles:
|
||||
"""Shortcut for getting the result of matching role names as DiscordRoles object"""
|
||||
return DiscordRoles.create_from_matched_roles(
|
||||
client.match_or_create_roles_from_names(
|
||||
guild_id=guild_id, role_names=role_names
|
||||
)
|
||||
)
|
||||
|
||||
125
allianceauth/services/modules/discord/discord_client/models.py
Normal file
125
allianceauth/services/modules/discord/discord_client/models.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Implementation of Discord objects used by this client.
|
||||
|
||||
Note that only those objects and properties are implemented, which are needed by AA.
|
||||
|
||||
Names and types are mirrored from the API whenever possible.
|
||||
Discord's snowflake type (used by Discord IDs) is implemented as int.
|
||||
"""
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import FrozenSet
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class User:
|
||||
"""A user on Discord."""
|
||||
|
||||
id: int
|
||||
username: str
|
||||
discriminator: str
|
||||
|
||||
def __post_init__(self):
|
||||
object.__setattr__(self, "id", int(self.id))
|
||||
object.__setattr__(self, "username", str(self.username))
|
||||
object.__setattr__(self, "discriminator", str(self.discriminator))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "User":
|
||||
"""Create object from dictionary as received from the API."""
|
||||
return cls(
|
||||
id=int(data["id"]),
|
||||
username=data["username"],
|
||||
discriminator=data["discriminator"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Role:
|
||||
"""A role on Discord."""
|
||||
|
||||
_ROLE_NAME_MAX_CHARS = 100
|
||||
|
||||
id: int
|
||||
name: str
|
||||
managed: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
object.__setattr__(self, "id", int(self.id))
|
||||
object.__setattr__(self, "name", self.sanitize_name(self.name))
|
||||
object.__setattr__(self, "managed", bool(self.managed))
|
||||
|
||||
def asdict(self) -> dict:
|
||||
"""Convert object into a dictionary representation."""
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "Role":
|
||||
"""Create object from dictionary as received from the API."""
|
||||
return cls(id=int(data["id"]), name=data["name"], managed=data["managed"])
|
||||
|
||||
@classmethod
|
||||
def sanitize_name(cls, role_name: str) -> str:
|
||||
"""Shorten too long names if necessary."""
|
||||
return str(role_name)[: cls._ROLE_NAME_MAX_CHARS]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Guild:
|
||||
"""A guild on Discord."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
roles: FrozenSet[Role]
|
||||
|
||||
def __post_init__(self):
|
||||
object.__setattr__(self, "id", int(self.id))
|
||||
object.__setattr__(self, "name", str(self.name))
|
||||
for role in self.roles:
|
||||
if not isinstance(role, Role):
|
||||
raise TypeError("roles can only contain Role objects.")
|
||||
object.__setattr__(self, "roles", frozenset(self.roles))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "Guild":
|
||||
"""Create object from dictionary as received from the API."""
|
||||
return cls(
|
||||
id=int(data["id"]),
|
||||
name=data["name"],
|
||||
roles=frozenset(Role.from_dict(obj) for obj in data["roles"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GuildMember:
|
||||
"""A member of a guild on Discord."""
|
||||
|
||||
_NICK_MAX_CHARS = 32
|
||||
|
||||
roles: FrozenSet[int]
|
||||
nick: str = None
|
||||
user: User = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.nick:
|
||||
object.__setattr__(self, "nick", self.sanitize_nick(self.nick))
|
||||
if self.user and not isinstance(self.user, User):
|
||||
raise TypeError("user must be of type User")
|
||||
for role in self.roles:
|
||||
if not isinstance(role, int):
|
||||
raise TypeError("roles can only contain ints")
|
||||
object.__setattr__(self, "roles", frozenset(self.roles))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "GuildMember":
|
||||
"""Create object from dictionary as received from the API."""
|
||||
params = {"roles": {int(obj) for obj in data["roles"]}}
|
||||
if data.get("user"):
|
||||
params["user"] = User.from_dict(data["user"])
|
||||
if data.get("nick"):
|
||||
params["nick"] = data["nick"]
|
||||
return cls(**params)
|
||||
|
||||
@classmethod
|
||||
def sanitize_nick(cls, nick: str) -> str:
|
||||
"""Sanitize a nick, i.e. shorten too long strings if necessary."""
|
||||
return str(nick)[: cls._NICK_MAX_CHARS]
|
||||
@@ -1,40 +0,0 @@
|
||||
TEST_GUILD_ID = 123456789012345678
|
||||
TEST_USER_ID = 198765432012345678
|
||||
TEST_USER_NAME = 'Peter Parker'
|
||||
TEST_USER_DISCRIMINATOR = '1234'
|
||||
TEST_BOT_TOKEN = 'abcdefhijlkmnopqastzvwxyz1234567890ABCDEFGHOJKLMNOPQRSTUVWXY'
|
||||
TEST_ROLE_ID = 654321012345678912
|
||||
|
||||
|
||||
def create_role(id: int, name: str, managed=False) -> dict:
|
||||
return {
|
||||
'id': int(id),
|
||||
'name': str(name),
|
||||
'managed': bool(managed)
|
||||
}
|
||||
|
||||
|
||||
def create_matched_role(role, created=False) -> tuple:
|
||||
return role, created
|
||||
|
||||
|
||||
ROLE_ALPHA = create_role(1, 'alpha')
|
||||
ROLE_BRAVO = create_role(2, 'bravo')
|
||||
ROLE_CHARLIE = create_role(3, 'charlie')
|
||||
ROLE_CHARLIE_2 = create_role(4, 'Charlie') # Discord roles are case sensitive
|
||||
ROLE_MIKE = create_role(13, 'mike', True)
|
||||
|
||||
|
||||
ALL_ROLES = [ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE]
|
||||
|
||||
|
||||
def create_user_info(
|
||||
id: int = TEST_USER_ID,
|
||||
username: str = TEST_USER_NAME,
|
||||
discriminator: str = TEST_USER_DISCRIMINATOR
|
||||
):
|
||||
return {
|
||||
'id': str(id),
|
||||
'username': str(username[:32]),
|
||||
'discriminator': str(discriminator[:4])
|
||||
}
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
{
|
||||
"guilds": {
|
||||
"2909267986263572999": {
|
||||
"id": "2909267986263572999",
|
||||
"name": "Mason's Test Server",
|
||||
"icon": "389030ec9db118cb5b85a732333b7c98",
|
||||
"description": null,
|
||||
"splash": "75610b05a0dd09ec2c3c7df9f6975ea0",
|
||||
"discovery_splash": null,
|
||||
"approximate_member_count": 2,
|
||||
"approximate_presence_count": 2,
|
||||
"features": [
|
||||
"INVITE_SPLASH",
|
||||
"VANITY_URL",
|
||||
"COMMERCE",
|
||||
"BANNER",
|
||||
"NEWS",
|
||||
"VERIFIED",
|
||||
"VIP_REGIONS"
|
||||
],
|
||||
"emojis": [
|
||||
{
|
||||
"name": "ultrafastparrot",
|
||||
"roles": [],
|
||||
"id": "393564762228785161",
|
||||
"require_colons": true,
|
||||
"managed": false,
|
||||
"animated": true,
|
||||
"available": true
|
||||
}
|
||||
],
|
||||
"banner": "5c3cb8d1bc159937fffe7e641ec96ca7",
|
||||
"owner_id": "53908232506183680",
|
||||
"application_id": null,
|
||||
"region": null,
|
||||
"afk_channel_id": null,
|
||||
"afk_timeout": 300,
|
||||
"system_channel_id": null,
|
||||
"widget_enabled": true,
|
||||
"widget_channel_id": "639513352485470208",
|
||||
"verification_level": 0,
|
||||
"roles": [
|
||||
{
|
||||
"id": "2909267986263572999",
|
||||
"name": "@everyone",
|
||||
"permissions": "49794752",
|
||||
"position": 0,
|
||||
"color": 0,
|
||||
"hoist": false,
|
||||
"managed": false,
|
||||
"mentionable": false
|
||||
}
|
||||
],
|
||||
"default_message_notifications": 1,
|
||||
"mfa_level": 0,
|
||||
"explicit_content_filter": 0,
|
||||
"max_presences": null,
|
||||
"max_members": 250000,
|
||||
"max_video_channel_users": 25,
|
||||
"vanity_url_code": "no",
|
||||
"premium_tier": 0,
|
||||
"premium_subscription_count": 0,
|
||||
"system_channel_flags": 0,
|
||||
"preferred_locale": "en-US",
|
||||
"rules_channel_id": null,
|
||||
"public_updates_channel_id": null
|
||||
}
|
||||
},
|
||||
"guildMembers": {
|
||||
"1": {
|
||||
"user": {},
|
||||
"nick": null,
|
||||
"avatar": null,
|
||||
"roles": [],
|
||||
"joined_at": "2015-04-26T06:26:56.936000+00:00",
|
||||
"deaf": false,
|
||||
"mute": false
|
||||
},
|
||||
"2": {
|
||||
"user": {
|
||||
"id": "80351110224678912",
|
||||
"username": "Nelly",
|
||||
"discriminator": "1337",
|
||||
"avatar": "8342729096ea3675442027381ff50dfe",
|
||||
"verified": true,
|
||||
"email": "nelly@discord.com",
|
||||
"flags": 64,
|
||||
"banner": "06c16474723fe537c283b8efa61a30c8",
|
||||
"accent_color": 16711680,
|
||||
"premium_type": 1,
|
||||
"public_flags": 64
|
||||
},
|
||||
"nick": "Nelly the great",
|
||||
"avatar": null,
|
||||
"roles": [
|
||||
"197150972374548480",
|
||||
"41771983423143936"
|
||||
],
|
||||
"joined_at": "2015-04-26T06:26:56.936000+00:00",
|
||||
"deaf": false,
|
||||
"mute": false
|
||||
}
|
||||
},
|
||||
"roles": {
|
||||
"197150972374548480": {
|
||||
"id": "197150972374548480",
|
||||
"name": "My Managed Role",
|
||||
"color": 3447003,
|
||||
"hoist": false,
|
||||
"icon": "cf3ced8600b777c9486c6d8d84fb4327",
|
||||
"unicode_emoji": null,
|
||||
"position": 2,
|
||||
"permissions": "66321471",
|
||||
"managed": true,
|
||||
"mentionable": false
|
||||
},
|
||||
"2909267986263572999": {
|
||||
"id": "2909267986263572999",
|
||||
"name": "@everyone",
|
||||
"permissions": "49794752",
|
||||
"position": 0,
|
||||
"color": 0,
|
||||
"hoist": false,
|
||||
"managed": false,
|
||||
"mentionable": false
|
||||
},
|
||||
"41771983423143936": {
|
||||
"id": "41771983423143936",
|
||||
"name": "WE DEM BOYZZ!!!!!!",
|
||||
"color": 3447003,
|
||||
"hoist": true,
|
||||
"icon": "cf3ced8600b777c9486c6d8d84fb4327",
|
||||
"unicode_emoji": null,
|
||||
"position": 1,
|
||||
"permissions": "66321471",
|
||||
"managed": false,
|
||||
"mentionable": false
|
||||
}
|
||||
},
|
||||
"users": {
|
||||
"80351110224678912": {
|
||||
"id": "80351110224678912",
|
||||
"username": "Nelly",
|
||||
"discriminator": "1337",
|
||||
"avatar": "8342729096ea3675442027381ff50dfe",
|
||||
"verified": true,
|
||||
"email": "nelly@discord.com",
|
||||
"flags": 64,
|
||||
"banner": "06c16474723fe537c283b8efa61a30c8",
|
||||
"accent_color": 16711680,
|
||||
"premium_type": 1,
|
||||
"public_flags": 64
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
from itertools import count
|
||||
|
||||
from django.utils.timezone import now
|
||||
|
||||
from ..client import DiscordApiStatusCode
|
||||
from ..models import Guild, GuildMember, Role, User
|
||||
|
||||
TEST_GUILD_ID = 123456789012345678
|
||||
TEST_GUILD_NAME = "Test Guild"
|
||||
TEST_USER_ID = 198765432012345678
|
||||
TEST_USER_NAME = "Peter Parker"
|
||||
TEST_USER_DISCRIMINATOR = "1234"
|
||||
TEST_BOT_TOKEN = "abcdefhijlkmnopqastzvwxyz1234567890ABCDEFGHOJKLMNOPQRSTUVWXY"
|
||||
TEST_ROLE_ID = 654321012345678912
|
||||
|
||||
|
||||
def create_discord_role_object(id: int, name: str, managed: bool = False) -> dict:
|
||||
return {"id": str(int(id)), "name": str(name), "managed": bool(managed)}
|
||||
|
||||
|
||||
def create_matched_role(role, created=False) -> tuple:
|
||||
return role, created
|
||||
|
||||
|
||||
def create_discord_user_object(**kwargs):
|
||||
params = {
|
||||
"id": TEST_USER_ID,
|
||||
"username": TEST_USER_NAME,
|
||||
"discriminator": TEST_USER_DISCRIMINATOR,
|
||||
}
|
||||
params.update(kwargs)
|
||||
params["id"] = str(int(params["id"]))
|
||||
return params
|
||||
|
||||
|
||||
def create_discord_guild_member_object(user=None, **kwargs):
|
||||
user_params = {}
|
||||
if user:
|
||||
user_params["user"] = user
|
||||
params = {
|
||||
"user": create_discord_user_object(**user_params),
|
||||
"roles": [],
|
||||
"joined_at": now().isoformat(),
|
||||
"deaf": False,
|
||||
"mute": False,
|
||||
}
|
||||
params.update(kwargs)
|
||||
params["roles"] = [str(int(obj)) for obj in params["roles"]]
|
||||
return params
|
||||
|
||||
|
||||
def create_discord_error_response(code: int) -> dict:
|
||||
return {"code": int(code)}
|
||||
|
||||
|
||||
def create_discord_error_response_unknown_member() -> dict:
|
||||
return create_discord_error_response(DiscordApiStatusCode.UNKNOWN_MEMBER.value)
|
||||
|
||||
|
||||
def create_discord_guild_object(**kwargs):
|
||||
params = {"id": TEST_GUILD_ID, "name": TEST_GUILD_NAME, "roles": []}
|
||||
params.update(kwargs)
|
||||
params["id"] = str(int(params["id"]))
|
||||
return params
|
||||
|
||||
|
||||
def create_user(**kwargs):
|
||||
params = {
|
||||
"id": TEST_USER_ID,
|
||||
"username": TEST_USER_NAME,
|
||||
"discriminator": TEST_USER_DISCRIMINATOR,
|
||||
}
|
||||
params.update(kwargs)
|
||||
return User(**params)
|
||||
|
||||
|
||||
def create_guild(**kwargs):
|
||||
params = {"id": TEST_GUILD_ID, "name": TEST_GUILD_NAME, "roles": []}
|
||||
params.update(kwargs)
|
||||
return Guild(**params)
|
||||
|
||||
|
||||
def create_guild_member(**kwargs):
|
||||
params = {"user": create_user(), "roles": []}
|
||||
params.update(kwargs)
|
||||
return GuildMember(**params)
|
||||
|
||||
|
||||
def create_role(**kwargs) -> dict:
|
||||
params = {"managed": False}
|
||||
params.update(kwargs)
|
||||
if "id" not in params:
|
||||
params["id"] = next_number("role")
|
||||
if "name" not in params:
|
||||
params["name"] = f"Test Role #{params['id']}"
|
||||
return Role(**params)
|
||||
|
||||
|
||||
def next_number(key: str = None) -> int:
|
||||
"""Calculate the next number in a persistent sequence."""
|
||||
if key is None:
|
||||
key = "_general"
|
||||
try:
|
||||
return next_number._counter[key].__next__()
|
||||
except AttributeError:
|
||||
next_number._counter = dict()
|
||||
except KeyError:
|
||||
pass
|
||||
next_number._counter[key] = count(start=1)
|
||||
return next_number._counter[key].__next__()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,177 +1,201 @@
|
||||
from unittest import TestCase
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
from . import (
|
||||
ROLE_ALPHA,
|
||||
ROLE_BRAVO,
|
||||
ROLE_CHARLIE,
|
||||
ROLE_CHARLIE_2,
|
||||
ROLE_MIKE,
|
||||
ALL_ROLES,
|
||||
create_role
|
||||
)
|
||||
from .. import DiscordRoles
|
||||
from ..helpers import RolesSet
|
||||
from .factories import create_matched_role, create_role
|
||||
|
||||
|
||||
MODULE_PATH = 'allianceauth.services.modules.discord.discord_client.client'
|
||||
|
||||
|
||||
class TestDiscordRoles(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.all_roles = DiscordRoles(ALL_ROLES)
|
||||
|
||||
class TestRolesSet(NoSocketsTestCase):
|
||||
def test_can_create_simple(self):
|
||||
roles_raw = [ROLE_ALPHA]
|
||||
roles = DiscordRoles(roles_raw)
|
||||
# given
|
||||
roles_raw = [create_role()]
|
||||
# when
|
||||
roles = RolesSet(roles_raw)
|
||||
# then
|
||||
self.assertListEqual(list(roles), roles_raw)
|
||||
|
||||
def test_can_create_empty(self):
|
||||
roles_raw = []
|
||||
roles = DiscordRoles(roles_raw)
|
||||
# when
|
||||
roles = RolesSet([])
|
||||
# then
|
||||
self.assertListEqual(list(roles), [])
|
||||
|
||||
def test_raises_exception_if_roles_raw_of_wrong_type(self):
|
||||
with self.assertRaises(TypeError):
|
||||
DiscordRoles({'id': 1})
|
||||
RolesSet({"id": 1})
|
||||
|
||||
def test_raises_exception_if_list_contains_non_dict(self):
|
||||
roles_raw = [ROLE_ALPHA, 'not_valid']
|
||||
# given
|
||||
roles_raw = [create_role(), "not_valid"]
|
||||
# when/then
|
||||
with self.assertRaises(TypeError):
|
||||
DiscordRoles(roles_raw)
|
||||
|
||||
def test_raises_exception_if_invalid_role_1(self):
|
||||
roles_raw = [{'name': 'alpha', 'managed': False}]
|
||||
with self.assertRaises(ValueError):
|
||||
DiscordRoles(roles_raw)
|
||||
|
||||
def test_raises_exception_if_invalid_role_2(self):
|
||||
roles_raw = [{'id': 1, 'managed': False}]
|
||||
with self.assertRaises(ValueError):
|
||||
DiscordRoles(roles_raw)
|
||||
|
||||
def test_raises_exception_if_invalid_role_3(self):
|
||||
roles_raw = [{'id': 1, 'name': 'alpha'}]
|
||||
with self.assertRaises(ValueError):
|
||||
DiscordRoles(roles_raw)
|
||||
RolesSet(roles_raw)
|
||||
|
||||
def test_roles_are_equal(self):
|
||||
roles_a = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_b = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
roles_a = RolesSet([role_a, role_b])
|
||||
roles_b = RolesSet([role_a, role_b])
|
||||
# when/then
|
||||
self.assertEqual(roles_a, roles_b)
|
||||
|
||||
def test_roles_are_not_equal(self):
|
||||
roles_a = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_b = DiscordRoles([ROLE_ALPHA])
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
roles_a = RolesSet([role_a, role_b])
|
||||
roles_b = RolesSet([role_a])
|
||||
# when/then
|
||||
self.assertNotEqual(roles_a, roles_b)
|
||||
|
||||
def test_different_objects_are_not_equal(self):
|
||||
roles_a = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_a = RolesSet([])
|
||||
self.assertFalse(roles_a == "invalid")
|
||||
|
||||
def test_len(self):
|
||||
self.assertEqual(len(self.all_roles), 4)
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
roles = RolesSet([role_a, role_b])
|
||||
# when/then
|
||||
self.assertEqual(len(roles), 2)
|
||||
|
||||
def test_contains(self):
|
||||
self.assertTrue(1 in self.all_roles)
|
||||
self.assertFalse(99 in self.all_roles)
|
||||
|
||||
def test_sanitize_role_name(self):
|
||||
role_name_input = 'x' * 110
|
||||
role_name_expected = 'x' * 100
|
||||
result = DiscordRoles.sanitize_role_name(role_name_input)
|
||||
self.assertEqual(result, role_name_expected)
|
||||
# given
|
||||
role_a = create_role(id=1)
|
||||
roles = RolesSet([role_a])
|
||||
# when/then
|
||||
self.assertTrue(1 in roles)
|
||||
self.assertFalse(99 in roles)
|
||||
|
||||
def test_objects_are_hashable(self):
|
||||
roles_a = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_b = DiscordRoles([ROLE_BRAVO, ROLE_ALPHA])
|
||||
roles_c = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO, ROLE_MIKE])
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
role_c = create_role()
|
||||
roles_a = RolesSet([role_a, role_b])
|
||||
roles_b = RolesSet([role_b, role_a])
|
||||
roles_c = RolesSet([role_a, role_b, role_c])
|
||||
# when/then
|
||||
self.assertIsNotNone(hash(roles_a))
|
||||
self.assertEqual(hash(roles_a), hash(roles_b))
|
||||
self.assertNotEqual(hash(roles_a), hash(roles_c))
|
||||
|
||||
def test_create_from_matched_roles(self):
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
matched_roles = [
|
||||
(ROLE_ALPHA, True),
|
||||
(ROLE_BRAVO, False)
|
||||
create_matched_role(role_a, True),
|
||||
create_matched_role(role_b, False),
|
||||
]
|
||||
roles = DiscordRoles.create_from_matched_roles(matched_roles)
|
||||
self.assertSetEqual(roles.ids(), {1, 2})
|
||||
|
||||
|
||||
class TestIds(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.all_roles = DiscordRoles(ALL_ROLES)
|
||||
# when
|
||||
roles = RolesSet.create_from_matched_roles(matched_roles)
|
||||
# then
|
||||
self.assertEqual(roles, RolesSet([role_a, role_b]))
|
||||
|
||||
def test_return_role_ids_default(self):
|
||||
result = self.all_roles.ids()
|
||||
expected = {1, 2, 3, 13}
|
||||
self.assertSetEqual(result, expected)
|
||||
role_a = create_role(id=1)
|
||||
role_b = create_role(id=2)
|
||||
roles = RolesSet([role_a, role_b])
|
||||
# when/then
|
||||
self.assertSetEqual(roles.ids(), {1, 2})
|
||||
|
||||
def test_return_role_ids_empty(self):
|
||||
roles = DiscordRoles([])
|
||||
# given
|
||||
roles = RolesSet([])
|
||||
# when/then
|
||||
self.assertSetEqual(roles.ids(), set())
|
||||
|
||||
|
||||
class TestSubset(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.all_roles = DiscordRoles(ALL_ROLES)
|
||||
|
||||
class TestRolesSetSubset(NoSocketsTestCase):
|
||||
def test_ids_only(self):
|
||||
role_ids = {1, 3}
|
||||
roles_subset = self.all_roles.subset(role_ids)
|
||||
expected = {1, 3}
|
||||
self.assertSetEqual(roles_subset.ids(), expected)
|
||||
# given
|
||||
role_a = create_role(id=1)
|
||||
role_b = create_role(id=2)
|
||||
role_c = create_role(id=3)
|
||||
roles_all = RolesSet([role_a, role_b, role_c])
|
||||
# when
|
||||
roles_subset = roles_all.subset({1, 3})
|
||||
# then
|
||||
self.assertEqual(roles_subset, RolesSet([role_a, role_c]))
|
||||
|
||||
def test_ids_as_string_work_too(self):
|
||||
role_ids = {'1', '3'}
|
||||
roles_subset = self.all_roles.subset(role_ids)
|
||||
expected = {1, 3}
|
||||
self.assertSetEqual(roles_subset.ids(), expected)
|
||||
# given
|
||||
role_a = create_role(id=1)
|
||||
role_b = create_role(id=2)
|
||||
role_c = create_role(id=3)
|
||||
roles_all = RolesSet([role_a, role_b, role_c])
|
||||
# when
|
||||
roles_subset = roles_all.subset({"1", "3"})
|
||||
# then
|
||||
self.assertEqual(roles_subset, RolesSet([role_a, role_c]))
|
||||
|
||||
def test_managed_only(self):
|
||||
roles = self.all_roles.subset(managed_only=True)
|
||||
expected = {13}
|
||||
self.assertSetEqual(roles.ids(), expected)
|
||||
# given
|
||||
role_a = create_role(id=1)
|
||||
role_m = create_role(id=13, managed=True)
|
||||
roles_all = RolesSet([role_a, role_m])
|
||||
# when
|
||||
roles_subset = roles_all.subset(managed_only=True)
|
||||
# then
|
||||
self.assertEqual(roles_subset, RolesSet([role_m]))
|
||||
|
||||
def test_ids_and_managed_only(self):
|
||||
role_ids = {1, 3, 13}
|
||||
roles_subset = self.all_roles.subset(role_ids, managed_only=True)
|
||||
expected = {13}
|
||||
self.assertSetEqual(roles_subset.ids(), expected)
|
||||
# given
|
||||
role_a = create_role(id=1)
|
||||
role_b = create_role(id=2)
|
||||
role_m = create_role(id=13, managed=True)
|
||||
roles_all = RolesSet([role_a, role_b, role_m])
|
||||
# when
|
||||
roles_subset = roles_all.subset({1, 13}, managed_only=True)
|
||||
# then
|
||||
self.assertEqual(roles_subset, RolesSet([role_m]))
|
||||
|
||||
def test_ids_are_empty(self):
|
||||
roles = self.all_roles.subset([])
|
||||
expected = set()
|
||||
self.assertSetEqual(roles.ids(), expected)
|
||||
# given
|
||||
role_a = create_role(id=1)
|
||||
role_b = create_role(id=2)
|
||||
roles_all = RolesSet([role_a, role_b])
|
||||
roles_subset = roles_all.subset([])
|
||||
# then
|
||||
self.assertEqual(roles_subset, RolesSet([]))
|
||||
|
||||
def test_no_parameters(self):
|
||||
roles = self.all_roles.subset()
|
||||
expected = {1, 2, 3, 13}
|
||||
self.assertSetEqual(roles.ids(), expected)
|
||||
# given
|
||||
role_a = create_role(id=1)
|
||||
role_b = create_role(id=2)
|
||||
roles_all = RolesSet([role_a, role_b])
|
||||
roles_subset = roles_all.subset()
|
||||
# then
|
||||
self.assertEqual(roles_subset, roles_all)
|
||||
|
||||
def test_should_return_role_names_only(self):
|
||||
# given
|
||||
all_roles = DiscordRoles([
|
||||
ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE, ROLE_CHARLIE_2
|
||||
])
|
||||
role_a = create_role(name="alpha")
|
||||
role_b = create_role(name="bravo")
|
||||
role_c1 = create_role(name="charlie")
|
||||
role_c2 = create_role(name="Charlie")
|
||||
roles_all = RolesSet([role_a, role_b, role_c1, role_c2])
|
||||
# when
|
||||
roles = all_roles.subset(role_names={"bravo", "charlie"})
|
||||
roles_subset = roles_all.subset(role_names={"bravo", "charlie"})
|
||||
# then
|
||||
self.assertSetEqual(roles.ids(), {2, 3, 4})
|
||||
self.assertSetEqual(roles_subset, RolesSet([role_b, role_c1, role_c2]))
|
||||
|
||||
|
||||
class TestHasRoles(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.all_roles = DiscordRoles(ALL_ROLES)
|
||||
class TestRolesSetHasRoles(NoSocketsTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
role_a = create_role(id=1)
|
||||
role_b = create_role(id=2)
|
||||
role_c = create_role(id=3)
|
||||
cls.all_roles = RolesSet([role_a, role_b, role_c])
|
||||
|
||||
def test_true_if_all_roles_exit(self):
|
||||
self.assertTrue(self.all_roles.has_roles([1, 2]))
|
||||
|
||||
def test_true_if_all_roles_exit_str(self):
|
||||
self.assertTrue(self.all_roles.has_roles(['1', '2']))
|
||||
self.assertTrue(self.all_roles.has_roles(["1", "2"]))
|
||||
|
||||
def test_false_if_role_does_not_exit(self):
|
||||
self.assertFalse(self.all_roles.has_roles([99]))
|
||||
@@ -183,74 +207,104 @@ class TestHasRoles(TestCase):
|
||||
self.assertTrue(self.all_roles.has_roles([]))
|
||||
|
||||
|
||||
class TestGetMatchingRolesByName(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.all_roles = DiscordRoles(ALL_ROLES)
|
||||
|
||||
class TestRolesSetGetMatchingRolesByName(NoSocketsTestCase):
|
||||
def test_return_role_if_matches(self):
|
||||
role_name = 'alpha'
|
||||
expected = ROLE_ALPHA
|
||||
result = self.all_roles.role_by_name(role_name)
|
||||
self.assertEqual(result, expected)
|
||||
# given
|
||||
role_a = create_role(name="alpha")
|
||||
role_b = create_role(name="bravo")
|
||||
roles = RolesSet([role_a, role_b])
|
||||
# when
|
||||
result = roles.role_by_name("alpha")
|
||||
# then
|
||||
self.assertEqual(result, role_a)
|
||||
|
||||
def test_return_role_if_matches_and_limit_max_length(self):
|
||||
role_name = 'x' * 120
|
||||
expected = create_role(77, 'x' * 100)
|
||||
roles = DiscordRoles([expected])
|
||||
# given
|
||||
role_name = "x" * 120
|
||||
role = create_role(name="x" * 100)
|
||||
roles = RolesSet([role])
|
||||
# when
|
||||
result = roles.role_by_name(role_name)
|
||||
self.assertEqual(result, expected)
|
||||
# then
|
||||
self.assertEqual(result, role)
|
||||
|
||||
def test_return_empty_if_not_matches(self):
|
||||
role_name = 'lima'
|
||||
expected = {}
|
||||
result = self.all_roles.role_by_name(role_name)
|
||||
self.assertEqual(result, expected)
|
||||
# given
|
||||
role_a = create_role(name="alpha")
|
||||
role_b = create_role(name="bravo")
|
||||
roles = RolesSet([role_a, role_b])
|
||||
# when
|
||||
result = roles.role_by_name("unknown")
|
||||
# then
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
class TestUnion(TestCase):
|
||||
|
||||
class TestRolesSetUnion(NoSocketsTestCase):
|
||||
def test_distinct_sets(self):
|
||||
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_2 = DiscordRoles([ROLE_CHARLIE, ROLE_MIKE])
|
||||
roles_3 = roles_1.union(roles_2)
|
||||
expected = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE])
|
||||
self.assertEqual(roles_3, expected)
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
roles_1 = RolesSet([role_a])
|
||||
roles_2 = RolesSet([role_b])
|
||||
# when
|
||||
result = roles_1.union(roles_2)
|
||||
# then
|
||||
self.assertEqual(result, RolesSet([role_a, role_b]))
|
||||
|
||||
def test_overlapping_sets(self):
|
||||
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_2 = DiscordRoles([ROLE_BRAVO, ROLE_MIKE])
|
||||
roles_3 = roles_1.union(roles_2)
|
||||
expected = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO, ROLE_MIKE])
|
||||
self.assertEqual(roles_3, expected)
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
role_c = create_role()
|
||||
roles_1 = RolesSet([role_a, role_b])
|
||||
roles_2 = RolesSet([role_b, role_c])
|
||||
# when
|
||||
result = roles_1.union(roles_2)
|
||||
self.assertEqual(result, RolesSet([role_a, role_b, role_c]))
|
||||
|
||||
def test_identical_sets(self):
|
||||
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_2 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_3 = roles_1.union(roles_2)
|
||||
expected = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
self.assertEqual(roles_3, expected)
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
roles_1 = RolesSet([role_a, role_b])
|
||||
roles_2 = RolesSet([role_a, role_b])
|
||||
# when
|
||||
result = roles_1.union(roles_2)
|
||||
self.assertEqual(result, RolesSet([role_a, role_b]))
|
||||
|
||||
|
||||
class TestDifference(TestCase):
|
||||
|
||||
class TestRolesSetDifference(NoSocketsTestCase):
|
||||
def test_distinct_sets(self):
|
||||
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_2 = DiscordRoles([ROLE_CHARLIE, ROLE_MIKE])
|
||||
roles_3 = roles_1.difference(roles_2)
|
||||
expected = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
self.assertEqual(roles_3, expected)
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
role_c = create_role()
|
||||
role_d = create_role()
|
||||
roles_1 = RolesSet([role_a, role_b])
|
||||
roles_2 = RolesSet([role_c, role_d])
|
||||
# when
|
||||
result = roles_1.difference(roles_2)
|
||||
# then
|
||||
self.assertEqual(result, RolesSet([role_a, role_b]))
|
||||
|
||||
def test_overlapping_sets(self):
|
||||
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_2 = DiscordRoles([ROLE_BRAVO, ROLE_MIKE])
|
||||
roles_3 = roles_1.difference(roles_2)
|
||||
expected = DiscordRoles([ROLE_ALPHA])
|
||||
self.assertEqual(roles_3, expected)
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
role_c = create_role()
|
||||
roles_1 = RolesSet([role_a, role_b])
|
||||
roles_2 = RolesSet([role_b, role_c])
|
||||
# when
|
||||
result = roles_1.difference(roles_2)
|
||||
# then
|
||||
self.assertEqual(result, RolesSet([role_a]))
|
||||
|
||||
def test_identical_sets(self):
|
||||
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_2 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
|
||||
roles_3 = roles_1.difference(roles_2)
|
||||
expected = DiscordRoles([])
|
||||
self.assertEqual(roles_3, expected)
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
roles_1 = RolesSet([role_a, role_b])
|
||||
roles_2 = RolesSet([role_a, role_b])
|
||||
# when
|
||||
result = roles_1.difference(roles_2)
|
||||
# then
|
||||
self.assertEqual(result, RolesSet([]))
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest import TestCase
|
||||
|
||||
from ..models import Guild, GuildMember, Role, User
|
||||
from .factories import create_guild, create_guild_member, create_role, create_user
|
||||
|
||||
|
||||
def _fetch_example_objects() -> dict:
|
||||
path = Path(__file__).parent / "example_objects.json"
|
||||
with path.open("r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
class TestUser(TestCase):
|
||||
def test_should_create_new_object(self):
|
||||
# when
|
||||
obj = User(id="42", username=123, discriminator=456)
|
||||
# then
|
||||
self.assertEqual(obj.id, 42)
|
||||
self.assertEqual(obj.username, "123")
|
||||
self.assertTrue(obj.discriminator, "456")
|
||||
|
||||
def test_should_create_from_dict(self):
|
||||
# given
|
||||
data = example_objects["users"]["80351110224678912"]
|
||||
# when
|
||||
obj = User.from_dict(data)
|
||||
# then
|
||||
self.assertEqual(obj.id, 80351110224678912)
|
||||
self.assertEqual(obj.username, "Nelly")
|
||||
self.assertEqual(obj.discriminator, "1337")
|
||||
|
||||
|
||||
class TestRole(TestCase):
|
||||
def test_should_create_new_object_with_defaults(self):
|
||||
# when
|
||||
obj = Role(id="42", name="x" * 110)
|
||||
# then
|
||||
self.assertEqual(obj.id, 42)
|
||||
self.assertEqual(obj.name, "x" * 100)
|
||||
self.assertFalse(obj.managed)
|
||||
|
||||
def test_should_create_new_object(self):
|
||||
# when
|
||||
obj = Role(id=42, name="name", managed=1)
|
||||
# then
|
||||
self.assertEqual(obj.id, 42)
|
||||
self.assertEqual(obj.name, "name")
|
||||
self.assertTrue(obj.managed)
|
||||
|
||||
def test_should_create_from_dict(self):
|
||||
# given
|
||||
data = example_objects["roles"]["41771983423143936"]
|
||||
# when
|
||||
obj = Role.from_dict(data)
|
||||
# then
|
||||
self.assertEqual(obj.id, 41771983423143936)
|
||||
self.assertEqual(obj.name, "WE DEM BOYZZ!!!!!!")
|
||||
self.assertFalse(obj.managed)
|
||||
|
||||
def test_should_convert_to_dict(self):
|
||||
# given
|
||||
role = create_role(id=42, name="Special Name", managed=True)
|
||||
# when/then
|
||||
self.assertDictEqual(
|
||||
role.asdict(), {"id": 42, "name": "Special Name", "managed": True}
|
||||
)
|
||||
|
||||
def test_sanitize_role_name(self):
|
||||
# given
|
||||
role_name_input = "x" * 110
|
||||
role_name_expected = "x" * 100
|
||||
# when
|
||||
result = Role.sanitize_name(role_name_input)
|
||||
# then
|
||||
self.assertEqual(result, role_name_expected)
|
||||
|
||||
|
||||
class TestGuild(TestCase):
|
||||
def test_should_create_new_object(self):
|
||||
# given
|
||||
role_a = create_role()
|
||||
# when
|
||||
obj = Guild(id="42", name=123, roles=[role_a])
|
||||
# then
|
||||
self.assertEqual(obj.id, 42)
|
||||
self.assertEqual(obj.name, "123")
|
||||
self.assertEqual(obj.roles, frozenset([role_a]))
|
||||
|
||||
def test_should_create_from_dict(self):
|
||||
# given
|
||||
data = example_objects["guilds"]["2909267986263572999"]
|
||||
# when
|
||||
obj = Guild.from_dict(data)
|
||||
# then
|
||||
self.assertEqual(obj.id, 2909267986263572999)
|
||||
self.assertEqual(obj.name, "Mason's Test Server")
|
||||
(first_role,) = obj.roles
|
||||
self.assertEqual(first_role.id, 2909267986263572999)
|
||||
|
||||
def test_should_raise_error_when_role_type_is_wrong(self):
|
||||
with self.assertRaises(TypeError):
|
||||
create_guild(roles=[create_role(), "invalid"])
|
||||
|
||||
|
||||
class TestGuildMember(TestCase):
|
||||
def test_should_create_new_object(self):
|
||||
# given
|
||||
user = create_user()
|
||||
# when
|
||||
obj = GuildMember(user=user, nick="x" * 40, roles=[1, 2])
|
||||
# then
|
||||
self.assertEqual(obj.user, user)
|
||||
self.assertEqual(obj.nick, "x" * 32)
|
||||
self.assertEqual(obj.roles, frozenset([1, 2]))
|
||||
|
||||
def test_should_create_from_dict_empty(self):
|
||||
# given
|
||||
data = example_objects["guildMembers"]["1"]
|
||||
# when
|
||||
obj = GuildMember.from_dict(data)
|
||||
# then
|
||||
self.assertIsNone(obj.user)
|
||||
self.assertSetEqual(obj.roles, set())
|
||||
self.assertIsNone(obj.nick)
|
||||
|
||||
def test_should_create_from_dict_full(self):
|
||||
# given
|
||||
data = example_objects["guildMembers"]["2"]
|
||||
# when
|
||||
obj = GuildMember.from_dict(data)
|
||||
# then
|
||||
self.assertEqual(obj.user.username, "Nelly")
|
||||
self.assertSetEqual(obj.roles, {197150972374548480, 41771983423143936})
|
||||
self.assertEqual(obj.nick, "Nelly the great")
|
||||
|
||||
def test_should_raise_error_when_user_type_is_wrong(self):
|
||||
with self.assertRaises(TypeError):
|
||||
create_guild_member(user="invalid")
|
||||
|
||||
def test_should_raise_error_when_role_type_is_wrong(self):
|
||||
with self.assertRaises(TypeError):
|
||||
GuildMember(roles=[1, 2, "invalid"])
|
||||
|
||||
def test_sanitize_nick(self):
|
||||
# given
|
||||
nick_input = "x" * 40
|
||||
nick_expected = "x" * 32
|
||||
# when
|
||||
result = GuildMember.sanitize_nick(nick_input)
|
||||
# then
|
||||
self.assertEqual(result, nick_expected)
|
||||
|
||||
|
||||
example_objects = _fetch_example_objects()
|
||||
@@ -1,30 +1,33 @@
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from requests_oauthlib import OAuth2Session
|
||||
from requests.exceptions import HTTPError
|
||||
from requests_oauthlib import OAuth2Session
|
||||
|
||||
from django.contrib.auth.models import User, Group
|
||||
from django.contrib.auth.models import Group, User
|
||||
from django.db import models
|
||||
from django.utils.timezone import now
|
||||
|
||||
from allianceauth.services.hooks import NameFormatter
|
||||
|
||||
from . import __title__
|
||||
from .app_settings import (
|
||||
DISCORD_APP_ID,
|
||||
DISCORD_APP_SECRET,
|
||||
DISCORD_BOT_TOKEN,
|
||||
DISCORD_CALLBACK_URL,
|
||||
DISCORD_GUILD_ID,
|
||||
DISCORD_SYNC_NAMES
|
||||
DISCORD_SYNC_NAMES,
|
||||
)
|
||||
from .core import calculate_roles_for_user, create_bot_client
|
||||
from .core import group_to_role as core_group_to_role
|
||||
from .core import server_name as core_server_name
|
||||
from .core import user_formatted_nick
|
||||
from .discord_client import (
|
||||
DISCORD_OAUTH_BASE_URL,
|
||||
DISCORD_OAUTH_TOKEN_URL,
|
||||
DiscordApiBackoff,
|
||||
DiscordClient,
|
||||
)
|
||||
from .discord_client import DiscordClient
|
||||
from .discord_client.exceptions import DiscordClientException, DiscordApiBackoff
|
||||
from .discord_client.helpers import match_or_create_roles_from_names
|
||||
from .utils import LoggerAddTag
|
||||
|
||||
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
|
||||
@@ -56,79 +59,68 @@ class DiscordUserManager(models.Manager):
|
||||
Returns: True on success, else False or raises exception
|
||||
"""
|
||||
try:
|
||||
nickname = self.user_formatted_nick(user) if DISCORD_SYNC_NAMES else None
|
||||
group_names = self.user_group_names(user)
|
||||
nickname = user_formatted_nick(user) if DISCORD_SYNC_NAMES else None
|
||||
access_token = self._exchange_auth_code_for_token(authorization_code)
|
||||
user_client = DiscordClient(access_token, is_rate_limited=is_rate_limited)
|
||||
discord_user = user_client.current_user()
|
||||
user_id = discord_user['id']
|
||||
bot_client = self._bot_client(is_rate_limited=is_rate_limited)
|
||||
|
||||
if group_names:
|
||||
role_ids = match_or_create_roles_from_names(
|
||||
client=bot_client,
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
role_names=group_names
|
||||
).ids()
|
||||
else:
|
||||
role_ids = None
|
||||
|
||||
created = bot_client.add_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
user_id=user_id,
|
||||
access_token=access_token,
|
||||
role_ids=role_ids,
|
||||
nick=nickname
|
||||
bot_client = create_bot_client(is_rate_limited=is_rate_limited)
|
||||
roles, changed = calculate_roles_for_user(
|
||||
user=user, client=bot_client, discord_uid=discord_user.id
|
||||
)
|
||||
if created is not False:
|
||||
if created is None:
|
||||
logger.debug(
|
||||
"User %s with Discord ID %s is already a member. Forcing a Refresh",
|
||||
if changed is None:
|
||||
# Handle new member
|
||||
created = bot_client.add_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
user_id=discord_user.id,
|
||||
access_token=access_token,
|
||||
role_ids=list(roles.ids()),
|
||||
nick=nickname
|
||||
)
|
||||
if not created:
|
||||
logger.warning(
|
||||
"Failed to add user %s with Discord ID %s to Discord server",
|
||||
user,
|
||||
user_id,
|
||||
discord_user.id,
|
||||
)
|
||||
|
||||
# Force an update cause the discord API won't do it for us.
|
||||
if role_ids:
|
||||
role_ids = list(role_ids)
|
||||
|
||||
updated = bot_client.modify_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
user_id=user_id,
|
||||
role_ids=role_ids,
|
||||
nick=nickname
|
||||
)
|
||||
|
||||
if not updated:
|
||||
# Could not update the new user so fail.
|
||||
logger.warning(
|
||||
"Failed to add user %s with Discord ID %s to Discord server",
|
||||
user,
|
||||
user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
self.update_or_create(
|
||||
user=user,
|
||||
defaults={
|
||||
'uid': user_id,
|
||||
'username': discord_user['username'][:32],
|
||||
'discriminator': discord_user['discriminator'][:4],
|
||||
'activated': now()
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
"Added user %s with Discord ID %s to Discord server", user, user_id
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to add user %s with Discord ID %s to Discord server",
|
||||
# Handle existing member
|
||||
logger.debug(
|
||||
"User %s with Discord ID %s is already a member. Forcing a Refresh",
|
||||
user,
|
||||
user_id,
|
||||
discord_user.id,
|
||||
)
|
||||
return False
|
||||
# Force an update cause the discord API won't do it for us.
|
||||
updated = bot_client.modify_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
user_id=discord_user.id,
|
||||
role_ids=list(roles.ids()),
|
||||
nick=nickname
|
||||
)
|
||||
if not updated:
|
||||
# Could not update the new user so fail.
|
||||
logger.warning(
|
||||
"Failed to add user %s with Discord ID %s to Discord server",
|
||||
user,
|
||||
discord_user.id,
|
||||
)
|
||||
return False
|
||||
|
||||
self.update_or_create(
|
||||
user=user,
|
||||
defaults={
|
||||
'uid': discord_user.id,
|
||||
'username': discord_user.username[:32],
|
||||
'discriminator': discord_user.discriminator[:4],
|
||||
'activated': now()
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
"Added user %s with Discord ID %s to Discord server",
|
||||
user,
|
||||
discord_user.id
|
||||
)
|
||||
return True
|
||||
|
||||
except (HTTPError, ConnectionError, DiscordApiBackoff) as ex:
|
||||
logger.exception(
|
||||
@@ -136,31 +128,6 @@ class DiscordUserManager(models.Manager):
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def user_formatted_nick(user: User) -> str:
|
||||
"""returns the name of the given users main character with name formatting
|
||||
or None if user has no main
|
||||
"""
|
||||
from .auth_hooks import DiscordService
|
||||
|
||||
if user.profile.main_character:
|
||||
return NameFormatter(DiscordService(), user).format_name()
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def user_group_names(user: User, state_name: str = None) -> list:
|
||||
"""returns list of group names plus state the given user is a member of"""
|
||||
if not state_name:
|
||||
state_name = user.profile.state.name
|
||||
group_names = (
|
||||
[group.name for group in user.groups.all()] + [state_name]
|
||||
)
|
||||
logger.debug(
|
||||
"Group names for roles updates of user %s are: %s", user, group_names
|
||||
)
|
||||
return group_names
|
||||
|
||||
def user_has_account(self, user: User) -> bool:
|
||||
"""Returns True if the user has an Discord account, else False
|
||||
|
||||
@@ -178,60 +145,41 @@ class DiscordUserManager(models.Manager):
|
||||
'permissions': str(cls.BOT_PERMISSIONS)
|
||||
|
||||
})
|
||||
return f'{DiscordClient.OAUTH_BASE_URL}?{params}'
|
||||
return f'{DISCORD_OAUTH_BASE_URL}?{params}'
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_redirect_url(cls) -> str:
|
||||
oauth = OAuth2Session(
|
||||
DISCORD_APP_ID, redirect_uri=DISCORD_CALLBACK_URL, scope=cls.SCOPES
|
||||
)
|
||||
url, state = oauth.authorization_url(DiscordClient.OAUTH_BASE_URL)
|
||||
url, _ = oauth.authorization_url(DISCORD_OAUTH_BASE_URL)
|
||||
return url
|
||||
|
||||
@staticmethod
|
||||
def _exchange_auth_code_for_token(authorization_code: str) -> str:
|
||||
oauth = OAuth2Session(DISCORD_APP_ID, redirect_uri=DISCORD_CALLBACK_URL)
|
||||
token = oauth.fetch_token(
|
||||
DiscordClient.OAUTH_TOKEN_URL,
|
||||
DISCORD_OAUTH_TOKEN_URL,
|
||||
client_secret=DISCORD_APP_SECRET,
|
||||
code=authorization_code
|
||||
)
|
||||
logger.debug("Received token from OAuth")
|
||||
return token['access_token']
|
||||
|
||||
@classmethod
|
||||
def server_name(cls, use_cache: bool = True) -> str:
|
||||
"""returns the name of the current Discord server
|
||||
or an empty string if the name could not be retrieved
|
||||
@staticmethod
|
||||
def group_to_role(group: Group) -> dict:
|
||||
"""Fetch the Discord role matching the given Django group by name.
|
||||
|
||||
Params:
|
||||
- use_cache: When set False will force an API call to get the server name
|
||||
Returns:
|
||||
- Discord role as dict
|
||||
- empty dict if no matching role found
|
||||
"""
|
||||
try:
|
||||
server_name = cls._bot_client().guild_name(
|
||||
guild_id=DISCORD_GUILD_ID, use_cache=use_cache
|
||||
)
|
||||
except (HTTPError, DiscordClientException):
|
||||
server_name = ""
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Unexpected error when trying to retrieve the server name from Discord",
|
||||
exc_info=True
|
||||
)
|
||||
server_name = ""
|
||||
|
||||
return server_name
|
||||
|
||||
@classmethod
|
||||
def group_to_role(cls, group: Group) -> dict:
|
||||
"""returns the Discord role matching the given Django group by name
|
||||
or an empty dict() if no matching role exist
|
||||
"""
|
||||
return cls._bot_client().match_role_from_name(
|
||||
guild_id=DISCORD_GUILD_ID, role_name=group.name
|
||||
)
|
||||
role = core_group_to_role(group)
|
||||
return role.asdict() if role else dict()
|
||||
|
||||
@staticmethod
|
||||
def _bot_client(is_rate_limited: bool = True) -> DiscordClient:
|
||||
"""returns a bot client for access to the Discord API"""
|
||||
return DiscordClient(DISCORD_BOT_TOKEN, is_rate_limited=is_rate_limited)
|
||||
def server_name(use_cache: bool = True) -> str:
|
||||
"""Fetches the name of the current Discord server.
|
||||
This method is kept to ensure backwards compatibility of this API.
|
||||
"""
|
||||
return core_server_name(use_cache)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -6,13 +7,17 @@ from django.contrib.auth.models import User
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy
|
||||
|
||||
from allianceauth.groupmanagement.models import ReservedGroupName
|
||||
from allianceauth.notifications import notify
|
||||
|
||||
from . import __title__
|
||||
from .app_settings import DISCORD_GUILD_ID
|
||||
from .discord_client import DiscordApiBackoff, DiscordClient, DiscordRoles
|
||||
from .discord_client.helpers import match_or_create_roles_from_names
|
||||
from .core import (
|
||||
create_bot_client,
|
||||
default_bot_client,
|
||||
calculate_roles_for_user,
|
||||
user_formatted_nick
|
||||
)
|
||||
from .discord_client import DiscordApiBackoff
|
||||
from .managers import DiscordUserManager
|
||||
from .utils import LoggerAddTag
|
||||
|
||||
@@ -21,14 +26,13 @@ logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
|
||||
class DiscordUser(models.Model):
|
||||
|
||||
USER_RELATED_NAME = 'discord'
|
||||
"""The Discord user account of an Auth user."""
|
||||
|
||||
user = models.OneToOneField(
|
||||
User,
|
||||
primary_key=True,
|
||||
on_delete=models.CASCADE,
|
||||
related_name=USER_RELATED_NAME,
|
||||
related_name='discord',
|
||||
help_text='Auth user owning this Discord account'
|
||||
)
|
||||
uid = models.BigIntegerField(
|
||||
@@ -80,24 +84,21 @@ class DiscordUser(models.Model):
|
||||
- False on error or raises exception
|
||||
"""
|
||||
if not nickname:
|
||||
nickname = DiscordUser.objects.user_formatted_nick(self.user)
|
||||
if nickname:
|
||||
client = DiscordUser.objects._bot_client()
|
||||
success = client.modify_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
user_id=self.uid,
|
||||
nick=nickname
|
||||
)
|
||||
if success:
|
||||
logger.info('Nickname for %s has been updated', self.user)
|
||||
else:
|
||||
logger.warning('Failed to update nickname for %s', self.user)
|
||||
return success
|
||||
|
||||
else:
|
||||
nickname = user_formatted_nick(self.user)
|
||||
if not nickname:
|
||||
return False
|
||||
success = default_bot_client.modify_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
user_id=self.uid,
|
||||
nick=nickname
|
||||
)
|
||||
if success:
|
||||
logger.info('Nickname for %s has been updated', self.user)
|
||||
else:
|
||||
logger.warning('Failed to update nickname for %s', self.user)
|
||||
return success
|
||||
|
||||
def update_groups(self, state_name: str = None) -> bool:
|
||||
def update_groups(self, state_name: str = None) -> Optional[bool]:
|
||||
"""update groups for a user based on his current group memberships.
|
||||
Will add or remove roles of a user as needed.
|
||||
|
||||
@@ -109,57 +110,18 @@ class DiscordUser(models.Model):
|
||||
- None if user is no longer a member of the Discord server
|
||||
- False on error or raises exception
|
||||
"""
|
||||
client = DiscordUser.objects._bot_client()
|
||||
member_roles = self._determine_member_roles(client)
|
||||
if member_roles is None:
|
||||
new_roles, is_changed = calculate_roles_for_user(
|
||||
user=self.user,
|
||||
client=default_bot_client,
|
||||
discord_uid=self.uid,
|
||||
state_name=state_name
|
||||
)
|
||||
if is_changed is None:
|
||||
logger.debug('User is not a member of this guild %s', self.user)
|
||||
return None
|
||||
return self._update_roles_if_needed(client, state_name, member_roles)
|
||||
|
||||
def _determine_member_roles(self, client: DiscordClient) -> DiscordRoles:
|
||||
"""Determine the roles of the current member / user."""
|
||||
member_info = client.guild_member(guild_id=DISCORD_GUILD_ID, user_id=self.uid)
|
||||
if member_info is None:
|
||||
return None # User is no longer a member
|
||||
guild_roles = DiscordRoles(client.guild_roles(guild_id=DISCORD_GUILD_ID))
|
||||
logger.debug('Current guild roles: %s', guild_roles.ids())
|
||||
if 'roles' in member_info:
|
||||
if not guild_roles.has_roles(member_info['roles']):
|
||||
guild_roles = DiscordRoles(
|
||||
client.guild_roles(guild_id=DISCORD_GUILD_ID, use_cache=False)
|
||||
)
|
||||
if not guild_roles.has_roles(member_info['roles']):
|
||||
raise RuntimeError(
|
||||
'Member {} has unknown roles: {}'.format(
|
||||
self.user,
|
||||
set(member_info['roles']).difference(guild_roles.ids())
|
||||
)
|
||||
)
|
||||
return guild_roles.subset(member_info['roles'])
|
||||
raise RuntimeError('member_info from %s is not valid' % self.user)
|
||||
|
||||
def _update_roles_if_needed(
|
||||
self, client: DiscordClient, state_name: str, member_roles: DiscordRoles
|
||||
) -> bool:
|
||||
"""Update the roles of this member/user if needed."""
|
||||
requested_roles = match_or_create_roles_from_names(
|
||||
client=client,
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
role_names=DiscordUser.objects.user_group_names(
|
||||
user=self.user, state_name=state_name
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
'Requested roles for user %s: %s', self.user, requested_roles.ids()
|
||||
)
|
||||
logger.debug('Current roles user %s: %s', self.user, member_roles.ids())
|
||||
reserved_role_names = ReservedGroupName.objects.values_list("name", flat=True)
|
||||
member_roles_reserved = member_roles.subset(role_names=reserved_role_names)
|
||||
member_roles_managed = member_roles.subset(managed_only=True)
|
||||
member_roles_persistent = member_roles_managed.union(member_roles_reserved)
|
||||
if requested_roles != member_roles.difference(member_roles_persistent):
|
||||
if is_changed:
|
||||
logger.debug('Need to update roles for user %s', self.user)
|
||||
new_roles = requested_roles.union(member_roles_persistent)
|
||||
success = client.modify_guild_member(
|
||||
success = default_bot_client.modify_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
user_id=self.uid,
|
||||
role_ids=list(new_roles.ids())
|
||||
@@ -172,41 +134,32 @@ class DiscordUser(models.Model):
|
||||
logger.info('No need to update roles for user %s', self.user)
|
||||
return True
|
||||
|
||||
def update_username(self) -> bool:
|
||||
def update_username(self) -> Optional[bool]:
|
||||
"""Updates the username incl. the discriminator
|
||||
from the Discord server and saves it
|
||||
|
||||
Returns:
|
||||
- True on success
|
||||
- None if user is no longer a member of the Discord server
|
||||
- False on error or raises exception
|
||||
"""
|
||||
client = DiscordUser.objects._bot_client()
|
||||
user_info = client.guild_member(guild_id=DISCORD_GUILD_ID, user_id=self.uid)
|
||||
if user_info is None:
|
||||
success = None
|
||||
elif (
|
||||
user_info
|
||||
and 'user' in user_info
|
||||
and 'username' in user_info['user']
|
||||
and 'discriminator' in user_info['user']
|
||||
):
|
||||
self.username = user_info['user']['username']
|
||||
self.discriminator = user_info['user']['discriminator']
|
||||
self.save()
|
||||
logger.info('Username for %s has been updated', self.user)
|
||||
success = True
|
||||
else:
|
||||
logger.warning('Failed to update username for %s', self.user)
|
||||
success = False
|
||||
return success
|
||||
member_info = default_bot_client.guild_member(
|
||||
guild_id=DISCORD_GUILD_ID, user_id=self.uid
|
||||
)
|
||||
if not member_info:
|
||||
logger.warning('%s: User not a guild member', self.user)
|
||||
return None
|
||||
self.username = member_info.user.username
|
||||
self.discriminator = member_info.user.discriminator
|
||||
self.save()
|
||||
logger.info('%s: Username has been updated', self.user)
|
||||
return True
|
||||
|
||||
def delete_user(
|
||||
self,
|
||||
notify_user: bool = False,
|
||||
is_rate_limited: bool = True,
|
||||
handle_api_exceptions: bool = False
|
||||
) -> bool:
|
||||
) -> Optional[bool]:
|
||||
"""Deletes the Discount user both on the server and locally
|
||||
|
||||
Params:
|
||||
@@ -221,7 +174,7 @@ class DiscordUser(models.Model):
|
||||
"""
|
||||
try:
|
||||
_user = self.user
|
||||
client = DiscordUser.objects._bot_client(is_rate_limited=is_rate_limited)
|
||||
client = create_bot_client(is_rate_limited=is_rate_limited)
|
||||
success = client.remove_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID, user_id=self.uid
|
||||
)
|
||||
@@ -241,15 +194,13 @@ class DiscordUser(models.Model):
|
||||
)
|
||||
logger.info('Account for user %s was deleted.', _user)
|
||||
return True
|
||||
else:
|
||||
logger.debug('Account for user %s was already deleted.', _user)
|
||||
return None
|
||||
logger.debug('Account for user %s was already deleted.', _user)
|
||||
return None
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
'Failed to remove user %s from the Discord server', _user
|
||||
)
|
||||
return False
|
||||
logger.warning(
|
||||
'Failed to remove user %s from the Discord server', _user
|
||||
)
|
||||
return False
|
||||
|
||||
except (HTTPError, ConnectionError, DiscordApiBackoff) as ex:
|
||||
if handle_api_exceptions:
|
||||
@@ -257,5 +208,4 @@ class DiscordUser(models.Model):
|
||||
'Failed to remove user %s from Discord server: %s',self.user, ex
|
||||
)
|
||||
return False
|
||||
else:
|
||||
raise ex
|
||||
raise ex
|
||||
|
||||
@@ -1,19 +1,6 @@
|
||||
from django.contrib.auth.models import Group
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from ..discord_client.tests import ( # noqa
|
||||
TEST_GUILD_ID,
|
||||
TEST_USER_ID,
|
||||
TEST_USER_NAME,
|
||||
TEST_USER_DISCRIMINATOR,
|
||||
create_role,
|
||||
ROLE_ALPHA,
|
||||
ROLE_BRAVO,
|
||||
ROLE_CHARLIE,
|
||||
ROLE_CHARLIE_2,
|
||||
ROLE_MIKE,
|
||||
ALL_ROLES,
|
||||
create_user_info
|
||||
)
|
||||
|
||||
DEFAULT_AUTH_GROUP = 'Member'
|
||||
MODULE_PATH = 'allianceauth.services.modules.discord'
|
||||
|
||||
31
allianceauth/services/modules/discord/tests/factories.py
Normal file
31
allianceauth/services/modules/discord/tests/factories.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from django.utils.timezone import now
|
||||
|
||||
from allianceauth.authentication.backends import StateBackend
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from ..discord_client.tests.factories import (
|
||||
TEST_USER_DISCRIMINATOR,
|
||||
TEST_USER_ID,
|
||||
TEST_USER_NAME,
|
||||
)
|
||||
from ..models import DiscordUser
|
||||
|
||||
|
||||
def create_user(**kwargs):
|
||||
params = {"username": TEST_USER_NAME}
|
||||
params.update(kwargs)
|
||||
username = StateBackend.iterate_username(params["username"])
|
||||
user = AuthUtils.create_user(username)
|
||||
return AuthUtils.add_permission_to_user_by_name("discord.access_discord", user)
|
||||
|
||||
|
||||
def create_discord_user(user=None, **kwargs):
|
||||
params = {
|
||||
"user": user or create_user(),
|
||||
"uid": TEST_USER_ID,
|
||||
"username": TEST_USER_NAME,
|
||||
"discriminator": TEST_USER_DISCRIMINATOR,
|
||||
"activated": now(),
|
||||
}
|
||||
params.update(kwargs)
|
||||
return DiscordUser.objects.create(**params)
|
||||
@@ -35,17 +35,17 @@ import logging
|
||||
from uuid import uuid1
|
||||
import random
|
||||
|
||||
from django.core.cache import caches
|
||||
from django.contrib.auth.models import User, Group
|
||||
|
||||
from allianceauth.services.modules.discord.models import DiscordUser
|
||||
from allianceauth.utils.cache import get_redis_client
|
||||
|
||||
logger = logging.getLogger('allianceauth')
|
||||
MAX_RUNS = 3
|
||||
|
||||
|
||||
def clear_cache():
|
||||
default_cache = caches['default']
|
||||
redis = default_cache.get_master_client()
|
||||
redis = get_redis_client()
|
||||
redis.flushall()
|
||||
logger.info('Cache flushed')
|
||||
|
||||
|
||||
@@ -1,26 +1,32 @@
|
||||
from django.test import TestCase, RequestFactory
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.admin.sites import AdminSite
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import RequestFactory
|
||||
from django.utils.timezone import now
|
||||
|
||||
from allianceauth.authentication.models import CharacterOwnership
|
||||
from allianceauth.eveonline.models import (
|
||||
EveCharacter, EveCorporationInfo, EveAllianceInfo
|
||||
EveAllianceInfo,
|
||||
EveCharacter,
|
||||
EveCorporationInfo,
|
||||
)
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
from ....admin import (
|
||||
MainAllianceFilter,
|
||||
MainCorporationsFilter,
|
||||
ServicesUserAdmin,
|
||||
user_main_organization,
|
||||
user_profile_pic,
|
||||
user_username,
|
||||
user_main_organization,
|
||||
ServicesUserAdmin,
|
||||
MainCorporationsFilter,
|
||||
MainAllianceFilter
|
||||
)
|
||||
from ..admin import DiscordUserAdmin
|
||||
from ..models import DiscordUser
|
||||
from . import MODULE_PATH
|
||||
|
||||
|
||||
class TestDataMixin(TestCase):
|
||||
class TestDataMixin(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -168,7 +174,7 @@ class TestDataMixin(TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestColumnRendering(TestDataMixin, TestCase):
|
||||
class TestColumnRendering(TestDataMixin, NoSocketsTestCase):
|
||||
|
||||
def test_user_profile_pic_u1(self):
|
||||
expected = (
|
||||
@@ -229,7 +235,7 @@ class TestColumnRendering(TestDataMixin, TestCase):
|
||||
# actions
|
||||
|
||||
|
||||
class TestFilters(TestDataMixin, TestCase):
|
||||
class TestFilters(TestDataMixin, NoSocketsTestCase):
|
||||
|
||||
def test_filter_main_corporations(self):
|
||||
|
||||
@@ -287,3 +293,16 @@ class TestFilters(TestDataMixin, TestCase):
|
||||
queryset = changelist.get_queryset(request)
|
||||
expected = [self.user_1.discord]
|
||||
self.assertSetEqual(set(queryset), set(expected))
|
||||
|
||||
|
||||
@patch(MODULE_PATH + ".admin.DiscordUser.delete_user")
|
||||
class TestDeleteQueryset(TestDataMixin, NoSocketsTestCase):
|
||||
def test_should_delete_all_objects(self, mock_delete_user):
|
||||
# given
|
||||
request = self.factory.get('/')
|
||||
request.user = self.user_1
|
||||
queryset = DiscordUser.objects.filter(user__in=[self.user_2, self.user_3])
|
||||
# when
|
||||
self.modeladmin.delete_queryset(request, queryset)
|
||||
# then
|
||||
self.assertEqual(mock_delete_user.call_count, 2)
|
||||
|
||||
16
allianceauth/services/modules/discord/tests/test_api.py
Normal file
16
allianceauth/services/modules/discord/tests/test_api.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
from ..api import discord_guild_id
|
||||
from . import MODULE_PATH
|
||||
|
||||
|
||||
class TestDiscordGuildId(NoSocketsTestCase):
|
||||
@patch(MODULE_PATH + ".api.DISCORD_GUILD_ID", "123")
|
||||
def test_should_return_guild_id_when_configured(self):
|
||||
self.assertEqual(discord_guild_id(), 123)
|
||||
|
||||
@patch(MODULE_PATH + ".api.DISCORD_GUILD_ID", "")
|
||||
def test_should_return_none_when_not_configured(self):
|
||||
self.assertIsNone(discord_guild_id())
|
||||
@@ -1,23 +1,23 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TestCase, RequestFactory
|
||||
from django.test import RequestFactory
|
||||
from django.test.utils import override_settings
|
||||
|
||||
from allianceauth.notifications.models import Notification
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
from . import TEST_USER_NAME, TEST_USER_ID, add_permissions_to_members, MODULE_PATH
|
||||
from ..auth_hooks import DiscordService
|
||||
from ..discord_client import DiscordClient
|
||||
from ..discord_client.tests.factories import TEST_USER_ID, TEST_USER_NAME
|
||||
from ..models import DiscordUser
|
||||
from ..utils import set_logger_to_file
|
||||
|
||||
from . import MODULE_PATH, add_permissions_to_members
|
||||
|
||||
logger = set_logger_to_file(MODULE_PATH + '.auth_hooks', __file__)
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
class TestDiscordService(TestCase):
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
|
||||
class TestDiscordService(NoSocketsTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.member = AuthUtils.create_member(TEST_USER_NAME)
|
||||
@@ -64,11 +64,11 @@ class TestDiscordService(TestCase):
|
||||
|
||||
@patch(MODULE_PATH + '.models.notify')
|
||||
@patch(MODULE_PATH + '.tasks.DiscordUser')
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
@patch(MODULE_PATH + '.models.create_bot_client')
|
||||
def test_validate_user(
|
||||
self, mock_DiscordClient, mock_DiscordUser, mock_notify
|
||||
self, mock_create_bot_client, mock_DiscordUser, mock_notify
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = True
|
||||
|
||||
# Test member is not deleted
|
||||
service = self.service()
|
||||
@@ -92,33 +92,38 @@ class TestDiscordService(TestCase):
|
||||
service.sync_nicknames_bulk([self.member])
|
||||
self.assertTrue(mock_update_nicknames_bulk.delay.called)
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
def test_delete_user_is_member(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
|
||||
@patch(MODULE_PATH + '.models.create_bot_client')
|
||||
def test_delete_user_is_member(self, mock_create_bot_client):
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = True
|
||||
service = self.service()
|
||||
# when
|
||||
service.delete_user(self.member, notify_user=True)
|
||||
|
||||
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
|
||||
# then
|
||||
self.assertTrue(mock_create_bot_client.return_value.remove_guild_member.called)
|
||||
self.assertFalse(DiscordUser.objects.filter(user=self.member).exists())
|
||||
self.assertTrue(Notification.objects.filter(user=self.member).exists())
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
def test_delete_user_is_not_member(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
|
||||
@patch(MODULE_PATH + '.models.create_bot_client')
|
||||
def test_delete_user_is_not_member(self, mock_create_bot_client):
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = True
|
||||
service = self.service()
|
||||
# when
|
||||
service.delete_user(self.none_member)
|
||||
# then
|
||||
self.assertFalse(mock_create_bot_client.return_value.remove_guild_member.called)
|
||||
|
||||
self.assertFalse(mock_DiscordClient.return_value.remove_guild_member.called)
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
def test_render_services_ctrl_with_username(self, mock_DiscordClient):
|
||||
@patch(MODULE_PATH + '.auth_hooks.server_name')
|
||||
def test_render_services_ctrl_with_username(self, mock_server_name):
|
||||
# given
|
||||
mock_server_name.return_value = "My server"
|
||||
service = self.service()
|
||||
request = self.factory.get('/services/')
|
||||
request.user = self.member
|
||||
|
||||
# when
|
||||
response = service.render_services_ctrl(request)
|
||||
# then
|
||||
self.assertTemplateUsed(service.service_ctrl_template)
|
||||
self.assertIn('/discord/reset/', response)
|
||||
self.assertIn('/discord/deactivate/', response)
|
||||
@@ -130,15 +135,18 @@ class TestDiscordService(TestCase):
|
||||
response = service.render_services_ctrl(request)
|
||||
self.assertIn('/discord/activate/', response)
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
def test_render_services_ctrl_wo_username(self, mock_DiscordClient):
|
||||
@patch(MODULE_PATH + '.auth_hooks.server_name')
|
||||
def test_render_services_ctrl_wo_username(self, mock_server_name):
|
||||
# given
|
||||
mock_server_name.return_value = "My server"
|
||||
my_member = AuthUtils.create_member('John Doe')
|
||||
DiscordUser.objects.create(user=my_member, uid=111222333)
|
||||
service = self.service()
|
||||
request = self.factory.get('/services/')
|
||||
request.user = my_member
|
||||
|
||||
# when
|
||||
response = service.render_services_ctrl(request)
|
||||
# then
|
||||
self.assertTemplateUsed(service.service_ctrl_template)
|
||||
self.assertIn('/discord/reset/', response)
|
||||
self.assertIn('/discord/deactivate/', response)
|
||||
|
||||
221
allianceauth/services/modules/discord/tests/test_core.py
Normal file
221
allianceauth/services/modules/discord/tests/test_core.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.contrib.auth.models import Group
|
||||
|
||||
from allianceauth.groupmanagement.models import ReservedGroupName
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
from ..core import (
|
||||
_user_group_names,
|
||||
calculate_roles_for_user,
|
||||
group_to_role,
|
||||
server_name,
|
||||
user_formatted_nick,
|
||||
)
|
||||
from ..discord_client import DiscordApiBackoff, DiscordClient, RolesSet
|
||||
from ..discord_client.tests.factories import TEST_USER_NAME, create_role
|
||||
from . import MODULE_PATH, TEST_MAIN_ID, TEST_MAIN_NAME
|
||||
|
||||
|
||||
class TestUserGroupNames(NoSocketsTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.group_1 = Group.objects.create(name="Group 1")
|
||||
cls.group_2 = Group.objects.create(name="Group 2")
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_member(TEST_USER_NAME)
|
||||
|
||||
def test_return_groups_and_state_names_for_user(self):
|
||||
self.user.groups.add(self.group_1)
|
||||
result = _user_group_names(self.user)
|
||||
expected = ["Group 1", "Member"]
|
||||
self.assertSetEqual(set(result), set(expected))
|
||||
|
||||
def test_return_state_only_if_user_has_no_groups(self):
|
||||
result = _user_group_names(self.user)
|
||||
expected = ["Member"]
|
||||
self.assertSetEqual(set(result), set(expected))
|
||||
|
||||
|
||||
class TestUserFormattedNick(NoSocketsTestCase):
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
|
||||
def test_return_nick_when_user_has_main(self):
|
||||
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
|
||||
result = user_formatted_nick(self.user)
|
||||
expected = TEST_MAIN_NAME
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_return_none_if_user_has_no_main(self):
|
||||
result = user_formatted_nick(self.user)
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + ".core.default_bot_client", spec=True)
|
||||
class TestRoleForGroup(NoSocketsTestCase):
|
||||
def test_return_role_if_found(self, mock_bot_client):
|
||||
# given
|
||||
role = create_role(name="alpha")
|
||||
mock_bot_client.match_role_from_name.side_effect = (
|
||||
lambda guild_id, role_name: role if role.name == role_name else None
|
||||
)
|
||||
group = Group.objects.create(name="alpha")
|
||||
# when/then
|
||||
self.assertEqual(group_to_role(group), role)
|
||||
|
||||
def test_return_empty_dict_if_not_found(self, mock_bot_client):
|
||||
# given
|
||||
role = create_role(name="alpha")
|
||||
mock_bot_client.match_role_from_name.side_effect = (
|
||||
lambda guild_id, role_name: role if role.name == role_name else None
|
||||
)
|
||||
group = Group.objects.create(name="unknown")
|
||||
# when/then
|
||||
self.assertIsNone(group_to_role(group))
|
||||
|
||||
|
||||
@patch(MODULE_PATH + ".core.default_bot_client", spec=True)
|
||||
@patch(MODULE_PATH + ".core.logger", spec=True)
|
||||
class TestServerName(NoSocketsTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
|
||||
def test_returns_name_when_api_returns_it(self, mock_logger, mock_bot_client):
|
||||
# given
|
||||
my_server_name = "El Dorado"
|
||||
mock_bot_client.guild_name.return_value = my_server_name
|
||||
# when
|
||||
self.assertEqual(server_name(), my_server_name)
|
||||
# then
|
||||
self.assertFalse(mock_logger.warning.called)
|
||||
|
||||
def test_returns_empty_string_when_api_throws_http_error(
|
||||
self, mock_logger, mock_bot_client
|
||||
):
|
||||
mock_exception = HTTPError("Test exception")
|
||||
mock_exception.response = Mock(**{"status_code": 440})
|
||||
mock_bot_client.guild_name.side_effect = mock_exception
|
||||
|
||||
self.assertEqual(server_name(), "")
|
||||
self.assertFalse(mock_logger.warning.called)
|
||||
|
||||
def test_returns_empty_string_when_api_throws_service_error(
|
||||
self, mock_logger, mock_bot_client
|
||||
):
|
||||
mock_bot_client.guild_name.side_effect = DiscordApiBackoff(1000)
|
||||
|
||||
self.assertEqual(server_name(), "")
|
||||
self.assertFalse(mock_logger.warning.called)
|
||||
|
||||
def test_returns_empty_string_when_api_throws_unexpected_error(
|
||||
self, mock_logger, mock_bot_client
|
||||
):
|
||||
mock_bot_client.guild_name.side_effect = RuntimeError
|
||||
|
||||
self.assertEqual(server_name(), "")
|
||||
self.assertTrue(mock_logger.warning.called)
|
||||
|
||||
|
||||
class TestCalculateRolesForUser(NoSocketsTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
|
||||
def test_should_return_roles_for_new_member(self):
|
||||
# given
|
||||
roles = RolesSet([create_role()])
|
||||
my_client = Mock(spec=DiscordClient)
|
||||
my_client.guild_member_roles.return_value = RolesSet([])
|
||||
my_client.match_or_create_roles_from_names_2.return_value = roles
|
||||
# when
|
||||
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
|
||||
# then
|
||||
self.assertTrue(changed)
|
||||
self.assertEqual(roles_calculated, roles)
|
||||
|
||||
def test_should_return_changed_roles_for_existing_member(self):
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
roles_current = RolesSet([role_a])
|
||||
roles_matching = RolesSet([role_a, role_b])
|
||||
my_client = Mock(spec=DiscordClient)
|
||||
my_client.guild_member_roles.return_value = roles_current
|
||||
my_client.match_or_create_roles_from_names_2.return_value = roles_matching
|
||||
# when
|
||||
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
|
||||
# then
|
||||
self.assertTrue(changed)
|
||||
self.assertEqual(roles_calculated, roles_matching)
|
||||
|
||||
def test_should_indicate_when_roles_are_unchanged(self):
|
||||
# given
|
||||
role_a = create_role()
|
||||
roles_current = RolesSet([role_a])
|
||||
roles_matching = RolesSet([role_a])
|
||||
my_client = Mock(spec=DiscordClient)
|
||||
my_client.guild_member_roles.return_value = roles_current
|
||||
my_client.match_or_create_roles_from_names_2.return_value = roles_matching
|
||||
# when
|
||||
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
|
||||
# then
|
||||
self.assertFalse(changed)
|
||||
self.assertEqual(roles_calculated, roles_matching)
|
||||
|
||||
def test_should_indicate_when_user_is_no_guild_member(self):
|
||||
# given
|
||||
role_a = create_role()
|
||||
roles_matching = RolesSet([role_a])
|
||||
my_client = Mock(spec=DiscordClient)
|
||||
my_client.guild_member_roles.return_value = None
|
||||
my_client.match_or_create_roles_from_names_2.return_value = roles_matching
|
||||
# when
|
||||
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
|
||||
# then
|
||||
self.assertIsNone(changed)
|
||||
self.assertEqual(roles_calculated, roles_matching)
|
||||
|
||||
def test_should_preserve_managed_roles_for_existing_member(self):
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
role_m = create_role(managed=True)
|
||||
roles_current = RolesSet([role_a, role_m])
|
||||
roles_matching = RolesSet([role_b])
|
||||
my_client = Mock(spec=DiscordClient)
|
||||
my_client.guild_member_roles.return_value = roles_current
|
||||
my_client.match_or_create_roles_from_names_2.return_value = roles_matching
|
||||
# when
|
||||
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
|
||||
# then
|
||||
self.assertTrue(changed)
|
||||
self.assertEqual(roles_calculated, RolesSet([role_b, role_m]))
|
||||
|
||||
def test_should_preserve_reserved_roles_for_existing_member(self):
|
||||
# given
|
||||
role_a = create_role()
|
||||
role_b = create_role()
|
||||
role_c1 = create_role(name="charlie")
|
||||
role_c2 = create_role(name="Charlie")
|
||||
roles_current = RolesSet([role_a, role_c1, role_c2])
|
||||
roles_matching = RolesSet([role_b])
|
||||
my_client = Mock(spec=DiscordClient)
|
||||
my_client.guild_member_roles.return_value = roles_current
|
||||
my_client.match_or_create_roles_from_names_2.return_value = roles_matching
|
||||
ReservedGroupName.objects.create(
|
||||
name="charlie", reason="dummy", created_by="xyz"
|
||||
)
|
||||
# when
|
||||
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
|
||||
# then
|
||||
self.assertTrue(changed)
|
||||
self.assertEqual(roles_calculated, RolesSet([role_b, role_c1, role_c2]))
|
||||
@@ -4,54 +4,68 @@ Testing all components of the service, with the exception of the Discord API.
|
||||
|
||||
Please note that these tests require Redis and will flush it
|
||||
"""
|
||||
from collections import namedtuple
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
from unittest.mock import patch, Mock
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid1
|
||||
|
||||
from django_webtest import WebTest
|
||||
from requests.exceptions import HTTPError
|
||||
import requests_mock
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.contrib.auth.models import Group, User
|
||||
from django.core.cache import caches
|
||||
from django.shortcuts import reverse
|
||||
from django.test import TransactionTestCase, TestCase
|
||||
from django.test.utils import override_settings
|
||||
from django.test import TransactionTestCase, override_settings
|
||||
from django.urls import reverse
|
||||
from django_webtest import WebTest
|
||||
|
||||
from allianceauth.authentication.models import State
|
||||
from allianceauth.eveonline.models import EveCharacter
|
||||
from allianceauth.groupmanagement.models import ReservedGroupName
|
||||
from allianceauth.notifications.models import Notification
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from allianceauth.utils.cache import get_redis_client
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
from . import (
|
||||
TEST_GUILD_ID,
|
||||
TEST_USER_NAME,
|
||||
TEST_USER_ID,
|
||||
TEST_USER_DISCRIMINATOR,
|
||||
TEST_MAIN_NAME,
|
||||
TEST_MAIN_ID,
|
||||
MODULE_PATH,
|
||||
add_permissions_to_members,
|
||||
ROLE_ALPHA,
|
||||
ROLE_BRAVO,
|
||||
ROLE_CHARLIE,
|
||||
ROLE_MIKE,
|
||||
create_role,
|
||||
create_user_info
|
||||
)
|
||||
from ..discord_client.app_settings import DISCORD_API_BASE_URL
|
||||
from ..discord_client.exceptions import DiscordApiBackoff
|
||||
from ..models import DiscordUser
|
||||
from .. import tasks
|
||||
from ..core import create_bot_client
|
||||
from ..discord_client import DiscordApiBackoff
|
||||
from ..discord_client.app_settings import DISCORD_API_BASE_URL
|
||||
from ..discord_client.tests.factories import (
|
||||
TEST_GUILD_ID,
|
||||
TEST_USER_ID,
|
||||
TEST_USER_NAME,
|
||||
create_discord_error_response_unknown_member,
|
||||
create_discord_guild_member_object,
|
||||
create_discord_guild_object,
|
||||
create_discord_role_object,
|
||||
create_discord_user_object,
|
||||
)
|
||||
from ..models import DiscordUser
|
||||
from . import MODULE_PATH, TEST_MAIN_ID, TEST_MAIN_NAME, add_permissions_to_members
|
||||
from .factories import create_discord_user, create_user
|
||||
|
||||
logger = logging.getLogger('allianceauth')
|
||||
|
||||
ROLE_MEMBER = create_role(99, 'Member')
|
||||
ROLE_BLUE = create_role(98, 'Blue')
|
||||
ROLE_ALPHA = create_discord_role_object(id=1, name="alpha")
|
||||
ROLE_BRAVO = create_discord_role_object(id=2, name="bravo")
|
||||
ROLE_CHARLIE = create_discord_role_object(id=3, name="charlie")
|
||||
ROLE_CHARLIE_2 = create_discord_role_object(id=4, name="Charlie") # Discord roles are case sensitive
|
||||
ROLE_MIKE = create_discord_role_object(id=13, name="mike", managed=True)
|
||||
ROLE_MEMBER = create_discord_role_object(99, 'Member')
|
||||
ROLE_BLUE = create_discord_role_object(98, 'Blue')
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DiscordRequest:
|
||||
"""Helper for comparing requests made to the Discord API."""
|
||||
method: str
|
||||
url: str
|
||||
text: str = dataclasses.field(compare=False, default=None)
|
||||
|
||||
def json(self):
|
||||
return json.loads(self.text)
|
||||
|
||||
|
||||
# Putting all requests to Discord into objects so we can compare them better
|
||||
DiscordRequest = namedtuple('DiscordRequest', ['method', 'url'])
|
||||
user_get_current_request = DiscordRequest(
|
||||
method='GET',
|
||||
url=f'{DISCORD_API_BASE_URL}users/@me'
|
||||
@@ -87,8 +101,7 @@ remove_guild_member_request = DiscordRequest(
|
||||
|
||||
|
||||
def clear_cache():
|
||||
default_cache = caches['default']
|
||||
redis = default_cache.get_master_client()
|
||||
redis = get_redis_client()
|
||||
redis.flushall()
|
||||
logger.info('Cache flushed')
|
||||
|
||||
@@ -103,13 +116,13 @@ def reset_testdata():
|
||||
Notification.objects.all().delete()
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.core.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@patch(MODULE_PATH + '.models.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=False)
|
||||
@requests_mock.Mocker()
|
||||
class TestServiceFeatures(TransactionTestCase):
|
||||
fixtures = ['disable_analytics.json']
|
||||
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
@@ -193,7 +206,7 @@ class TestServiceFeatures(TransactionTestCase):
|
||||
requests_mocker.patch(modify_guild_member_request.url, status_code=204)
|
||||
|
||||
# exhausting rate limit
|
||||
client = DiscordUser.objects._bot_client()
|
||||
client = create_bot_client()
|
||||
client._redis.set(
|
||||
name=client._KEY_GLOBAL_RATE_LIMIT_REMAINING,
|
||||
value=0,
|
||||
@@ -209,7 +222,6 @@ class TestServiceFeatures(TransactionTestCase):
|
||||
requests_made = [
|
||||
DiscordRequest(r.method, r.url) for r in requests_mocker.request_history
|
||||
]
|
||||
|
||||
self.assertListEqual(requests_made, list())
|
||||
|
||||
def test_when_member_is_demoted_to_guest_then_his_account_is_deleted(
|
||||
@@ -247,7 +259,7 @@ class TestServiceFeatures(TransactionTestCase):
|
||||
# request mocks
|
||||
requests_mocker.get(
|
||||
guild_member_request.url,
|
||||
json={'user': create_user_info(), 'roles': ['3', '13', '99']}
|
||||
json=create_discord_guild_member_object(roles=[3, 13, 99])
|
||||
)
|
||||
requests_mocker.get(
|
||||
guild_roles_request.url,
|
||||
@@ -283,10 +295,7 @@ class TestServiceFeatures(TransactionTestCase):
|
||||
):
|
||||
requests_mocker.get(
|
||||
guild_member_request.url,
|
||||
json={
|
||||
'user': create_user_info(),
|
||||
'roles': ['13', '99']
|
||||
}
|
||||
json=create_discord_guild_member_object(roles=[13, 99])
|
||||
)
|
||||
requests_mocker.get(
|
||||
guild_roles_request.url,
|
||||
@@ -315,10 +324,7 @@ class TestServiceFeatures(TransactionTestCase):
|
||||
):
|
||||
requests_mocker.get(
|
||||
guild_member_request.url,
|
||||
json={
|
||||
'user': {'id': str(TEST_USER_ID), 'username': TEST_MAIN_NAME},
|
||||
'roles': ['13', '99']
|
||||
}
|
||||
json=create_discord_guild_member_object(roles=['13', '99'])
|
||||
)
|
||||
requests_mocker.get(
|
||||
guild_roles_request.url,
|
||||
@@ -344,11 +350,33 @@ class TestServiceFeatures(TransactionTestCase):
|
||||
self.assertTrue(DiscordUser.objects.user_has_account(self.user))
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@patch(MODULE_PATH + '.models.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@requests_mock.Mocker()
|
||||
class StateTestCase(TestCase):
|
||||
class TestTasks(NoSocketsTestCase):
|
||||
def test_should_update_username(self, requests_mocker):
|
||||
# given
|
||||
user = create_user()
|
||||
discord_user = create_discord_user(user)
|
||||
discord_user_obj = create_discord_user_object()
|
||||
data = create_discord_guild_member_object(user=discord_user_obj)
|
||||
requests_mocker.get(guild_member_request.url, json=data)
|
||||
# when
|
||||
tasks.update_username.delay(user.pk)
|
||||
# then
|
||||
discord_user.refresh_from_db()
|
||||
self.assertEqual(discord_user.username, discord_user_obj["username"])
|
||||
self.assertEqual(
|
||||
discord_user.discriminator, discord_user_obj["discriminator"]
|
||||
)
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@patch(MODULE_PATH + '.models.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@requests_mock.Mocker()
|
||||
class StateTestCase(NoSocketsTestCase):
|
||||
|
||||
def setUp(self):
|
||||
clear_cache()
|
||||
@@ -432,6 +460,7 @@ class StateTestCase(TestCase):
|
||||
self.user.discord
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.core.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@patch(MODULE_PATH + '.models.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@requests_mock.Mocker()
|
||||
@@ -450,24 +479,25 @@ class TestUserFeatures(WebTest):
|
||||
)
|
||||
add_permissions_to_members()
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.managers.OAuth2Session')
|
||||
@patch(MODULE_PATH + '.views.messages', spec=True)
|
||||
@patch(MODULE_PATH + '.managers.OAuth2Session', spec=True)
|
||||
def test_user_activation_normal(
|
||||
self, requests_mocker, mock_OAuth2Session, mock_messages
|
||||
):
|
||||
# setup
|
||||
requests_mocker.get(
|
||||
guild_infos_request.url, json={'id': TEST_GUILD_ID, 'name': 'Test Guild'}
|
||||
guild_infos_request.url, json=create_discord_guild_object()
|
||||
)
|
||||
requests_mocker.get(
|
||||
user_get_current_request.url,
|
||||
json=create_user_info(
|
||||
TEST_USER_ID, TEST_USER_NAME, TEST_USER_DISCRIMINATOR
|
||||
)
|
||||
user_get_current_request.url, json=create_discord_user_object()
|
||||
)
|
||||
requests_mocker.get(
|
||||
guild_roles_request.url,
|
||||
json=[ROLE_ALPHA, ROLE_BRAVO, ROLE_MIKE, ROLE_MEMBER]
|
||||
guild_roles_request.url, json=[ROLE_ALPHA, ROLE_BRAVO, ROLE_MEMBER]
|
||||
)
|
||||
requests_mocker.get(
|
||||
guild_member_request.url,
|
||||
status_code=404,
|
||||
json=create_discord_error_response_unknown_member()
|
||||
)
|
||||
requests_mocker.put(add_guild_member_request.url, status_code=201)
|
||||
|
||||
@@ -505,33 +535,93 @@ class TestUserFeatures(WebTest):
|
||||
for r in requests_mocker.request_history:
|
||||
obj = DiscordRequest(r.method, r.url)
|
||||
requests_made.append(obj)
|
||||
self.assertIn(add_guild_member_request, requests_made)
|
||||
|
||||
expected = [
|
||||
guild_infos_request,
|
||||
user_get_current_request,
|
||||
guild_roles_request,
|
||||
add_guild_member_request
|
||||
]
|
||||
self.assertListEqual(requests_made, expected)
|
||||
@patch(MODULE_PATH + '.views.messages', spec=True)
|
||||
@patch(MODULE_PATH + '.managers.OAuth2Session', spec=True)
|
||||
def test_should_activate_existing_user_and_keep_managed_and_reserved_roles(
|
||||
self, requests_mocker, mock_OAuth2Session, mock_messages
|
||||
):
|
||||
# setup
|
||||
requests_mocker.get(
|
||||
guild_infos_request.url, json=create_discord_guild_object()
|
||||
)
|
||||
requests_mocker.get(
|
||||
user_get_current_request.url, json=create_discord_user_object()
|
||||
)
|
||||
requests_mocker.get(
|
||||
guild_roles_request.url, json=[
|
||||
ROLE_ALPHA, ROLE_CHARLIE, ROLE_MEMBER, ROLE_MIKE
|
||||
]
|
||||
)
|
||||
requests_mocker.get(
|
||||
guild_member_request.url,
|
||||
json=create_discord_guild_member_object(roles=[1, 3, 13])
|
||||
)
|
||||
requests_mocker.patch(modify_guild_member_request.url, status_code=204)
|
||||
ReservedGroupName.objects.create(
|
||||
name="charlie", reason="dummy", created_by="xyz"
|
||||
)
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.managers.OAuth2Session')
|
||||
authentication_code = 'auth_code'
|
||||
oauth_url = 'https://www.example.com/oauth'
|
||||
state = ''
|
||||
mock_OAuth2Session.return_value.authorization_url.return_value = \
|
||||
oauth_url, state
|
||||
|
||||
# login
|
||||
self.app.set_user(self.member)
|
||||
|
||||
# user opens services page
|
||||
services_page = self.app.get(reverse('services:services'))
|
||||
self.assertEqual(services_page.status_code, 200)
|
||||
|
||||
# user clicks Discord service activation link on page
|
||||
response = services_page.click(href=reverse('discord:activate'))
|
||||
|
||||
# check we got a redirect to Discord OAuth
|
||||
self.assertRedirects(
|
||||
response, expected_url=oauth_url, fetch_redirect_response=False
|
||||
)
|
||||
|
||||
# simulate Discord callback
|
||||
response = self.app.get(
|
||||
reverse('discord:callback'), params={'code': authentication_code}
|
||||
)
|
||||
|
||||
# user got a success message
|
||||
self.assertTrue(mock_messages.success.called)
|
||||
self.assertFalse(mock_messages.error.called)
|
||||
|
||||
my_request = None
|
||||
for r in requests_mocker.request_history:
|
||||
obj = DiscordRequest(r.method, r.url, r.text)
|
||||
if obj == modify_guild_member_request:
|
||||
my_request = obj
|
||||
break
|
||||
else:
|
||||
self.fail("Request not found")
|
||||
self.assertSetEqual(set(my_request.json()["roles"]), {3, 13, 99})
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages', spec=True)
|
||||
@patch(MODULE_PATH + '.managers.OAuth2Session', spec=True)
|
||||
def test_user_activation_failed(
|
||||
self, requests_mocker, mock_OAuth2Session, mock_messages
|
||||
):
|
||||
# setup
|
||||
requests_mocker.get(
|
||||
guild_infos_request.url, json={'id': TEST_GUILD_ID, 'name': 'Test Guild'}
|
||||
guild_infos_request.url, json=create_discord_guild_object()
|
||||
)
|
||||
requests_mocker.get(
|
||||
user_get_current_request.url,
|
||||
json=create_user_info(
|
||||
TEST_USER_ID, TEST_USER_NAME, TEST_USER_DISCRIMINATOR
|
||||
)
|
||||
user_get_current_request.url, json=create_discord_user_object()
|
||||
)
|
||||
requests_mocker.get(
|
||||
guild_roles_request.url,
|
||||
json=[ROLE_ALPHA, ROLE_BRAVO, ROLE_MIKE, ROLE_MEMBER]
|
||||
guild_roles_request.url, json=[ROLE_ALPHA, ROLE_BRAVO, ROLE_MEMBER]
|
||||
)
|
||||
requests_mocker.get(
|
||||
guild_member_request.url,
|
||||
status_code=404,
|
||||
json=create_discord_error_response_unknown_member()
|
||||
)
|
||||
|
||||
mock_exception = HTTPError('error')
|
||||
@@ -573,20 +663,13 @@ class TestUserFeatures(WebTest):
|
||||
for r in requests_mocker.request_history:
|
||||
obj = DiscordRequest(r.method, r.url)
|
||||
requests_made.append(obj)
|
||||
self.assertIn(add_guild_member_request, requests_made)
|
||||
|
||||
expected = [
|
||||
guild_infos_request,
|
||||
user_get_current_request,
|
||||
guild_roles_request,
|
||||
add_guild_member_request
|
||||
]
|
||||
self.assertListEqual(requests_made, expected)
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.views.messages', spec=True)
|
||||
def test_user_deactivation_normal(self, requests_mocker, mock_messages):
|
||||
# setup
|
||||
requests_mocker.get(
|
||||
guild_infos_request.url, json={'id': TEST_GUILD_ID, 'name': 'Test Guild'}
|
||||
guild_infos_request.url, json=create_discord_guild_object()
|
||||
)
|
||||
requests_mocker.delete(remove_guild_member_request.url, status_code=204)
|
||||
DiscordUser.objects.create(user=self.member, uid=TEST_USER_ID)
|
||||
@@ -612,15 +695,13 @@ class TestUserFeatures(WebTest):
|
||||
for r in requests_mocker.request_history:
|
||||
obj = DiscordRequest(r.method, r.url)
|
||||
requests_made.append(obj)
|
||||
self.assertIn(remove_guild_member_request, requests_made)
|
||||
|
||||
expected = [guild_infos_request, remove_guild_member_request]
|
||||
self.assertListEqual(requests_made, expected)
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.views.messages', spec=True)
|
||||
def test_user_deactivation_fails(self, requests_mocker, mock_messages):
|
||||
# setup
|
||||
requests_mocker.get(
|
||||
guild_infos_request.url, json={'id': TEST_GUILD_ID, 'name': 'Test Guild'}
|
||||
guild_infos_request.url, json=create_discord_guild_object()
|
||||
)
|
||||
mock_exception = HTTPError('error')
|
||||
mock_exception.response = Mock()
|
||||
@@ -650,11 +731,9 @@ class TestUserFeatures(WebTest):
|
||||
for r in requests_mocker.request_history:
|
||||
obj = DiscordRequest(r.method, r.url)
|
||||
requests_made.append(obj)
|
||||
self.assertIn(remove_guild_member_request, requests_made)
|
||||
|
||||
expected = [guild_infos_request, remove_guild_member_request]
|
||||
self.assertListEqual(requests_made, expected)
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.views.messages', spec=True)
|
||||
def test_user_add_new_server(self, requests_mocker, mock_messages):
|
||||
# setup
|
||||
mock_exception = HTTPError(Mock(**{"response.status_code": 400}))
|
||||
@@ -686,14 +765,13 @@ class TestUserFeatures(WebTest):
|
||||
services_page = self.app.get(reverse('services:services'))
|
||||
self.assertEqual(services_page.status_code, 200)
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
|
||||
@patch(MODULE_PATH + ".core.default_bot_client", spec=True)
|
||||
def test_server_name_is_updated_by_task(
|
||||
self, requests_mocker
|
||||
self, requests_mocker, mock_bot_client
|
||||
):
|
||||
# setup
|
||||
requests_mocker.get(
|
||||
guild_infos_request.url, json={'id': TEST_GUILD_ID, 'name': 'Test Guild'}
|
||||
)
|
||||
mock_bot_client.guild_name.return_value = "Test Guild"
|
||||
# run task to update usernames
|
||||
tasks.update_all_usernames()
|
||||
|
||||
|
||||
@@ -1,364 +1,395 @@
|
||||
from unittest.mock import patch, Mock
|
||||
import urllib
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.contrib.auth.models import Group, User
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
from . import (
|
||||
from ..app_settings import DISCORD_APP_ID, DISCORD_APP_SECRET, DISCORD_CALLBACK_URL
|
||||
from ..discord_client import (
|
||||
DISCORD_OAUTH_BASE_URL,
|
||||
DISCORD_OAUTH_TOKEN_URL,
|
||||
DiscordApiBackoff,
|
||||
DiscordClient,
|
||||
RolesSet,
|
||||
)
|
||||
from ..discord_client.tests.factories import (
|
||||
TEST_GUILD_ID,
|
||||
TEST_USER_NAME,
|
||||
TEST_USER_ID,
|
||||
TEST_MAIN_NAME,
|
||||
TEST_MAIN_ID,
|
||||
MODULE_PATH,
|
||||
ROLE_ALPHA,
|
||||
ROLE_BRAVO,
|
||||
ROLE_CHARLIE,
|
||||
TEST_USER_NAME,
|
||||
create_role,
|
||||
create_user,
|
||||
)
|
||||
from ..discord_client.tests import create_matched_role
|
||||
from ..app_settings import (
|
||||
DISCORD_APP_ID,
|
||||
DISCORD_APP_SECRET,
|
||||
DISCORD_CALLBACK_URL,
|
||||
)
|
||||
from ..discord_client import DiscordClient, DiscordApiBackoff
|
||||
from ..models import DiscordUser
|
||||
from ..utils import set_logger_to_file
|
||||
|
||||
from . import MODULE_PATH, TEST_MAIN_NAME
|
||||
|
||||
logger = set_logger_to_file(MODULE_PATH + '.managers', __file__)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
@patch(MODULE_PATH + '.models.DiscordUser.objects._exchange_auth_code_for_token')
|
||||
@patch(MODULE_PATH + '.models.DiscordUser.objects.user_group_names')
|
||||
@patch(MODULE_PATH + '.models.DiscordUser.objects.user_formatted_nick')
|
||||
class TestAddUser(TestCase):
|
||||
@patch(MODULE_PATH + '.managers.create_bot_client', spec=True)
|
||||
@patch(
|
||||
MODULE_PATH + '.models.DiscordUser.objects._exchange_auth_code_for_token', spec=True
|
||||
)
|
||||
@patch(MODULE_PATH + '.managers.calculate_roles_for_user', spec=True)
|
||||
@patch(MODULE_PATH + '.managers.user_formatted_nick', spec=True)
|
||||
class TestAddUser(NoSocketsTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
self.user_info = {
|
||||
'id': TEST_USER_ID,
|
||||
'name': TEST_USER_NAME,
|
||||
'username': TEST_USER_NAME,
|
||||
'discriminator': '1234',
|
||||
}
|
||||
self.access_token = 'accesstoken'
|
||||
|
||||
def test_can_create_user_no_roles_no_nick(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
# given
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = []
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = True
|
||||
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([]), None
|
||||
mock_create_bot_client.return_value.add_guild_member.return_value = True
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
|
||||
_, kwargs = mock_create_bot_client.return_value.add_guild_member.call_args
|
||||
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
|
||||
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
|
||||
self.assertEqual(kwargs['access_token'], self.access_token)
|
||||
self.assertIsNone(kwargs['role_ids'])
|
||||
self.assertFalse(kwargs['role_ids'])
|
||||
self.assertIsNone(kwargs['nick'])
|
||||
|
||||
def test_can_create_user_with_roles_no_nick(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
roles = [
|
||||
create_matched_role(ROLE_ALPHA),
|
||||
create_matched_role(ROLE_BRAVO),
|
||||
create_matched_role(ROLE_CHARLIE)
|
||||
]
|
||||
# given
|
||||
role_a = create_role(id=1)
|
||||
role_b = create_role(id=2)
|
||||
roles_calculated = RolesSet([role_a, role_b])
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = ['a', 'b', 'c']
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = roles
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = True
|
||||
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = roles_calculated, None
|
||||
mock_create_bot_client.return_value.add_guild_member.return_value = True
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
|
||||
_, kwargs = mock_create_bot_client.return_value.add_guild_member.call_args
|
||||
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
|
||||
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
|
||||
self.assertEqual(kwargs['access_token'], self.access_token)
|
||||
self.assertSetEqual(set(kwargs['role_ids']), {1, 2, 3})
|
||||
self.assertSetEqual(set(kwargs['role_ids']), {1, 2})
|
||||
self.assertIsNone(kwargs['nick'])
|
||||
|
||||
def test_can_activate_existing_user_with_roles_no_nick(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
roles = [
|
||||
create_matched_role(ROLE_ALPHA),
|
||||
create_matched_role(ROLE_BRAVO),
|
||||
create_matched_role(ROLE_CHARLIE)
|
||||
]
|
||||
# given
|
||||
role_a = create_role(id=1)
|
||||
role_b = create_role(id=2)
|
||||
roles_calculated = RolesSet([role_a, role_b])
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = ['a', 'b', 'c']
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = roles
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = None
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = roles_calculated, False
|
||||
mock_create_bot_client.return_value.add_guild_member.return_value = None
|
||||
mock_create_bot_client.return_value.modify_guild_member.return_value = True
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.modify_guild_member.call_args
|
||||
_, kwargs = mock_create_bot_client.return_value.modify_guild_member.call_args
|
||||
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
|
||||
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
|
||||
self.assertSetEqual(set(kwargs['role_ids']), {1, 2, 3})
|
||||
self.assertSetEqual(set(kwargs['role_ids']), {1, 2})
|
||||
self.assertIsNone(kwargs['nick'])
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_SYNC_NAMES', True)
|
||||
def test_can_create_user_no_roles_with_nick(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
# given
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
mock_user_formatted_nick.return_value = TEST_MAIN_NAME
|
||||
mock_user_group_names.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = []
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = True
|
||||
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([]), None
|
||||
mock_create_bot_client.return_value.add_guild_member.return_value = True
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
|
||||
_, kwargs = mock_create_bot_client.return_value.add_guild_member.call_args
|
||||
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
|
||||
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
|
||||
self.assertEqual(kwargs['access_token'], self.access_token)
|
||||
self.assertIsNone(kwargs['role_ids'])
|
||||
self.assertFalse(kwargs['role_ids'])
|
||||
self.assertEqual(kwargs['nick'], TEST_MAIN_NAME)
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_SYNC_NAMES', True)
|
||||
def test_can_activate_existing_user_no_roles_with_nick(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
# given
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
mock_user_formatted_nick.return_value = TEST_MAIN_NAME
|
||||
mock_user_group_names.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = []
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = None
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([]), False
|
||||
mock_create_bot_client.return_value.add_guild_member.return_value = None
|
||||
mock_create_bot_client.return_value.modify_guild_member.return_value = True
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.modify_guild_member.call_args
|
||||
_, kwargs = mock_create_bot_client.return_value.modify_guild_member.call_args
|
||||
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
|
||||
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
|
||||
self.assertIsNone(kwargs['role_ids'])
|
||||
self.assertFalse(kwargs['role_ids'])
|
||||
self.assertEqual(kwargs['nick'], TEST_MAIN_NAME)
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_SYNC_NAMES', False)
|
||||
def test_can_create_user_no_roles_and_without_nick_if_turned_off(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
# given
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
mock_user_formatted_nick.return_value = TEST_MAIN_NAME
|
||||
mock_user_group_names.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = []
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = True
|
||||
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([]), None
|
||||
mock_create_bot_client.return_value.add_guild_member.return_value = True
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
|
||||
_, kwargs = mock_create_bot_client.return_value.add_guild_member.call_args
|
||||
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
|
||||
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
|
||||
self.assertEqual(kwargs['access_token'], self.access_token)
|
||||
self.assertIsNone(kwargs['role_ids'])
|
||||
self.assertFalse(kwargs['role_ids'])
|
||||
self.assertIsNone(kwargs['nick'])
|
||||
|
||||
def test_can_activate_existing_guild_member(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
# given
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
roles_calculated = RolesSet([create_role()])
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = []
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = None
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = roles_calculated, False
|
||||
mock_create_bot_client.return_value.add_guild_member.return_value = None
|
||||
mock_create_bot_client.return_value.modify_guild_member.return_value = True
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
self.assertTrue(mock_create_bot_client.return_value.modify_guild_member.called)
|
||||
|
||||
def test_can_activate_existing_member_with_roles(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
# given
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
roles_calculated = RolesSet([create_role(id=1)])
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = roles_calculated, False
|
||||
mock_create_bot_client.return_value.add_guild_member.return_value = None
|
||||
mock_create_bot_client.return_value.modify_guild_member.return_value = True
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
_, kwargs = mock_create_bot_client.return_value.modify_guild_member.call_args
|
||||
self.assertSetEqual(set(kwargs['role_ids']), {1})
|
||||
|
||||
def test_can_activate_existing_guild_member_failure(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
# given
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
roles_calculated = RolesSet([create_role()])
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = []
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = None
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = False
|
||||
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = roles_calculated, False
|
||||
mock_create_bot_client.return_value.add_guild_member.return_value = None
|
||||
mock_create_bot_client.return_value.modify_guild_member.return_value = False
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertFalse(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
self.assertTrue(mock_create_bot_client.return_value.modify_guild_member.called)
|
||||
|
||||
def test_return_false_when_user_creation_fails(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
# given
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = []
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = False
|
||||
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([]), None
|
||||
mock_create_bot_client.return_value.add_guild_member.return_value = False
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertFalse(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
self.assertTrue(mock_create_bot_client.return_value.add_guild_member.called)
|
||||
|
||||
def test_return_false_when_on_api_backoff(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
# given
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = []
|
||||
mock_DiscordClient.return_value.add_guild_member.side_effect = \
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([]), None
|
||||
mock_create_bot_client.return_value.add_guild_member.side_effect = \
|
||||
DiscordApiBackoff(999)
|
||||
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertFalse(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
self.assertTrue(mock_create_bot_client.return_value.add_guild_member.called)
|
||||
|
||||
def test_return_false_on_http_error(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_calculate_roles_for_user,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
mock_create_bot_client,
|
||||
mock_DiscordClient,
|
||||
):
|
||||
# given
|
||||
discord_user = create_user(id=TEST_USER_ID)
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = []
|
||||
mock_DiscordClient.return_value.current_user.return_value = discord_user
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([]), None
|
||||
mock_exception = HTTPError('error')
|
||||
mock_exception.response = Mock()
|
||||
mock_exception.response.status_code = 500
|
||||
mock_DiscordClient.return_value.add_guild_member.side_effect = mock_exception
|
||||
|
||||
mock_create_bot_client.return_value.add_guild_member.side_effect = mock_exception
|
||||
# when
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
# then
|
||||
self.assertFalse(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
self.assertTrue(mock_create_bot_client.return_value.add_guild_member.called)
|
||||
|
||||
|
||||
class TestOauthHelpers(TestCase):
|
||||
class TestOauthHelpers(NoSocketsTestCase):
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_APP_ID', '123456')
|
||||
def test_generate_bot_add_url(self):
|
||||
bot_add_url = DiscordUser.objects.generate_bot_add_url()
|
||||
|
||||
auth_url = DiscordClient.OAUTH_BASE_URL
|
||||
auth_url = DISCORD_OAUTH_BASE_URL
|
||||
real_bot_add_url = (
|
||||
f'{auth_url}?client_id=123456&scope=bot'
|
||||
f'&permissions={DiscordUser.objects.BOT_PERMISSIONS}'
|
||||
@@ -368,12 +399,12 @@ class TestOauthHelpers(TestCase):
|
||||
def test_generate_oauth_redirect_url(self):
|
||||
oauth_url = DiscordUser.objects.generate_oauth_redirect_url()
|
||||
|
||||
self.assertIn(DiscordClient.OAUTH_BASE_URL, oauth_url)
|
||||
self.assertIn(DISCORD_OAUTH_BASE_URL, oauth_url)
|
||||
self.assertIn('+'.join(DiscordUser.objects.SCOPES), oauth_url)
|
||||
self.assertIn(DISCORD_APP_ID, oauth_url)
|
||||
self.assertIn(urllib.parse.quote_plus(DISCORD_CALLBACK_URL), oauth_url)
|
||||
|
||||
@patch(MODULE_PATH + '.managers.OAuth2Session')
|
||||
@patch(MODULE_PATH + '.managers.OAuth2Session', spec=True)
|
||||
def test_process_callback_code(self, oauth):
|
||||
instance = oauth.return_value
|
||||
instance.fetch_token.return_value = {'access_token': 'mywonderfultoken'}
|
||||
@@ -386,52 +417,13 @@ class TestOauthHelpers(TestCase):
|
||||
self.assertEqual(kwargs['redirect_uri'], DISCORD_CALLBACK_URL)
|
||||
self.assertTrue(instance.fetch_token.called)
|
||||
args, kwargs = instance.fetch_token.call_args
|
||||
self.assertEqual(args[0], DiscordClient.OAUTH_TOKEN_URL)
|
||||
self.assertEqual(args[0], DISCORD_OAUTH_TOKEN_URL)
|
||||
self.assertEqual(kwargs['client_secret'], DISCORD_APP_SECRET)
|
||||
self.assertEqual(kwargs['code'], '12345')
|
||||
self.assertEqual(token, 'mywonderfultoken')
|
||||
|
||||
|
||||
class TestUserFormattedNick(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
|
||||
def test_return_nick_when_user_has_main(self):
|
||||
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
|
||||
result = DiscordUser.objects.user_formatted_nick(self.user)
|
||||
expected = TEST_MAIN_NAME
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_return_none_if_user_has_no_main(self):
|
||||
result = DiscordUser.objects.user_formatted_nick(self.user)
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
class TestUserGroupNames(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.group_1 = Group.objects.create(name='Group 1')
|
||||
cls.group_2 = Group.objects.create(name='Group 2')
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_member(TEST_USER_NAME)
|
||||
|
||||
def test_return_groups_and_state_names_for_user(self):
|
||||
self.user.groups.add(self.group_1)
|
||||
result = DiscordUser.objects.user_group_names(self.user)
|
||||
expected = ['Group 1', 'Member']
|
||||
self.assertSetEqual(set(result), set(expected))
|
||||
|
||||
def test_return_state_only_if_user_has_no_groups(self):
|
||||
result = DiscordUser.objects.user_group_names(self.user)
|
||||
expected = ['Member']
|
||||
self.assertSetEqual(set(result), set(expected))
|
||||
|
||||
|
||||
class TestUserHasAccount(TestCase):
|
||||
class TestUserHasAccount(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -453,59 +445,22 @@ class TestUserHasAccount(TestCase):
|
||||
self.assertFalse(DiscordUser.objects.user_has_account('abc'))
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
@patch(MODULE_PATH + '.managers.logger')
|
||||
class TestServerName(TestCase):
|
||||
class TestOtherMethods(NoSocketsTestCase):
|
||||
@patch(MODULE_PATH + '.managers.core_group_to_role', spec=True)
|
||||
def test_should_call_group_to_role(self, mock_core_group_to_role):
|
||||
# given
|
||||
role = create_role(id=1, name="alpha", managed=False)
|
||||
mock_core_group_to_role.return_value = role
|
||||
# when
|
||||
result = DiscordUser.objects.group_to_role(Mock())
|
||||
# then
|
||||
self.assertEqual(result["id"], 1)
|
||||
self.assertEqual(result["name"], "alpha")
|
||||
self.assertEqual(result["managed"], False)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
|
||||
def test_returns_name_when_api_returns_it(self, mock_logger, mock_DiscordClient):
|
||||
server_name = "El Dorado"
|
||||
mock_DiscordClient.return_value.guild_name.return_value = server_name
|
||||
|
||||
self.assertEqual(DiscordUser.objects.server_name(), server_name)
|
||||
self.assertFalse(mock_logger.warning.called)
|
||||
|
||||
def test_returns_empty_string_when_api_throws_http_error(
|
||||
self, mock_logger, mock_DiscordClient
|
||||
):
|
||||
mock_exception = HTTPError('Test exception')
|
||||
mock_exception.response = Mock(**{"status_code": 440})
|
||||
mock_DiscordClient.return_value.guild_name.side_effect = mock_exception
|
||||
|
||||
self.assertEqual(DiscordUser.objects.server_name(), "")
|
||||
self.assertFalse(mock_logger.warning.called)
|
||||
|
||||
def test_returns_empty_string_when_api_throws_service_error(
|
||||
self, mock_logger, mock_DiscordClient
|
||||
):
|
||||
mock_DiscordClient.return_value.guild_name.side_effect = DiscordApiBackoff(1000)
|
||||
|
||||
self.assertEqual(DiscordUser.objects.server_name(), "")
|
||||
self.assertFalse(mock_logger.warning.called)
|
||||
|
||||
def test_returns_empty_string_when_api_throws_unexpected_error(
|
||||
self, mock_logger, mock_DiscordClient
|
||||
):
|
||||
mock_DiscordClient.return_value.guild_name.side_effect = RuntimeError
|
||||
|
||||
self.assertEqual(DiscordUser.objects.server_name(), "")
|
||||
self.assertTrue(mock_logger.warning.called)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
class TestRoleForGroup(TestCase):
|
||||
def test_return_role_if_found(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.match_role_from_name.return_value = ROLE_ALPHA
|
||||
|
||||
group = Group.objects.create(name='alpha')
|
||||
self.assertEqual(DiscordUser.objects.group_to_role(group), ROLE_ALPHA)
|
||||
|
||||
def test_return_empty_dict_if_not_found(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.match_role_from_name.return_value = dict()
|
||||
|
||||
group = Group.objects.create(name='unknown')
|
||||
self.assertEqual(DiscordUser.objects.group_to_role(group), dict())
|
||||
@patch(MODULE_PATH + '.managers.core_server_name', spec=True)
|
||||
def test_should_call_server_name(self, mock_core_server_name):
|
||||
# when
|
||||
DiscordUser.objects.server_name()
|
||||
# then
|
||||
self.assertTrue(mock_core_server_name.called)
|
||||
|
||||
@@ -1,34 +1,27 @@
|
||||
from unittest.mock import patch, Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from allianceauth.groupmanagement.models import ReservedGroupName
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
from . import (
|
||||
TEST_USER_NAME,
|
||||
from ..discord_client import DiscordApiBackoff, RolesSet
|
||||
from ..discord_client.tests.factories import (
|
||||
TEST_USER_ID,
|
||||
TEST_MAIN_NAME,
|
||||
TEST_MAIN_ID,
|
||||
MODULE_PATH,
|
||||
ROLE_ALPHA,
|
||||
ROLE_BRAVO,
|
||||
ROLE_CHARLIE,
|
||||
ROLE_CHARLIE_2,
|
||||
ROLE_MIKE,
|
||||
TEST_USER_NAME,
|
||||
create_guild_member,
|
||||
create_role,
|
||||
)
|
||||
from ..discord_client import DiscordClient, DiscordApiBackoff
|
||||
from ..discord_client.tests import create_matched_role
|
||||
from ..discord_client.tests.factories import create_user as create_guild_user
|
||||
from ..models import DiscordUser
|
||||
from ..utils import set_logger_to_file
|
||||
|
||||
from . import MODULE_PATH, TEST_MAIN_ID, TEST_MAIN_NAME
|
||||
from .factories import create_discord_user, create_user
|
||||
|
||||
logger = set_logger_to_file(MODULE_PATH + '.models', __file__)
|
||||
|
||||
|
||||
class TestBasicsAndHelpers(TestCase):
|
||||
class TestBasicsAndHelpers(NoSocketsTestCase):
|
||||
|
||||
def test_str(self):
|
||||
user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
@@ -43,8 +36,8 @@ class TestBasicsAndHelpers(TestCase):
|
||||
self.assertEqual(repr(discord_user), expected)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
class TestUpdateNick(TestCase):
|
||||
@patch(MODULE_PATH + '.models.default_bot_client', spec=True)
|
||||
class TestUpdateNick(NoSocketsTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
@@ -52,119 +45,92 @@ class TestUpdateNick(TestCase):
|
||||
user=self.user, uid=TEST_USER_ID
|
||||
)
|
||||
|
||||
def test_can_update(self, mock_DiscordClient):
|
||||
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
|
||||
def test_can_update(self, mock_default_bot_client):
|
||||
# given
|
||||
AuthUtils.add_main_character_2(
|
||||
self.user, TEST_MAIN_NAME, TEST_MAIN_ID, disconnect_signals=True
|
||||
)
|
||||
mock_default_bot_client.modify_guild_member.return_value = True
|
||||
# when
|
||||
result = self.discord_user.update_nickname()
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
|
||||
def test_dont_update_if_user_has_no_main(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = False
|
||||
self.assertTrue(mock_default_bot_client.modify_guild_member.called)
|
||||
|
||||
def test_dont_update_if_user_has_no_main(self, mock_default_bot_client):
|
||||
# given
|
||||
mock_default_bot_client.modify_guild_member.return_value = False
|
||||
# when
|
||||
result = self.discord_user.update_nickname()
|
||||
# then
|
||||
self.assertFalse(result)
|
||||
self.assertFalse(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
|
||||
def test_return_none_if_user_no_longer_a_member(self, mock_DiscordClient):
|
||||
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = None
|
||||
self.assertFalse(mock_default_bot_client.modify_guild_member.called)
|
||||
|
||||
def test_return_none_if_user_no_longer_a_member(self, mock_default_bot_client):
|
||||
# given
|
||||
AuthUtils.add_main_character_2(
|
||||
self.user, TEST_MAIN_NAME, TEST_MAIN_ID, disconnect_signals=True
|
||||
)
|
||||
mock_default_bot_client.modify_guild_member.return_value = None
|
||||
# when
|
||||
result = self.discord_user.update_nickname()
|
||||
# then
|
||||
self.assertIsNone(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
|
||||
def test_return_false_if_api_returns_false(self, mock_DiscordClient):
|
||||
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = False
|
||||
self.assertTrue(mock_default_bot_client.modify_guild_member.called)
|
||||
|
||||
def test_return_false_if_api_returns_false(self, mock_default_bot_client):
|
||||
# given
|
||||
AuthUtils.add_main_character_2(
|
||||
self.user, TEST_MAIN_NAME, TEST_MAIN_ID, disconnect_signals=True
|
||||
)
|
||||
mock_default_bot_client.modify_guild_member.return_value = False
|
||||
# when
|
||||
result = self.discord_user.update_nickname()
|
||||
# then
|
||||
self.assertFalse(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
self.assertTrue(mock_default_bot_client.modify_guild_member.called)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
class TestUpdateUsername(TestCase):
|
||||
@patch(MODULE_PATH + '.models.default_bot_client.guild_member', spec=True)
|
||||
class TestUpdateUsername(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
cls.user = create_user()
|
||||
|
||||
def setUp(self):
|
||||
self.discord_user = DiscordUser.objects.create(
|
||||
user=self.user,
|
||||
uid=TEST_USER_ID,
|
||||
username=TEST_MAIN_NAME,
|
||||
discriminator='1234'
|
||||
)
|
||||
|
||||
def test_can_update(self, mock_DiscordClient):
|
||||
def test_can_update(self, mock_guild_member):
|
||||
# given
|
||||
discord_user = create_discord_user(user=self.user)
|
||||
new_username = 'New name'
|
||||
new_discriminator = '9876'
|
||||
user_info = {
|
||||
'user': {
|
||||
'id': str(TEST_USER_ID),
|
||||
'username': new_username,
|
||||
'discriminator': new_discriminator,
|
||||
}
|
||||
}
|
||||
mock_DiscordClient.return_value.guild_member.return_value = user_info
|
||||
|
||||
result = self.discord_user.update_username()
|
||||
guild_user = create_guild_user(
|
||||
username='New name', discriminator=new_discriminator
|
||||
)
|
||||
mock_guild_member.return_value = create_guild_member(user=guild_user)
|
||||
# when
|
||||
result = discord_user.update_username()
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
|
||||
self.discord_user.refresh_from_db()
|
||||
self.assertEqual(self.discord_user.username, new_username)
|
||||
self.assertEqual(self.discord_user.discriminator, new_discriminator)
|
||||
self.assertTrue(mock_guild_member.called)
|
||||
discord_user.refresh_from_db()
|
||||
self.assertEqual(discord_user.username, new_username)
|
||||
self.assertEqual(discord_user.discriminator, new_discriminator)
|
||||
|
||||
def test_return_none_if_user_no_longer_a_member(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.guild_member.return_value = None
|
||||
result = self.discord_user.update_username()
|
||||
def test_return_none_if_user_no_longer_a_member(self, mock_guild_member):
|
||||
# given
|
||||
discord_user = create_discord_user(user=self.user)
|
||||
mock_guild_member.return_value = None
|
||||
# when
|
||||
result = discord_user.update_username()
|
||||
# then
|
||||
self.assertIsNone(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
|
||||
|
||||
def test_return_false_if_api_returns_false(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.guild_member.return_value = False
|
||||
result = self.discord_user.update_username()
|
||||
self.assertFalse(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
|
||||
|
||||
def test_return_false_if_api_returns_corrput_data_1(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.guild_member.return_value = {'invalid': True}
|
||||
result = self.discord_user.update_username()
|
||||
self.assertFalse(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
|
||||
|
||||
def test_return_false_if_api_returns_corrput_data_2(self, mock_DiscordClient):
|
||||
user_info = {
|
||||
'user': {
|
||||
'id': str(TEST_USER_ID),
|
||||
'discriminator': '1234',
|
||||
}
|
||||
}
|
||||
mock_DiscordClient.return_value.guild_member.return_value = user_info
|
||||
result = self.discord_user.update_username()
|
||||
self.assertFalse(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
|
||||
|
||||
def test_return_false_if_api_returns_corrput_data_3(self, mock_DiscordClient):
|
||||
user_info = {
|
||||
'user': {
|
||||
'id': str(TEST_USER_ID),
|
||||
'username': TEST_USER_NAME,
|
||||
}
|
||||
}
|
||||
mock_DiscordClient.return_value.guild_member.return_value = user_info
|
||||
result = self.discord_user.update_username()
|
||||
self.assertFalse(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
|
||||
self.assertTrue(mock_guild_member.called)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.models.notify')
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
class TestDeleteUser(TestCase):
|
||||
@patch(MODULE_PATH + '.models.notify', spec=True)
|
||||
@patch(MODULE_PATH + '.models.create_bot_client', spec=True)
|
||||
class TestDeleteUser(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -176,272 +142,168 @@ class TestDeleteUser(TestCase):
|
||||
user=self.user, uid=TEST_USER_ID
|
||||
)
|
||||
|
||||
def test_can_delete_user(self, mock_DiscordClient, mock_notify):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
def test_can_delete_user(self, mock_create_bot_client, mock_notify):
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = True
|
||||
# when
|
||||
result = self.discord_user.delete_user()
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
|
||||
self.assertTrue(mock_create_bot_client.return_value.remove_guild_member.called)
|
||||
self.assertFalse(mock_notify.called)
|
||||
|
||||
def test_can_delete_user_and_notify_user(self, mock_DiscordClient, mock_notify):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
def test_can_delete_user_and_notify_user(self, mock_create_bot_client, mock_notify):
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = True
|
||||
# when
|
||||
result = self.discord_user.delete_user(notify_user=True)
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(mock_notify.called)
|
||||
|
||||
def test_can_delete_user_when_member_is_unknown(
|
||||
self, mock_DiscordClient, mock_notify
|
||||
self, mock_create_bot_client, mock_notify
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = None
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = None
|
||||
# when
|
||||
result = self.discord_user.delete_user()
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
|
||||
self.assertTrue(mock_create_bot_client.return_value.remove_guild_member.called)
|
||||
self.assertFalse(mock_notify.called)
|
||||
|
||||
def test_return_false_when_api_fails(self, mock_DiscordClient, mock_notify):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = False
|
||||
def test_return_false_when_api_fails(self, mock_create_bot_client, mock_notify):
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = False
|
||||
# when
|
||||
result = self.discord_user.delete_user()
|
||||
# then
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_dont_notify_if_user_was_already_deleted_and_return_none(
|
||||
self, mock_DiscordClient, mock_notify
|
||||
self, mock_create_bot_client, mock_notify
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = None
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = None
|
||||
DiscordUser.objects.get(pk=self.discord_user.pk).delete()
|
||||
# when
|
||||
result = self.discord_user.delete_user()
|
||||
# then
|
||||
self.assertIsNone(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
|
||||
self.assertTrue(mock_create_bot_client.return_value.remove_guild_member.called)
|
||||
self.assertFalse(mock_notify.called)
|
||||
|
||||
def test_raise_exception_on_api_backoff(
|
||||
self, mock_DiscordClient, mock_notify
|
||||
self, mock_create_bot_client, mock_notify
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.side_effect = \
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.side_effect = \
|
||||
DiscordApiBackoff(999)
|
||||
# when/then
|
||||
with self.assertRaises(DiscordApiBackoff):
|
||||
self.discord_user.delete_user()
|
||||
|
||||
def test_return_false_on_api_backoff_and_exception_handling_on(
|
||||
self, mock_DiscordClient, mock_notify
|
||||
self, mock_create_bot_client, mock_notify
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.side_effect = \
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.side_effect = \
|
||||
DiscordApiBackoff(999)
|
||||
# when
|
||||
result = self.discord_user.delete_user(handle_api_exceptions=True)
|
||||
# then
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_raise_exception_on_http_error(
|
||||
self, mock_DiscordClient, mock_notify
|
||||
self, mock_create_bot_client, mock_notify
|
||||
):
|
||||
# given
|
||||
mock_exception = HTTPError('error')
|
||||
mock_exception.response = Mock()
|
||||
mock_exception.response.status_code = 500
|
||||
mock_DiscordClient.return_value.remove_guild_member.side_effect = \
|
||||
mock_create_bot_client.return_value.remove_guild_member.side_effect = \
|
||||
mock_exception
|
||||
|
||||
# when/then
|
||||
with self.assertRaises(HTTPError):
|
||||
self.discord_user.delete_user()
|
||||
|
||||
def test_return_false_on_http_error_and_exception_handling_on(
|
||||
self, mock_DiscordClient, mock_notify
|
||||
self, mock_create_bot_client, mock_notify
|
||||
):
|
||||
# given
|
||||
mock_exception = HTTPError('error')
|
||||
mock_exception.response = Mock()
|
||||
mock_exception.response.status_code = 500
|
||||
mock_DiscordClient.return_value.remove_guild_member.side_effect = \
|
||||
mock_create_bot_client.return_value.remove_guild_member.side_effect = \
|
||||
mock_exception
|
||||
# when
|
||||
result = self.discord_user.delete_user(handle_api_exceptions=True)
|
||||
# then
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
@patch(MODULE_PATH + '.models.DiscordUser.objects.user_group_names')
|
||||
class TestUpdateGroups(TestCase):
|
||||
@patch(MODULE_PATH + '.models.default_bot_client', spec=True)
|
||||
@patch(MODULE_PATH + '.models.calculate_roles_for_user', spec=True)
|
||||
class TestUpdateGroups(NoSocketsTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
self.discord_user = DiscordUser.objects.create(
|
||||
user=self.user, uid=TEST_USER_ID
|
||||
)
|
||||
self.guild_roles = [ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE]
|
||||
self.roles_requested = [
|
||||
create_matched_role(ROLE_ALPHA), create_matched_role(ROLE_BRAVO)
|
||||
]
|
||||
user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
self.discord_user = DiscordUser.objects.create(user=user, uid=TEST_USER_ID)
|
||||
|
||||
def test_update_if_needed(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_DiscordClient
|
||||
):
|
||||
roles_current = [1]
|
||||
mock_user_group_names.return_value = []
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = self.roles_requested
|
||||
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
|
||||
mock_DiscordClient.return_value.guild_member.return_value = \
|
||||
{'roles': roles_current}
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
|
||||
result = self.discord_user.update_groups()
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.modify_guild_member.call_args
|
||||
self.assertEqual(set(kwargs['role_ids']), {1, 2})
|
||||
|
||||
def test_should_update_and_preserve_managed_and_reserved_roles(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_DiscordClient
|
||||
def test_should_update_when_roles_have_changed(
|
||||
self, mock_calculate_roles_for_user, mock_client
|
||||
):
|
||||
# given
|
||||
roles_current = [1, 3, 4, 13]
|
||||
mock_user_group_names.return_value = []
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = self.roles_requested
|
||||
mock_DiscordClient.return_value.guild_roles.return_value = [
|
||||
ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE, ROLE_CHARLIE_2
|
||||
]
|
||||
mock_DiscordClient.return_value.guild_member.return_value = {
|
||||
'roles': roles_current
|
||||
}
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
ReservedGroupName.objects.create(
|
||||
name="charlie", reason="dummy", created_by="xyz"
|
||||
)
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([create_role()]), True
|
||||
mock_client.modify_guild_member.return_value = True
|
||||
# when
|
||||
result = self.discord_user.update_groups()
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.modify_guild_member.call_args
|
||||
self.assertEqual(set(kwargs['role_ids']), {1, 2, 3, 4, 13})
|
||||
self.assertTrue(mock_client.modify_guild_member.called)
|
||||
|
||||
def test_dont_update_if_not_needed(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_DiscordClient
|
||||
def test_should_not_update_when_roles_have_not_changed(
|
||||
self, mock_calculate_roles_for_user, mock_client
|
||||
):
|
||||
roles_current = [1, 2, 13]
|
||||
mock_user_group_names.return_value = []
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = self.roles_requested
|
||||
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
|
||||
mock_DiscordClient.return_value.guild_member.return_value = \
|
||||
{'roles': roles_current}
|
||||
|
||||
# given
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([create_role()]), False
|
||||
mock_client.modify_guild_member.return_value = True
|
||||
# when
|
||||
result = self.discord_user.update_groups()
|
||||
# then
|
||||
self.assertTrue(result)
|
||||
self.assertFalse(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
self.assertFalse(mock_client.modify_guild_member.called)
|
||||
|
||||
def test_update_if_user_has_no_roles_on_discord(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_DiscordClient
|
||||
def test_should_not_update_when_user_not_guild_member(
|
||||
self, mock_calculate_roles_for_user, mock_client
|
||||
):
|
||||
roles_current = []
|
||||
mock_user_group_names.return_value = []
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = self.roles_requested
|
||||
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
|
||||
mock_DiscordClient.return_value.guild_member.return_value = \
|
||||
{'roles': roles_current}
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
|
||||
result = self.discord_user.update_groups()
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.modify_guild_member.call_args
|
||||
self.assertEqual(set(kwargs['role_ids']), {1, 2})
|
||||
|
||||
def test_return_none_if_user_no_longer_a_member(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_DiscordClient
|
||||
):
|
||||
mock_DiscordClient.return_value.guild_member.return_value = None
|
||||
|
||||
# given
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([create_role()]), None
|
||||
mock_client.modify_guild_member.return_value = True
|
||||
# when
|
||||
result = self.discord_user.update_groups()
|
||||
# then
|
||||
self.assertIsNone(result)
|
||||
self.assertFalse(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
self.assertFalse(mock_client.modify_guild_member.called)
|
||||
|
||||
def test_return_false_if_api_returns_false(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_DiscordClient
|
||||
def test_should_return_false_when_update_failed(
|
||||
self, mock_calculate_roles_for_user, mock_client
|
||||
):
|
||||
roles_current = [1]
|
||||
mock_user_group_names.return_value = []
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = self.roles_requested
|
||||
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
|
||||
mock_DiscordClient.return_value.guild_member.return_value = \
|
||||
{'roles': roles_current}
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = False
|
||||
|
||||
# given
|
||||
mock_calculate_roles_for_user.return_value = RolesSet([create_role()]), True
|
||||
mock_client.modify_guild_member.return_value = False
|
||||
# when
|
||||
result = self.discord_user.update_groups()
|
||||
# then
|
||||
self.assertFalse(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
|
||||
def test_raise_exception_if_member_has_unknown_roles(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_DiscordClient
|
||||
):
|
||||
roles_current = [99]
|
||||
mock_user_group_names.return_value = []
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = self.roles_requested
|
||||
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
|
||||
mock_DiscordClient.return_value.guild_member.return_value = \
|
||||
{'roles': roles_current}
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.discord_user.update_groups()
|
||||
|
||||
def test_refresh_guild_roles_user_roles_dont_not_match(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_DiscordClient
|
||||
):
|
||||
def my_guild_roles(guild_id, use_cache=True):
|
||||
if use_cache:
|
||||
return [ROLE_ALPHA, ROLE_BRAVO, ROLE_MIKE]
|
||||
else:
|
||||
return [ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE]
|
||||
|
||||
roles_current = [3]
|
||||
mock_user_group_names.return_value = []
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = self.roles_requested
|
||||
mock_DiscordClient.return_value.guild_roles.side_effect = my_guild_roles
|
||||
mock_DiscordClient.return_value.guild_member.return_value = \
|
||||
{'roles': roles_current}
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
result = self.discord_user.update_groups()
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(mock_DiscordClient.return_value.guild_roles.call_count, 2)
|
||||
|
||||
def test_raise_exception_if_member_info_is_invalid(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_DiscordClient
|
||||
):
|
||||
mock_user_group_names.return_value = []
|
||||
mock_DiscordClient.return_value.match_or_create_roles_from_names\
|
||||
.return_value = self.roles_requested
|
||||
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
|
||||
mock_DiscordClient.return_value.guild_member.return_value = \
|
||||
{'user': 'dummy'}
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.discord_user.update_groups()
|
||||
self.assertTrue(mock_client.modify_guild_member.called)
|
||||
|
||||
@@ -3,18 +3,18 @@ from unittest.mock import MagicMock, patch
|
||||
from celery.exceptions import Retry
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth.models import Group
|
||||
from django.test.utils import override_settings
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
from . import TEST_USER_NAME, TEST_USER_ID, TEST_MAIN_NAME, TEST_MAIN_ID
|
||||
from ..models import DiscordUser
|
||||
from ..discord_client import DiscordApiBackoff
|
||||
from .. import tasks
|
||||
from ..discord_client import DiscordApiBackoff
|
||||
from ..discord_client.tests.factories import TEST_USER_ID, TEST_USER_NAME
|
||||
from ..models import DiscordUser
|
||||
from ..utils import set_logger_to_file
|
||||
|
||||
from . import TEST_MAIN_ID, TEST_MAIN_NAME
|
||||
|
||||
MODULE_PATH = 'allianceauth.services.modules.discord.tasks'
|
||||
logger = set_logger_to_file(MODULE_PATH, __file__)
|
||||
@@ -22,7 +22,7 @@ logger = set_logger_to_file(MODULE_PATH, __file__)
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.update_groups')
|
||||
@patch(MODULE_PATH + ".logger")
|
||||
class TestUpdateGroups(TestCase):
|
||||
class TestUpdateGroups(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -110,7 +110,7 @@ class TestUpdateGroups(TestCase):
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.update_nickname')
|
||||
class TestUpdateNickname(TestCase):
|
||||
class TestUpdateNickname(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -163,7 +163,7 @@ class TestUpdateNickname(TestCase):
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.update_username')
|
||||
class TestUpdateUsername(TestCase):
|
||||
class TestUpdateUsername(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -179,7 +179,7 @@ class TestUpdateUsername(TestCase):
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.delete_user')
|
||||
class TestDeleteUser(TestCase):
|
||||
class TestDeleteUser(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -213,7 +213,7 @@ class TestDeleteUser(TestCase):
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.update_groups')
|
||||
class TestTaskPerformUserAction(TestCase):
|
||||
class TestTaskPerformUserAction(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -236,7 +236,7 @@ class TestTaskPerformUserAction(TestCase):
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.objects.server_name')
|
||||
@patch(MODULE_PATH + ".logger")
|
||||
class TestTaskUpdateServername(TestCase):
|
||||
class TestTaskUpdateServername(NoSocketsTestCase):
|
||||
|
||||
def test_normal(self, mock_logger, mock_server_name):
|
||||
tasks.update_servername()
|
||||
@@ -281,7 +281,7 @@ class TestTaskUpdateServername(TestCase):
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.objects.server_name')
|
||||
class TestTaskPerformUsersAction(TestCase):
|
||||
class TestTaskPerformUsersAction(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -300,8 +300,8 @@ class TestTaskPerformUsersAction(TestCase):
|
||||
tasks._task_perform_users_action(mock_task, 'server_name')
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
class TestBulkTasks(TestCase):
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
|
||||
class TestBulkTasks(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from ..utils import clean_setting
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import TestCase, RequestFactory
|
||||
from django.test import RequestFactory
|
||||
from django.urls import reverse
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from allianceauth.utils.testing import NoSocketsTestCase
|
||||
|
||||
from . import MODULE_PATH, add_permissions_to_members, TEST_USER_NAME, TEST_USER_ID
|
||||
from ..discord_client import DiscordClient
|
||||
from ..discord_client.tests.factories import TEST_USER_ID, TEST_USER_NAME
|
||||
from ..models import DiscordUser
|
||||
from ..utils import set_logger_to_file
|
||||
from ..views import (
|
||||
discord_callback,
|
||||
reset_discord,
|
||||
activate_discord,
|
||||
deactivate_discord,
|
||||
discord_add_bot,
|
||||
activate_discord
|
||||
discord_callback,
|
||||
reset_discord,
|
||||
)
|
||||
|
||||
from . import MODULE_PATH, add_permissions_to_members
|
||||
|
||||
logger = set_logger_to_file(MODULE_PATH + '.views', __file__)
|
||||
|
||||
|
||||
class SetupClassMixin(TestCase):
|
||||
class SetupClassMixin(NoSocketsTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -33,7 +33,7 @@ class SetupClassMixin(TestCase):
|
||||
cls.services_url = reverse('services:services')
|
||||
|
||||
|
||||
class TestActivateDiscord(SetupClassMixin, TestCase):
|
||||
class TestActivateDiscord(SetupClassMixin, NoSocketsTestCase):
|
||||
|
||||
@patch(MODULE_PATH + '.views.DiscordUser.objects.generate_oauth_redirect_url')
|
||||
def test_redirects_to_correct_url(self, mock_generate_oauth_redirect_url):
|
||||
@@ -47,31 +47,37 @@ class TestActivateDiscord(SetupClassMixin, TestCase):
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
class TestDeactivateDiscord(SetupClassMixin, TestCase):
|
||||
@patch(MODULE_PATH + '.models.create_bot_client')
|
||||
class TestDeactivateDiscord(SetupClassMixin, NoSocketsTestCase):
|
||||
|
||||
def setUp(self):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
|
||||
def test_when_successful_show_success_message(
|
||||
self, mock_DiscordClient, mock_messages
|
||||
self, mock_create_bot_client, mock_messages
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = True
|
||||
request = self.factory.get(reverse('discord:deactivate'))
|
||||
request.user = self.user
|
||||
# when
|
||||
response = deactivate_discord(request)
|
||||
# then
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, self.services_url)
|
||||
self.assertTrue(mock_messages.success.called)
|
||||
self.assertFalse(mock_messages.error.called)
|
||||
|
||||
def test_when_unsuccessful_show_error_message(
|
||||
self, mock_DiscordClient, mock_messages
|
||||
self, mock_create_bot_client, mock_messages
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = False
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = False
|
||||
request = self.factory.get(reverse('discord:deactivate'))
|
||||
request.user = self.user
|
||||
# when
|
||||
response = deactivate_discord(request)
|
||||
# then
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, self.services_url)
|
||||
self.assertFalse(mock_messages.success.called)
|
||||
@@ -79,30 +85,36 @@ class TestDeactivateDiscord(SetupClassMixin, TestCase):
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient')
|
||||
class TestResetDiscord(SetupClassMixin, TestCase):
|
||||
@patch(MODULE_PATH + '.models.create_bot_client')
|
||||
class TestResetDiscord(SetupClassMixin, NoSocketsTestCase):
|
||||
|
||||
def setUp(self):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
|
||||
def test_when_successful_redirect_to_activate(
|
||||
self, mock_DiscordClient, mock_messages
|
||||
self, mock_create_bot_client, mock_messages
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = True
|
||||
request = self.factory.get(reverse('discord:reset'))
|
||||
request.user = self.user
|
||||
# when
|
||||
response = reset_discord(request)
|
||||
# then
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, reverse("discord:activate"))
|
||||
self.assertFalse(mock_messages.error.called)
|
||||
|
||||
def test_when_unsuccessful_message_error_and_redirect_to_service(
|
||||
self, mock_DiscordClient, mock_messages
|
||||
self, mock_create_bot_client, mock_messages
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = False
|
||||
# given
|
||||
mock_create_bot_client.return_value.remove_guild_member.return_value = False
|
||||
request = self.factory.get(reverse('discord:reset'))
|
||||
request.user = self.user
|
||||
# when
|
||||
response = reset_discord(request)
|
||||
# then
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, self.services_url)
|
||||
self.assertTrue(mock_messages.error.called)
|
||||
@@ -110,7 +122,7 @@ class TestResetDiscord(SetupClassMixin, TestCase):
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.views.DiscordUser.objects.add_user')
|
||||
class TestDiscordCallback(SetupClassMixin, TestCase):
|
||||
class TestDiscordCallback(SetupClassMixin, NoSocketsTestCase):
|
||||
|
||||
def setUp(self):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
@@ -155,7 +167,7 @@ class TestDiscordCallback(SetupClassMixin, TestCase):
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.views.DiscordUser.objects.generate_bot_add_url')
|
||||
class TestDiscordAddBot(TestCase):
|
||||
class TestDiscordAddBot(NoSocketsTestCase):
|
||||
|
||||
def test_add_bot(self, mock_generate_bot_add_url):
|
||||
bot_url = 'https://www.example.com/bot'
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from django.contrib import admin
|
||||
|
||||
from .models import AuthTS, Teamspeak3User, StateGroup
|
||||
from django.contrib.auth.models import Group
|
||||
from .models import AuthTS, Teamspeak3User, StateGroup, TSgroup
|
||||
from ...admin import ServicesUserAdmin
|
||||
from allianceauth.groupmanagement.models import ReservedGroupName
|
||||
|
||||
|
||||
@admin.register(Teamspeak3User)
|
||||
@@ -25,6 +26,16 @@ class AuthTSgroupAdmin(admin.ModelAdmin):
|
||||
fields = ('auth_group', 'ts_group')
|
||||
filter_horizontal = ('ts_group',)
|
||||
|
||||
def formfield_for_foreignkey(self, db_field, request, **kwargs):
|
||||
if db_field.name == 'auth_group':
|
||||
kwargs['queryset'] = Group.objects.exclude(name__in=ReservedGroupName.objects.values_list('name', flat=True))
|
||||
return super().formfield_for_foreignkey(db_field, request, **kwargs)
|
||||
|
||||
def formfield_for_manytomany(self, db_field, request, **kwargs):
|
||||
if db_field.name == 'ts_group':
|
||||
kwargs['queryset'] = TSgroup.objects.exclude(ts_group_name__in=ReservedGroupName.objects.values_list('name', flat=True))
|
||||
return super().formfield_for_manytomany(db_field, request, **kwargs)
|
||||
|
||||
def _ts_group(self, obj):
|
||||
return [x for x in obj.ts_group.all().order_by('ts_group_id')]
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from django.conf import settings
|
||||
|
||||
from .util.ts3 import TS3Server, TeamspeakError
|
||||
from .models import TSgroup
|
||||
from allianceauth.groupmanagement.models import ReservedGroupName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -156,32 +157,25 @@ class Teamspeak3Manager:
|
||||
logger.info(f"Removed user id {uid} from group id {groupid} on TS3 server.")
|
||||
|
||||
def _sync_ts_group_db(self):
|
||||
logger.debug("_sync_ts_group_db function called.")
|
||||
try:
|
||||
remote_groups = self._group_list()
|
||||
local_groups = TSgroup.objects.all()
|
||||
logger.debug("Comparing remote groups to TSgroup objects: %s" % local_groups)
|
||||
for key in remote_groups:
|
||||
logger.debug(f"Typecasting remote_group value at position {key} to int: {remote_groups[key]}")
|
||||
remote_groups[key] = int(remote_groups[key])
|
||||
managed_groups = {g:int(remote_groups[g]) for g in remote_groups if g in set(remote_groups.keys()) - set(ReservedGroupName.objects.values_list('name', flat=True))}
|
||||
remove = TSgroup.objects.exclude(ts_group_id__in=managed_groups.values())
|
||||
|
||||
if remove:
|
||||
logger.debug(f"Deleting {remove.count()} TSgroup models: not found on server, or reserved name.")
|
||||
remove.delete()
|
||||
|
||||
add = {g:managed_groups[g] for g in managed_groups if managed_groups[g] in set(managed_groups.values()) - set(TSgroup.objects.values_list("ts_group_id", flat=True))}
|
||||
if add:
|
||||
logger.debug(f"Adding {len(add)} new TSgroup models.")
|
||||
models = [TSgroup(ts_group_name=name, ts_group_id=add[name]) for name in add]
|
||||
TSgroup.objects.bulk_create(models)
|
||||
|
||||
for group in local_groups:
|
||||
logger.debug("Checking local group %s" % group)
|
||||
if group.ts_group_id not in remote_groups.values():
|
||||
logger.debug(
|
||||
f"Local group id {group.ts_group_id} not found on server. Deleting model {group}")
|
||||
TSgroup.objects.filter(ts_group_id=group.ts_group_id).delete()
|
||||
for key in remote_groups:
|
||||
g = TSgroup(ts_group_id=remote_groups[key], ts_group_name=key)
|
||||
q = TSgroup.objects.filter(ts_group_id=g.ts_group_id)
|
||||
if not q:
|
||||
logger.debug("Local group does not exist for TS group {}. Creating TSgroup model {}".format(
|
||||
remote_groups[key], g))
|
||||
g.save()
|
||||
except TeamspeakError as e:
|
||||
logger.error("Error occured while syncing TS group db: %s" % str(e))
|
||||
except:
|
||||
logger.exception("An unhandled exception has occured while syncing TS groups.")
|
||||
logger.error(f"Error occurred while syncing TS group db: {str(e)}")
|
||||
except Exception:
|
||||
logger.exception(f"An unhandled exception has occurred while syncing TS groups.")
|
||||
|
||||
def add_user(self, user, fmt_name):
|
||||
username_clean = self.__santatize_username(fmt_name[:30])
|
||||
@@ -240,7 +234,7 @@ class Teamspeak3Manager:
|
||||
logger.exception(f"Failed to delete user id {uid} from TS3 - received response {ret}")
|
||||
return False
|
||||
else:
|
||||
logger.warn("User with id %s not found on TS3 server. Assuming succesful deletion." % uid)
|
||||
logger.warning("User with id %s not found on TS3 server. Assuming succesful deletion." % uid)
|
||||
return True
|
||||
|
||||
def check_user_exists(self, uid):
|
||||
@@ -270,7 +264,8 @@ class Teamspeak3Manager:
|
||||
addgroups.append(ts_groups[ts_group_key])
|
||||
for user_ts_group_key in user_ts_groups:
|
||||
if user_ts_groups[user_ts_group_key] not in ts_groups.values():
|
||||
remgroups.append(user_ts_groups[user_ts_group_key])
|
||||
if not ReservedGroupName.objects.filter(name=user_ts_group_key).exists():
|
||||
remgroups.append(user_ts_groups[user_ts_group_key])
|
||||
|
||||
for g in addgroups:
|
||||
logger.info(f"Adding Teamspeak user {userid} into group {g}")
|
||||
|
||||
@@ -5,16 +5,18 @@ from django import urls
|
||||
from django.contrib.auth.models import User, Group, Permission
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db.models import signals
|
||||
from django.contrib.admin import AdminSite
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from .auth_hooks import Teamspeak3Service
|
||||
from .models import Teamspeak3User, AuthTS, TSgroup, StateGroup
|
||||
from .tasks import Teamspeak3Tasks
|
||||
from .signals import m2m_changed_authts_group, post_save_authts, post_delete_authts
|
||||
from .admin import AuthTSgroupAdmin
|
||||
|
||||
from .manager import Teamspeak3Manager
|
||||
from .util.ts3 import TeamspeakError
|
||||
from allianceauth.authentication.models import State
|
||||
from allianceauth.groupmanagement.models import ReservedGroupName
|
||||
|
||||
MODULE_PATH = 'allianceauth.services.modules.teamspeak3'
|
||||
DEFAULT_AUTH_GROUP = 'Member'
|
||||
@@ -315,6 +317,9 @@ class Teamspeak3SignalsTestCase(TestCase):
|
||||
|
||||
|
||||
class Teamspeak3ManagerTestCase(TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
cls.reserved = ReservedGroupName.objects.create(name='reserved', reason='tests', created_by='Bob, praise be!')
|
||||
|
||||
@staticmethod
|
||||
def my_side_effect(*args, **kwargs):
|
||||
@@ -334,8 +339,135 @@ class Teamspeak3ManagerTestCase(TestCase):
|
||||
manager._server = server
|
||||
|
||||
# create test data
|
||||
user = User.objects.create_user("dummy")
|
||||
user.profile.state = State.objects.filter(name="Member").first()
|
||||
user = AuthUtils.create_user("dummy")
|
||||
AuthUtils.assign_state(user, AuthUtils.get_member_state())
|
||||
|
||||
# perform test
|
||||
manager.add_user(user, "Dummy User")
|
||||
|
||||
@mock.patch.object(Teamspeak3Manager, '_get_userid')
|
||||
@mock.patch.object(Teamspeak3Manager, '_user_group_list')
|
||||
@mock.patch.object(Teamspeak3Manager, '_add_user_to_group')
|
||||
@mock.patch.object(Teamspeak3Manager, '_remove_user_from_group')
|
||||
def test_update_groups_add(self, remove, add, groups, userid):
|
||||
"""Add to one group"""
|
||||
userid.return_value = 1
|
||||
groups.return_value = {'test': 1}
|
||||
|
||||
Teamspeak3Manager().update_groups(1, {'test': 1, 'dummy': 2})
|
||||
self.assertEqual(add.call_count, 1)
|
||||
self.assertEqual(remove.call_count, 0)
|
||||
self.assertEqual(add.call_args[0][1], 2)
|
||||
|
||||
@mock.patch.object(Teamspeak3Manager, '_get_userid')
|
||||
@mock.patch.object(Teamspeak3Manager, '_user_group_list')
|
||||
@mock.patch.object(Teamspeak3Manager, '_add_user_to_group')
|
||||
@mock.patch.object(Teamspeak3Manager, '_remove_user_from_group')
|
||||
def test_update_groups_remove(self, remove, add, groups, userid):
|
||||
"""Remove from one group"""
|
||||
userid.return_value = 1
|
||||
groups.return_value = {'test': '1', 'dummy': '2'}
|
||||
|
||||
Teamspeak3Manager().update_groups(1, {'test': 1})
|
||||
self.assertEqual(add.call_count, 0)
|
||||
self.assertEqual(remove.call_count, 1)
|
||||
self.assertEqual(remove.call_args[0][1], 2)
|
||||
|
||||
@mock.patch.object(Teamspeak3Manager, '_get_userid')
|
||||
@mock.patch.object(Teamspeak3Manager, '_user_group_list')
|
||||
@mock.patch.object(Teamspeak3Manager, '_add_user_to_group')
|
||||
@mock.patch.object(Teamspeak3Manager, '_remove_user_from_group')
|
||||
def test_update_groups_remove_reserved(self, remove, add, groups, userid):
|
||||
"""Remove from one group, but do not touch reserved group"""
|
||||
userid.return_value = 1
|
||||
groups.return_value = {'test': 1, 'dummy': 2, self.reserved.name: 3}
|
||||
|
||||
Teamspeak3Manager().update_groups(1, {'test': 1})
|
||||
self.assertEqual(add.call_count, 0)
|
||||
self.assertEqual(remove.call_count, 1)
|
||||
self.assertEqual(remove.call_args[0][1], 2)
|
||||
|
||||
@mock.patch.object(Teamspeak3Manager, '_group_list')
|
||||
def test_sync_group_db_create(self, group_list):
|
||||
"""Populate the list of all TSgroups"""
|
||||
group_list.return_value = {'allowed':'1', 'also allowed':'2'}
|
||||
Teamspeak3Manager()._sync_ts_group_db()
|
||||
self.assertEqual(TSgroup.objects.all().count(), 2)
|
||||
|
||||
@mock.patch.object(Teamspeak3Manager, '_group_list')
|
||||
def test_sync_group_db_delete(self, group_list):
|
||||
"""Populate the list of all TSgroups, and delete one which no longer exists"""
|
||||
TSgroup.objects.create(ts_group_name='deleted', ts_group_id=3)
|
||||
group_list.return_value = {'allowed': '1'}
|
||||
Teamspeak3Manager()._sync_ts_group_db()
|
||||
self.assertEqual(TSgroup.objects.all().count(), 1)
|
||||
self.assertFalse(TSgroup.objects.filter(ts_group_name='deleted').exists())
|
||||
|
||||
@mock.patch.object(Teamspeak3Manager, '_group_list')
|
||||
def test_sync_group_db_dont_create_reserved(self, group_list):
|
||||
"""Populate the list of all TSgroups, ignoring a reserved group name"""
|
||||
group_list.return_value = {'allowed': '1', 'reserved': '4'}
|
||||
Teamspeak3Manager()._sync_ts_group_db()
|
||||
self.assertEqual(TSgroup.objects.all().count(), 1)
|
||||
self.assertFalse(TSgroup.objects.filter(ts_group_name='reserved').exists())
|
||||
|
||||
@mock.patch.object(Teamspeak3Manager, '_group_list')
|
||||
def test_sync_group_db_delete_reserved(self, group_list):
|
||||
"""Populate the list of all TSgroups, deleting the TSgroup model for one which has become reserved"""
|
||||
TSgroup.objects.create(ts_group_name='reserved', ts_group_id=4)
|
||||
group_list.return_value = {'allowed': '1', 'reserved': '4'}
|
||||
Teamspeak3Manager()._sync_ts_group_db()
|
||||
self.assertEqual(TSgroup.objects.all().count(), 1)
|
||||
self.assertFalse(TSgroup.objects.filter(ts_group_name='reserved').exists())
|
||||
|
||||
@mock.patch.object(Teamspeak3Manager, '_group_list')
|
||||
def test_sync_group_db_partial_addition(self, group_list):
|
||||
"""Some TSgroups already exist in database, add new ones"""
|
||||
TSgroup.objects.create(ts_group_name='allowed', ts_group_id=1)
|
||||
group_list.return_value = {'allowed': '1', 'also allowed': '2'}
|
||||
Teamspeak3Manager()._sync_ts_group_db()
|
||||
self.assertEqual(TSgroup.objects.all().count(), 2)
|
||||
|
||||
@mock.patch.object(Teamspeak3Manager, '_group_list')
|
||||
def test_sync_group_db_partial_removal(self, group_list):
|
||||
"""One TSgroup has been deleted on server, so remove its model"""
|
||||
TSgroup.objects.create(ts_group_name='allowed', ts_group_id=1)
|
||||
TSgroup.objects.create(ts_group_name='also allowed', ts_group_id=2)
|
||||
group_list.return_value = {'allowed': '1'}
|
||||
Teamspeak3Manager()._sync_ts_group_db()
|
||||
self.assertEqual(TSgroup.objects.all().count(), 1)
|
||||
|
||||
|
||||
class MockRequest:
|
||||
pass
|
||||
|
||||
|
||||
class MockSuperUser:
|
||||
def has_perm(self, perm, obj=None):
|
||||
return True
|
||||
|
||||
|
||||
request = MockRequest()
|
||||
request.user = MockSuperUser()
|
||||
|
||||
|
||||
class Teamspeak3AdminTestCase(TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
cls.site = AdminSite()
|
||||
cls.admin = AuthTSgroupAdmin(AuthTS, cls.site)
|
||||
cls.group = Group.objects.create(name='test')
|
||||
cls.ts_group = TSgroup.objects.create(ts_group_name='test')
|
||||
|
||||
def test_field_queryset_no_reserved_names(self):
|
||||
"""Ensure all groups are listed when no reserved names"""
|
||||
form = self.admin.get_form(request)
|
||||
self.assertQuerysetEqual(form.base_fields['auth_group']._get_queryset(), Group.objects.all())
|
||||
self.assertQuerysetEqual(form.base_fields['ts_group']._get_queryset(), TSgroup.objects.all())
|
||||
|
||||
def test_field_queryset_reserved_names(self):
|
||||
"""Ensure reserved group names are filtered out"""
|
||||
ReservedGroupName.objects.bulk_create([ReservedGroupName(name='test', reason='tests', created_by='Bob')])
|
||||
form = self.admin.get_form(request)
|
||||
self.assertQuerysetEqual(form.base_fields['auth_group']._get_queryset(), Group.objects.none())
|
||||
self.assertQuerysetEqual(form.base_fields['ts_group']._get_queryset(), TSgroup.objects.none())
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
from django.contrib.auth.models import User, Group, Permission
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
@@ -8,7 +9,7 @@ from django.db.models.signals import pre_delete
|
||||
from django.db.models.signals import pre_save
|
||||
from django.dispatch import receiver
|
||||
from .hooks import ServicesHook
|
||||
from .tasks import disable_user
|
||||
from .tasks import disable_user, update_groups_for_user
|
||||
|
||||
from allianceauth.authentication.models import State, UserProfile
|
||||
from allianceauth.authentication.signals import state_changed
|
||||
@@ -19,21 +20,27 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@receiver(m2m_changed, sender=User.groups.through)
|
||||
def m2m_changed_user_groups(sender, instance, action, *args, **kwargs):
|
||||
logger.debug(f"Received m2m_changed from {instance} groups with action {action}")
|
||||
|
||||
def trigger_service_group_update():
|
||||
logger.debug("Triggering service group update for %s" % instance)
|
||||
# Iterate through Service hooks
|
||||
for svc in ServicesHook.get_services():
|
||||
try:
|
||||
svc.validate_user(instance)
|
||||
svc.update_groups(instance)
|
||||
except:
|
||||
logger.exception(f'Exception running update_groups for services module {svc} on user {instance}')
|
||||
|
||||
if instance.pk and (action == "post_add" or action == "post_remove" or action == "post_clear"):
|
||||
logger.debug("Waiting for commit to trigger service group update for %s" % instance)
|
||||
transaction.on_commit(trigger_service_group_update)
|
||||
logger.debug(
|
||||
"%s: Received m2m_changed from groups with action %s", instance, action
|
||||
)
|
||||
if instance.pk and (
|
||||
action == "post_add" or action == "post_remove" or action == "post_clear"
|
||||
):
|
||||
if isinstance(instance, User):
|
||||
logger.debug(
|
||||
"Waiting for commit to trigger service group update for %s", instance
|
||||
)
|
||||
transaction.on_commit(partial(update_groups_for_user.delay, instance.pk))
|
||||
elif (
|
||||
isinstance(instance, Group)
|
||||
and kwargs.get("model") is User
|
||||
and "pk_set" in kwargs
|
||||
):
|
||||
for user_pk in kwargs["pk_set"]:
|
||||
logger.debug(
|
||||
"%s: Waiting for commit to trigger service group update for user", user_pk
|
||||
)
|
||||
transaction.on_commit(partial(update_groups_for_user.delay, user_pk))
|
||||
|
||||
|
||||
@receiver(m2m_changed, sender=User.user_permissions.through)
|
||||
|
||||
@@ -47,3 +47,20 @@ def disable_user(user):
|
||||
for svc in ServicesHook.get_services():
|
||||
if svc.service_active_for_user(user):
|
||||
svc.delete_user(user)
|
||||
|
||||
|
||||
@shared_task
|
||||
def update_groups_for_user(user_pk: int) -> None:
|
||||
"""Update groups for all services registered to a user."""
|
||||
user = User.objects.get(pk=user_pk)
|
||||
logger.debug("%s: Triggering service group update for user", user)
|
||||
for svc in ServicesHook.get_services():
|
||||
try:
|
||||
svc.validate_user(user)
|
||||
svc.update_groups(user)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
'Exception running update_groups for services module %s on user %s',
|
||||
svc,
|
||||
user
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from copy import deepcopy
|
||||
from unittest import mock
|
||||
|
||||
from django.test import TestCase
|
||||
from django.test import override_settings, TestCase, TransactionTestCase
|
||||
from django.contrib.auth.models import Group, Permission
|
||||
|
||||
from allianceauth.authentication.models import State
|
||||
@@ -9,6 +9,9 @@ from allianceauth.eveonline.models import EveCharacter
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
|
||||
MODULE_PATH = 'allianceauth.services.signals'
|
||||
|
||||
|
||||
class ServicesSignalsTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.member = AuthUtils.create_user('auth_member', disconnect_signals=True)
|
||||
@@ -17,17 +20,12 @@ class ServicesSignalsTestCase(TestCase):
|
||||
)
|
||||
self.none_user = AuthUtils.create_user('none_user', disconnect_signals=True)
|
||||
|
||||
@mock.patch('allianceauth.services.signals.transaction')
|
||||
@mock.patch('allianceauth.services.signals.ServicesHook')
|
||||
def test_m2m_changed_user_groups(self, services_hook, transaction):
|
||||
@mock.patch(MODULE_PATH + '.transaction', spec=True)
|
||||
@mock.patch(MODULE_PATH + '.update_groups_for_user', spec=True)
|
||||
def test_m2m_changed_user_groups(self, update_groups_for_user, transaction):
|
||||
"""
|
||||
Test that update_groups hook function is called on user groups change
|
||||
"""
|
||||
svc = mock.Mock()
|
||||
svc.update_groups.return_value = None
|
||||
svc.validate_user.return_value = None
|
||||
|
||||
services_hook.get_services.return_value = [svc]
|
||||
|
||||
# Overload transaction.on_commit so everything happens synchronously
|
||||
transaction.on_commit = lambda fn: fn()
|
||||
@@ -39,17 +37,11 @@ class ServicesSignalsTestCase(TestCase):
|
||||
self.member.save()
|
||||
|
||||
# Assert
|
||||
self.assertTrue(services_hook.get_services.called)
|
||||
self.assertTrue(update_groups_for_user.delay.called)
|
||||
args, _ = update_groups_for_user.delay.call_args
|
||||
self.assertEqual(self.member.pk, args[0])
|
||||
|
||||
self.assertTrue(svc.update_groups.called)
|
||||
args, kwargs = svc.update_groups.call_args
|
||||
self.assertEqual(self.member, args[0])
|
||||
|
||||
self.assertTrue(svc.validate_user.called)
|
||||
args, kwargs = svc.validate_user.call_args
|
||||
self.assertEqual(self.member, args[0])
|
||||
|
||||
@mock.patch('allianceauth.services.signals.disable_user')
|
||||
@mock.patch(MODULE_PATH + '.disable_user')
|
||||
def test_pre_delete_user(self, disable_user):
|
||||
"""
|
||||
Test that disable_member is called when a user is deleted
|
||||
@@ -60,7 +52,7 @@ class ServicesSignalsTestCase(TestCase):
|
||||
args, kwargs = disable_user.call_args
|
||||
self.assertEqual(self.none_user, args[0])
|
||||
|
||||
@mock.patch('allianceauth.services.signals.disable_user')
|
||||
@mock.patch(MODULE_PATH + '.disable_user')
|
||||
def test_pre_save_user_inactivation(self, disable_user):
|
||||
"""
|
||||
Test a user set inactive has disable_member called
|
||||
@@ -72,7 +64,7 @@ class ServicesSignalsTestCase(TestCase):
|
||||
args, kwargs = disable_user.call_args
|
||||
self.assertEqual(self.member, args[0])
|
||||
|
||||
@mock.patch('allianceauth.services.signals.disable_user')
|
||||
@mock.patch(MODULE_PATH + '.disable_user')
|
||||
def test_disable_services_on_loss_of_main_character(self, disable_user):
|
||||
"""
|
||||
Test a user set inactive has disable_member called
|
||||
@@ -84,8 +76,8 @@ class ServicesSignalsTestCase(TestCase):
|
||||
args, kwargs = disable_user.call_args
|
||||
self.assertEqual(self.member, args[0])
|
||||
|
||||
@mock.patch('allianceauth.services.signals.transaction')
|
||||
@mock.patch('allianceauth.services.signals.ServicesHook')
|
||||
@mock.patch(MODULE_PATH + '.transaction')
|
||||
@mock.patch(MODULE_PATH + '.ServicesHook')
|
||||
def test_m2m_changed_group_permissions(self, services_hook, transaction):
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
svc = mock.Mock()
|
||||
@@ -116,8 +108,8 @@ class ServicesSignalsTestCase(TestCase):
|
||||
args, kwargs = svc.validate_user.call_args
|
||||
self.assertEqual(self.member, args[0])
|
||||
|
||||
@mock.patch('allianceauth.services.signals.transaction')
|
||||
@mock.patch('allianceauth.services.signals.ServicesHook')
|
||||
@mock.patch(MODULE_PATH + '.transaction')
|
||||
@mock.patch(MODULE_PATH + '.ServicesHook')
|
||||
def test_m2m_changed_user_permissions(self, services_hook, transaction):
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
svc = mock.Mock()
|
||||
@@ -145,8 +137,8 @@ class ServicesSignalsTestCase(TestCase):
|
||||
args, kwargs = svc.validate_user.call_args
|
||||
self.assertEqual(self.member, args[0])
|
||||
|
||||
@mock.patch('allianceauth.services.signals.transaction')
|
||||
@mock.patch('allianceauth.services.signals.ServicesHook')
|
||||
@mock.patch(MODULE_PATH + '.transaction')
|
||||
@mock.patch(MODULE_PATH + '.ServicesHook')
|
||||
def test_m2m_changed_user_state_permissions(self, services_hook, transaction):
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
svc = mock.Mock()
|
||||
@@ -180,7 +172,7 @@ class ServicesSignalsTestCase(TestCase):
|
||||
args, kwargs = svc.validate_user.call_args
|
||||
self.assertEqual(self.member, args[0])
|
||||
|
||||
@mock.patch('allianceauth.services.signals.ServicesHook')
|
||||
@mock.patch(MODULE_PATH + '.ServicesHook')
|
||||
def test_state_changed_services_validation_and_groups_update(self, services_hook):
|
||||
"""Test a user changing state has service accounts validated and groups updated
|
||||
"""
|
||||
@@ -206,8 +198,7 @@ class ServicesSignalsTestCase(TestCase):
|
||||
args, kwargs = svc.update_groups.call_args
|
||||
self.assertEqual(self.member, args[0])
|
||||
|
||||
|
||||
@mock.patch('allianceauth.services.signals.ServicesHook')
|
||||
@mock.patch(MODULE_PATH + '.ServicesHook')
|
||||
def test_state_changed_services_validation_and_groups_update_1(self, services_hook):
|
||||
"""Test a user changing main has service accounts validated and sync updated
|
||||
"""
|
||||
@@ -238,7 +229,7 @@ class ServicesSignalsTestCase(TestCase):
|
||||
args, kwargs = svc.sync_nickname.call_args
|
||||
self.assertEqual(self.member, args[0])
|
||||
|
||||
@mock.patch('allianceauth.services.signals.ServicesHook')
|
||||
@mock.patch(MODULE_PATH + '.ServicesHook')
|
||||
def test_state_changed_services_validation_and_groups_update_2(self, services_hook):
|
||||
"""Test a user changing main has service does not have accounts validated
|
||||
and sync updated if the new main is equal to the old main
|
||||
@@ -260,3 +251,71 @@ class ServicesSignalsTestCase(TestCase):
|
||||
self.assertFalse(services_hook.get_services.called)
|
||||
self.assertFalse(svc.validate_user.called)
|
||||
self.assertFalse(svc.sync_nickname.called)
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"allianceauth.services.modules.mumble.auth_hooks.MumbleService.update_groups"
|
||||
)
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
|
||||
class TestUserGroupBulkUpdate(TransactionTestCase):
|
||||
def test_should_run_user_service_check_when_group_added_to_user(
|
||||
self, mock_update_groups
|
||||
):
|
||||
# given
|
||||
user = AuthUtils.create_user("Bruce Wayne")
|
||||
AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
|
||||
group = Group.objects.create(name="Group")
|
||||
mock_update_groups.reset_mock()
|
||||
# when
|
||||
user.groups.add(group)
|
||||
# then
|
||||
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
|
||||
self.assertSetEqual(users_updated, {user})
|
||||
|
||||
def test_should_run_user_service_check_when_multiple_groups_are_added_to_user(
|
||||
self, mock_update_groups
|
||||
):
|
||||
# given
|
||||
user = AuthUtils.create_user("Bruce Wayne")
|
||||
AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
|
||||
group_1 = Group.objects.create(name="Group 1")
|
||||
group_2 = Group.objects.create(name="Group 2")
|
||||
mock_update_groups.reset_mock()
|
||||
# when
|
||||
user.groups.add(group_1, group_2)
|
||||
# then
|
||||
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
|
||||
self.assertSetEqual(users_updated, {user})
|
||||
|
||||
def test_should_run_user_service_check_when_user_added_to_group(
|
||||
self, mock_update_groups
|
||||
):
|
||||
# given
|
||||
user = AuthUtils.create_user("Bruce Wayne")
|
||||
AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
|
||||
group = Group.objects.create(name="Group")
|
||||
mock_update_groups.reset_mock()
|
||||
# when
|
||||
group.user_set.add(user)
|
||||
# then
|
||||
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
|
||||
self.assertSetEqual(users_updated, {user})
|
||||
|
||||
def test_should_run_user_service_check_when_multiple_users_are_added_to_group(
|
||||
self, mock_update_groups
|
||||
):
|
||||
# given
|
||||
user_1 = AuthUtils.create_user("Bruce Wayne")
|
||||
AuthUtils.add_main_character_2(user_1, "Bruce Wayne", 1001)
|
||||
user_2 = AuthUtils.create_user("Peter Parker")
|
||||
AuthUtils.add_main_character_2(user_2, "Peter Parker", 1002)
|
||||
user_3 = AuthUtils.create_user("Lex Luthor")
|
||||
AuthUtils.add_main_character_2(user_3, "Lex Luthor", 1011)
|
||||
group = Group.objects.create(name="Group")
|
||||
user_1.groups.add(group)
|
||||
mock_update_groups.reset_mock()
|
||||
# when
|
||||
group.user_set.add(user_2, user_3)
|
||||
# then
|
||||
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
|
||||
self.assertSetEqual(users_updated, {user_2, user_3})
|
||||
|
||||
@@ -3,32 +3,50 @@ from unittest import mock
|
||||
from celery_once import AlreadyQueued
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.test import TestCase
|
||||
from django.test import override_settings, TestCase
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
from allianceauth.services.tasks import validate_services
|
||||
from allianceauth.services.tasks import validate_services, update_groups_for_user
|
||||
|
||||
from ..tasks import DjangoBackend
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
|
||||
class ServicesTasksTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.member = AuthUtils.create_user('auth_member')
|
||||
|
||||
@mock.patch('allianceauth.services.tasks.ServicesHook')
|
||||
def test_validate_services(self, services_hook):
|
||||
# given
|
||||
svc = mock.Mock()
|
||||
svc.validate_user.return_value = None
|
||||
|
||||
services_hook.get_services.return_value = [svc]
|
||||
|
||||
# when
|
||||
validate_services.delay(self.member.pk)
|
||||
|
||||
# then
|
||||
self.assertTrue(services_hook.get_services.called)
|
||||
self.assertTrue(svc.validate_user.called)
|
||||
args, kwargs = svc.validate_user.call_args
|
||||
args, _ = svc.validate_user.call_args
|
||||
self.assertEqual(self.member, args[0]) # Assert correct user is passed to service hook function
|
||||
|
||||
@mock.patch('allianceauth.services.tasks.ServicesHook')
|
||||
def test_update_groups_for_user(self, services_hook):
|
||||
# given
|
||||
svc = mock.Mock()
|
||||
svc.validate_user.return_value = None
|
||||
services_hook.get_services.return_value = [svc]
|
||||
# when
|
||||
update_groups_for_user.delay(self.member.pk)
|
||||
# then
|
||||
self.assertTrue(services_hook.get_services.called)
|
||||
self.assertTrue(svc.validate_user.called)
|
||||
args, _ = svc.validate_user.call_args
|
||||
self.assertEqual(self.member, args[0]) # Assert correct user
|
||||
self.assertTrue(svc.update_groups.called)
|
||||
args, _ = svc.update_groups.call_args
|
||||
self.assertEqual(self.member, args[0]) # Assert correct user
|
||||
|
||||
|
||||
class TestDjangoBackend(TestCase):
|
||||
|
||||
|
||||
@@ -267,7 +267,9 @@ ESC to cancel{% endblocktrans %}"id="blah"></i></th>
|
||||
"targets": [4, 5],
|
||||
"type": "num"
|
||||
}
|
||||
]
|
||||
],
|
||||
"stateSave": true,
|
||||
"stateDuration": 0
|
||||
});
|
||||
|
||||
// tooltip
|
||||
|
||||
@@ -95,6 +95,11 @@ ul.list-group.list-group-horizontal > li.list-group-item {
|
||||
.table-aa > tbody > tr:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.task-status-progress-bar {
|
||||
font-size: 15px!important;
|
||||
line-height: normal!important;
|
||||
}
|
||||
}
|
||||
|
||||
/* highlight active menu items
|
||||
|
||||
@@ -1,58 +1,20 @@
|
||||
$(document).ready(function () {
|
||||
'use strict';
|
||||
|
||||
/**
|
||||
* check time
|
||||
* @param i
|
||||
* @returns {string}
|
||||
*/
|
||||
let checkTime = function (i) {
|
||||
if (i < 10) {
|
||||
i = '0' + i;
|
||||
}
|
||||
|
||||
return i;
|
||||
};
|
||||
|
||||
/**
|
||||
* render a JS clock for Eve Time
|
||||
* @param element
|
||||
* @param utcOffset
|
||||
*/
|
||||
let renderClock = function (element, utcOffset) {
|
||||
let today = new Date();
|
||||
let h = today.getUTCHours();
|
||||
let m = today.getUTCMinutes();
|
||||
|
||||
h = h + utcOffset;
|
||||
|
||||
if (h > 24) {
|
||||
h = h - 24;
|
||||
}
|
||||
|
||||
if (h < 0) {
|
||||
h = h + 24;
|
||||
}
|
||||
|
||||
h = checkTime(h);
|
||||
m = checkTime(m);
|
||||
const renderClock = function (element) {
|
||||
const datetimeNow = new Date();
|
||||
const h = String(datetimeNow.getUTCHours()).padStart(2, '0');
|
||||
const m = String(datetimeNow.getUTCMinutes()).padStart(2, '0');
|
||||
|
||||
element.html(h + ':' + m);
|
||||
|
||||
setTimeout(function () {
|
||||
renderClock(element, 0);
|
||||
}, 500);
|
||||
};
|
||||
|
||||
/**
|
||||
* functions that need to be executed on load
|
||||
*/
|
||||
let init = function () {
|
||||
renderClock($('.eve-time-wrapper .eve-time-clock'), 0);
|
||||
};
|
||||
|
||||
/**
|
||||
* start the show
|
||||
*/
|
||||
init();
|
||||
// Start the Eve time clock in the top menu bar
|
||||
setInterval(function () {
|
||||
renderClock($('.eve-time-wrapper .eve-time-clock'));
|
||||
}, 500);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
{% load humanize %}
|
||||
{% load admin_status %}
|
||||
|
||||
<div
|
||||
class="progress-bar progress-bar-{{ level }} task-status-progress-bar"
|
||||
role="progressbar"
|
||||
aria-valuenow="{% decimal_widthratio tasks_count tasks_total 100 %}"
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="100"
|
||||
style="width: {% decimal_widthratio tasks_count tasks_total 100 %}%;">
|
||||
<p style="margin-top:5px;">{% widthratio tasks_count tasks_total 100 %}%</p>
|
||||
</div>
|
||||
@@ -1,4 +1,6 @@
|
||||
{% load i18n %}
|
||||
{% load humanize %}
|
||||
|
||||
<div class="col-sm-12">
|
||||
<div class="row vertical-flexbox-row2">
|
||||
<div class="col-sm-6">
|
||||
@@ -75,29 +77,25 @@
|
||||
<div class="panel panel-primary" style="height:50%;">
|
||||
<div class="panel-heading text-center"><h3 class="panel-title">{% translate "Task Queue" %}</h3></div>
|
||||
<div class="panel-body flex-center-horizontal">
|
||||
<div class="progress" style="height: 21px;">
|
||||
<div class="progress-bar
|
||||
{% if task_queue_length > 500 %}
|
||||
progress-bar-danger
|
||||
{% elif task_queue_length > 100 %}
|
||||
progress-bar-warning
|
||||
{% else %}
|
||||
progress-bar-success
|
||||
{% endif %}
|
||||
" role="progressbar" aria-valuenow="{% widthratio task_queue_length 500 100 %}"
|
||||
aria-valuemin="0" aria-valuemax="100"
|
||||
style="width: {% widthratio task_queue_length 500 100 %}%;">
|
||||
</div>
|
||||
<p>
|
||||
{% blocktranslate with total=tasks_total|intcomma latest=earliest_task|timesince|default:"?" %}
|
||||
Status of {{ total }} processed tasks • last {{ latest }}
|
||||
{% endblocktranslate %}
|
||||
</p>
|
||||
<div
|
||||
class="progress"
|
||||
style="height: 21px;"
|
||||
title="{{ tasks_succeeded|intcomma }} succeeded, {{ tasks_retried|intcomma }} retried, {{ tasks_failed|intcomma }} failed"
|
||||
>
|
||||
{% include "allianceauth/admin-status/celery_bar_partial.html" with label="suceeded" level="success" tasks_count=tasks_succeeded %}
|
||||
{% include "allianceauth/admin-status/celery_bar_partial.html" with label="retried" level="info" tasks_count=tasks_retried %}
|
||||
{% include "allianceauth/admin-status/celery_bar_partial.html" with label="failed" level="danger" tasks_count=tasks_failed %}
|
||||
</div>
|
||||
{% if task_queue_length < 0 %}
|
||||
{% translate "Error retrieving task queue length" %}
|
||||
{% else %}
|
||||
{% blocktrans trimmed count tasks=task_queue_length %}
|
||||
{{ tasks }} task
|
||||
{% plural %}
|
||||
{{ tasks }} tasks
|
||||
{% endblocktrans %}
|
||||
{% endif %}
|
||||
<p>
|
||||
{% blocktranslate with queue_length=task_queue_length|default_if_none:"?"|intcomma %}
|
||||
{{ queue_length }} queued tasks
|
||||
{% endblocktranslate %}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
3
allianceauth/templates/bundles/evetime-js.html
Normal file
3
allianceauth/templates/bundles/evetime-js.html
Normal file
@@ -0,0 +1,3 @@
|
||||
{% load static %}
|
||||
|
||||
<script src="{% static 'js/eve-time.js' %}"></script>
|
||||
3
allianceauth/templates/bundles/filterdropdown-js.html
Normal file
3
allianceauth/templates/bundles/filterdropdown-js.html
Normal file
@@ -0,0 +1,3 @@
|
||||
{% load static %}
|
||||
|
||||
<script type="application/javascript" src="{% static 'js/filterDropDown/filterDropDown.min.js' %}"></script>
|
||||
@@ -0,0 +1,3 @@
|
||||
{% load static %}
|
||||
|
||||
<script src="{% static 'js/refresh_notifications.js' %}"></script>
|
||||
3
allianceauth/templates/bundles/timers-js.html
Normal file
3
allianceauth/templates/bundles/timers-js.html
Normal file
@@ -0,0 +1,3 @@
|
||||
{% load static %}
|
||||
|
||||
<script src="{% static 'js/timers.js' %}"></script>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user